Source code for linumpy.registration.transforms

"""Transform construction and mosaic-level transform estimation."""

import random
from collections.abc import Sequence
from typing import Any

import numpy as np
import SimpleITK as sitk
from skimage.exposure import match_histograms
from skimage.filters import threshold_otsu

from linumpy.registration.phase_correlation import pair_wise_phase_correlation


[docs] def create_transform(tx: float, ty: float, angle_deg: float, center: Sequence[float]) -> sitk.Euler3DTransform: """Create a 3D SimpleITK Euler transform from 2D parameters. Parameters ---------- tx, ty : float Translation in pixels. angle_deg : float Rotation angle in degrees (around Z axis). center : sequence (cx, cy) rotation center. Returns ------- sitk.Euler3DTransform """ transform = sitk.Euler3DTransform() transform.SetCenter([center[0], center[1], 0.0]) transform.SetRotation(0.0, 0.0, np.radians(angle_deg)) transform.SetTranslation([tx, ty, 0.0]) return transform
[docs] def compute_motor_transform(tile_shape: Sequence[int], overlap_fraction: float) -> np.ndarray: """Compute the transform matrix for motor-based tile positions. Creates a diagonal transform where tile index (i, j) maps to a pixel position based on the expected overlap, corresponding to precise motor/stage positions from acquisition. Parameters ---------- tile_shape : tuple or list Tile shape as (height, width) in pixels. overlap_fraction : float Expected overlap between tiles (0-1). Returns ------- np.ndarray 2x2 transform matrix where ``transform @ [i, j]`` gives the pixel position of tile ``(i, j)``. """ step_y = tile_shape[0] * (1.0 - overlap_fraction) step_x = tile_shape[1] * (1.0 - overlap_fraction) return np.array([[step_y, 0.0], [0.0, step_x]])
[docs] def estimate_mosaic_transform( mosaics: list[Any], max_empty_fraction: float = 0.9, n_samples: int = 512, seed: int | None = None ) -> tuple[np.ndarray, np.ndarray, int]: """Estimate the 2x2 mosaic transform from pairwise phase-correlation registration. For each mosaic, neighbouring tile pairs are registered with :func:`pair_wise_phase_correlation` and the resulting pixel displacements are assembled into a least-squares system to recover the underlying affine transform. Parameters ---------- mosaics : list of MosaicGrid Loaded mosaic grids to use for estimation. max_empty_fraction : float, optional Maximum fraction of empty pixels in an overlap region to still use the pair (default 0.9). n_samples : int, optional Maximum number of tile pairs to sample across all mosaics (default 512). seed : int, optional Random seed for reproducible tile-pair sampling. Returns ------- transform : np.ndarray 2x2 transform matrix. residuals : np.ndarray Residuals from the least-squares fit. tile_count : int Number of tile pairs actually used. """ rows, rows_px, cols, cols_px = [], [], [], [] tile_count = 0 if seed is not None: random.seed(seed) mosaic_idx = list(range(len(mosaics))) random.shuffle(mosaic_idx) thresholds = [threshold_otsu(m.image) for m in mosaics] for m_id in mosaic_idx: mosaic = mosaics[m_id] thresh = thresholds[m_id] for i in range(mosaic.n_tiles_x): for j in range(mosaic.n_tiles_y): if tile_count > n_samples: break neighbors, tiles = mosaic.get_neighbors_around_tile(i, j) for _n, t in zip(neighbors, tiles, strict=False): r = t[0] - i c = t[1] - j o1, o2, p1, _p2 = mosaic.get_neighbor_overlap_from_pos((i, j), t) o1_empty = np.sum(o1 <= thresh) > max_empty_fraction * o1.size o2_empty = np.sum(o2 <= thresh) > max_empty_fraction * o2.size if o1_empty or o2_empty: continue o2 = match_histograms(o2, o1) dx, dy = pair_wise_phase_correlation(o1, o2) r_px = p1[2] - mosaic.tile_size_x + dx if r == -1 else p1[0] + dx c_px = p1[3] - mosaic.tile_size_y + dy if c == -1 else p1[1] + dy rows.append(r) cols.append(c) rows_px.append(r_px) cols_px.append(c_px) tile_count += 1 # Build and solve the least-squares system a = np.zeros((len(rows) * 2, 4)) b = np.zeros((len(rows) * 2, 1)) for i in range(len(rows)): a[2 * i, :] = [rows[i], cols[i], 0, 0] b[2 * i, 0] = rows_px[i] a[2 * i + 1, :] = [0, 0, rows[i], cols[i]] b[2 * i + 1, 0] = cols_px[i] result = np.linalg.lstsq(a, b, rcond=None) transform = result[0].reshape((2, 2)) residuals = result[1] if len(result[1]) > 0 else np.array([0.0]) return transform, residuals, tile_count