"""Refinement registration: best-Z search and small rotation/translation correction."""
import numpy as np
import SimpleITK as sitk
from scipy.ndimage import center_of_mass, sobel
from skimage.registration import phase_cross_correlation
[docs]
def find_best_z(fixed_vol: np.ndarray, moving_slice: np.ndarray, expected_z: int, search_range: int) -> tuple[int, float]:
"""Find the Z-index in fixed_vol that best matches moving_slice.
Uses normalized cross-correlation in the center region.
Parameters
----------
fixed_vol : array-like
Fixed volume (Z, Y, X) or dask/zarr array.
moving_slice : np.ndarray
2D slice to match.
expected_z : int
Expected Z-index in fixed_vol for the match.
search_range : int
Search +/-search_range around expected_z.
Returns
-------
best_z : int
Z-index giving the best correlation.
best_corr : float
Correlation score at best_z.
"""
nz = fixed_vol.shape[0]
expected_z = max(0, min(nz - 1, expected_z))
z_min = max(0, expected_z - search_range)
z_max = min(nz - 1, expected_z + search_range)
if z_min >= z_max:
return max(0, min(nz - 1, expected_z)), 0.0
h, w = moving_slice.shape
margin = min(h, w) // 4
roi = (slice(margin, h - margin), slice(margin, w - margin))
moving_roi = moving_slice[roi].astype(np.float32)
valid_mov = moving_roi > 0
if valid_mov.any():
pmin = float(np.percentile(moving_roi[valid_mov], 5))
pmax = float(np.percentile(moving_roi[valid_mov], 95))
moving_roi = np.clip((moving_roi - pmin) / max(pmax - pmin, 1e-8), 0, 1)
moving_norm = (moving_roi - moving_roi.mean()) / (moving_roi.std() + 1e-8)
best_z = expected_z
best_corr = -np.inf
for z in range(z_min, z_max + 1):
fixed_slice = np.array(fixed_vol[z])
fixed_roi = fixed_slice[roi].astype(np.float32)
valid_fix = fixed_roi > 0
if valid_fix.any():
pmin = float(np.percentile(fixed_roi[valid_fix], 5))
pmax = float(np.percentile(fixed_roi[valid_fix], 95))
fixed_roi = np.clip((fixed_roi - pmin) / max(pmax - pmin, 1e-8), 0, 1)
fixed_norm = (fixed_roi - fixed_roi.mean()) / (fixed_roi.std() + 1e-8)
corr = float(np.mean(fixed_norm * moving_norm))
if corr > best_corr:
best_corr = corr
best_z = z
return max(0, min(nz - 1, best_z)), best_corr
[docs]
def register_refinement(
fixed: np.ndarray,
moving: np.ndarray,
enable_rotation: bool = True,
max_rotation_deg: float = 5.0,
max_translation_px: float = 20.0,
fixed_mask: np.ndarray | None = None,
moving_mask: np.ndarray | None = None,
initial_offset: tuple[float, float] | None = None,
) -> tuple[float, float, float, float]:
"""Compute small rotation and translation refinement using SimpleITK.
Parameters
----------
fixed, moving : np.ndarray
2D images for registration (should be normalized to [0, 1]).
enable_rotation : bool
Enable Euler2D rotation (default True). False = translation only.
max_rotation_deg : float
Maximum allowed rotation in degrees.
max_translation_px : float
Maximum allowed translation in pixels.
fixed_mask, moving_mask : np.ndarray or None
Optional tissue masks multiplied into images before registration.
initial_offset : tuple[float, float] or None
Optional initial (dy, dx) translation offset for the transform.
Returns
-------
tx, ty : float
Translation refinement in pixels.
angle_deg : float
Rotation angle in degrees.
metric : float
Registration metric value.
"""
fixed_std = np.std(fixed[fixed > 0]) if np.any(fixed > 0) else 0
moving_std = np.std(moving[moving > 0]) if np.any(moving > 0) else 0
if fixed_std < 0.01 or moving_std < 0.01:
return 0.0, 0.0, 0.0, 0.0
fixed_masked = fixed * fixed_mask.astype(np.float32) if fixed_mask is not None else fixed
moving_masked = moving * moving_mask.astype(np.float32) if moving_mask is not None else moving
fixed_sitk = sitk.GetImageFromArray(fixed_masked.astype(np.float32))
moving_sitk = sitk.GetImageFromArray(moving_masked.astype(np.float32))
if enable_rotation:
transform = sitk.Euler2DTransform()
center = [fixed.shape[1] / 2.0, fixed.shape[0] / 2.0]
transform.SetCenter(center)
if initial_offset is not None:
transform.SetTranslation([initial_offset[1], initial_offset[0]])
else:
transform = sitk.TranslationTransform(2)
if initial_offset is not None:
transform.SetOffset([initial_offset[1], initial_offset[0]])
reg = sitk.ImageRegistrationMethod()
reg.SetMetricAsCorrelation()
reg.SetOptimizerAsGradientDescent(
learningRate=1.0, numberOfIterations=200, convergenceMinimumValue=1e-6, convergenceWindowSize=10
)
reg.SetOptimizerScalesFromPhysicalShift()
reg.SetInitialTransform(transform, inPlace=False)
reg.SetInterpolator(sitk.sitkLinear)
reg.SetShrinkFactorsPerLevel([4, 2, 1])
reg.SetSmoothingSigmasPerLevel([2, 1, 0])
try:
final = reg.Execute(fixed_sitk, moving_sitk)
metric = reg.GetMetricValue()
inner = final.GetNthTransform(0) if final.GetName() == "CompositeTransform" else final
if enable_rotation:
euler = sitk.Euler2DTransform(inner)
angle_deg = np.degrees(euler.GetAngle())
tx, ty = euler.GetTranslation()
else:
tx, ty = inner.GetOffset()
angle_deg = 0.0
mag = np.sqrt(tx**2 + ty**2)
if mag > max_translation_px:
scale = max_translation_px / mag
tx, ty = tx * scale, ty * scale
if abs(angle_deg) > max_rotation_deg:
angle_deg = float(np.clip(angle_deg, -max_rotation_deg, max_rotation_deg))
return tx, ty, angle_deg, metric
except Exception:
return 0.0, 0.0, 0.0, float("inf")
[docs]
def centre_of_mass_offset(fixed: np.ndarray, moving: np.ndarray) -> tuple[float, float]:
"""Compute alignment offset using center-of-mass difference.
Parameters
----------
fixed, moving : np.ndarray
2D normalized images (values in [0, 1]).
Returns
-------
dy, dx : float
Translation from moving to fixed (fixed - moving offset).
"""
fixed_binary = fixed > 0.1
moving_binary = moving > 0.1
if not fixed_binary.any() or not moving_binary.any():
return 0.0, 0.0
cy_f, cx_f = center_of_mass(fixed * fixed_binary.astype(np.float32))
cy_m, cx_m = center_of_mass(moving * moving_binary.astype(np.float32))
return float(cy_f - cy_m), float(cx_f - cx_m)
[docs]
def gradient_magnitude_alignment(fixed: np.ndarray, moving: np.ndarray) -> tuple[float, float]:
"""Compute alignment offset using phase correlation on gradient magnitude images.
Parameters
----------
fixed, moving : np.ndarray
2D normalized images (values in [0, 1]).
Returns
-------
dy, dx : float
Translation from moving to fixed (fixed - moving offset).
"""
def grad_mag(img: np.ndarray) -> np.ndarray:
gx = sobel(img, axis=1)
gy = sobel(img, axis=0)
return np.sqrt(gx**2 + gy**2)
fixed_grad = grad_mag(fixed.astype(np.float32))
moving_grad = grad_mag(moving.astype(np.float32))
shift, _, _ = phase_cross_correlation(fixed_grad, moving_grad, normalization=None, upsample_factor=4)
return float(shift[0]), float(shift[1])