Source code for linumpy.stack_alignment.io

"""CSV loading and cumulative shift computation."""

from collections.abc import Sequence
from pathlib import Path

import pandas as pd

from linumpy.stack_alignment.units import detect_shift_units


[docs] def load_shifts_csv(shifts_path: Path) -> tuple[dict, list]: """Load shifts CSV and build cumulative shift lookup. The shifts file contains pairwise shifts: fixed_id -> moving_id in mm. Accumulates these to get absolute positions from the first slice. Parameters ---------- shifts_path : str or Path Path to CSV file with columns: fixed_id, moving_id, x_shift_mm, y_shift_mm Returns ------- cumsum : dict Mapping from slice_id to (cumulative_dx_mm, cumulative_dy_mm) all_ids : list Sorted list of all slice IDs """ df = pd.read_csv(shifts_path) all_ids = sorted(set(df["fixed_id"].tolist() + df["moving_id"].tolist())) shift_lookup = {} for _, row in df.iterrows(): fixed_id = int(row["fixed_id"]) moving_id = int(row["moving_id"]) shift_lookup[(fixed_id, moving_id)] = (row["x_shift_mm"], row["y_shift_mm"]) cumsum = {all_ids[0]: (0.0, 0.0)} for i in range(len(all_ids) - 1): fixed_id = all_ids[i] moving_id = all_ids[i + 1] dx_mm, dy_mm = shift_lookup.get((fixed_id, moving_id), (0.0, 0.0)) prev_dx, prev_dy = cumsum[fixed_id] cumsum[moving_id] = (prev_dx + dx_mm, prev_dy + dy_mm) return cumsum, all_ids
[docs] def build_cumulative_shifts( shifts_df: pd.DataFrame, selected_slice_ids: list, resolution: Sequence[float], center_drift: bool = True, ) -> dict: """Build cumulative pixel shifts for selected slices. Handles skipped slices by accumulating intermediate steps. Converts mm shifts to pixels using the provided resolution. Parameters ---------- shifts_df : pd.DataFrame DataFrame with columns: fixed_id, moving_id, x_shift_mm, y_shift_mm selected_slice_ids : list Sorted list of slice IDs to process. resolution : tuple Resolution (res_z, res_y, res_x) from read_omezarr; auto-detects mm vs µm. center_drift : bool If True, center cumulative drift around the middle slice. Returns ------- dict Mapping from slice_id to (cumulative_dx_px, cumulative_dy_px). """ shift_lookup = {} for _, row in shifts_df.iterrows(): fixed_id = int(row["fixed_id"]) moving_id = int(row["moving_id"]) shift_lookup[(fixed_id, moving_id)] = (row["x_shift_mm"], row["y_shift_mm"]) all_slice_ids = set() for _, row in shifts_df.iterrows(): all_slice_ids.add(int(row["fixed_id"])) all_slice_ids.add(int(row["moving_id"])) all_slice_ids = sorted(all_slice_ids) cumsum_all = {all_slice_ids[0]: (0.0, 0.0)} for i in range(len(all_slice_ids) - 1): fixed_id = all_slice_ids[i] moving_id = all_slice_ids[i + 1] dx_mm, dy_mm = shift_lookup.get((fixed_id, moving_id), (0.0, 0.0)) prev_dx, prev_dy = cumsum_all[fixed_id] cumsum_all[moving_id] = (prev_dx + dx_mm, prev_dy + dy_mm) res_x_um, res_y_um = detect_shift_units(resolution) mm_to_px_x = 1000.0 / res_x_um mm_to_px_y = 1000.0 / res_y_um cumsum_selected = {} for slice_id in selected_slice_ids: if slice_id in cumsum_all: dx_mm, dy_mm = cumsum_all[slice_id] cumsum_selected[slice_id] = (dx_mm * mm_to_px_x, dy_mm * mm_to_px_y) else: cumsum_selected[slice_id] = (0.0, 0.0) if center_drift and len(cumsum_selected) > 0: middle_idx = len(selected_slice_ids) // 2 middle_id = selected_slice_ids[middle_idx] center_dx, center_dy = cumsum_selected[middle_id] for slice_id in cumsum_selected: dx, dy = cumsum_selected[slice_id] cumsum_selected[slice_id] = (dx - center_dx, dy - center_dy) return cumsum_selected