"""
3D slice stacking utilities.
Consolidated from linum_stack_slices_motor.py and linum_stack_motor_only.py.
"""
import logging
from typing import Any
import numpy as np
[docs]
logger = logging.getLogger(__name__)
[docs]
def enforce_z_consistency(
z_matches: list,
confidence_per_slice: dict | None = None,
outlier_threshold_frac: float = 0.30,
confidence_protect_threshold: float = 0.6,
) -> tuple[list, list]:
"""Correct outlier Z-overlaps using neighbor interpolation.
Scans pairwise Z-overlap measurements for outliers (deviating more than
``outlier_threshold_frac`` from the median) and replaces them with the
local median of their immediate neighbors. Both ``overlap_voxels`` and
``blend_overlap_voxels`` are corrected independently.
Slices whose registration confidence (from ``confidence_per_slice``)
meets or exceeds ``confidence_protect_threshold`` are considered reliable
and are not modified.
Parameters
----------
z_matches : list of dict
Each dict must have keys ``overlap_voxels``, ``blend_overlap_voxels``
and ``moving_id``. Items are modified in-place.
confidence_per_slice : dict or None
Mapping from ``moving_id`` (int) to confidence score in [0, 1].
Slices with confidence >= ``confidence_protect_threshold`` are skipped.
If None, all slices are treated as having confidence 0.5.
outlier_threshold_frac : float
Fractional deviation from median above which a value is an outlier.
Default: 0.30 (30 %).
confidence_protect_threshold : float
Minimum confidence to protect a slice from correction. Default: 0.6.
Returns
-------
z_matches : list of dict
The corrected z_matches list (same objects, modified in-place).
corrections : list of dict
Log of corrections: each entry has keys ``moving_id``, ``field``,
``old_value`` and ``new_value``.
"""
if len(z_matches) < 3:
return z_matches, []
conf = confidence_per_slice or {}
corrections = []
for field in ("overlap_voxels", "blend_overlap_voxels"):
values = np.array([float(m[field]) for m in z_matches])
median_val = float(np.median(values))
threshold = outlier_threshold_frac * max(median_val, 1.0)
for i, match in enumerate(z_matches):
slice_id = match.get("moving_id", i)
# Protect high-confidence registrations from correction
if conf.get(slice_id, 0.5) >= confidence_protect_threshold:
continue
deviation = abs(float(match[field]) - median_val)
if deviation <= threshold:
continue
old_val = match[field]
neighbor_vals = []
if i > 0:
neighbor_vals.append(float(z_matches[i - 1][field]))
if i + 1 < len(z_matches):
neighbor_vals.append(float(z_matches[i + 1][field]))
new_val = int(np.median(neighbor_vals)) if neighbor_vals else int(median_val)
match[field] = new_val
corrections.append(
{
"moving_id": slice_id,
"field": field,
"old_value": old_val,
"new_value": new_val,
}
)
return z_matches, corrections
[docs]
def find_z_overlap(
fixed_vol: np.ndarray, moving_vol: np.ndarray, slicing_interval_mm: float, search_range_mm: float, resolution_um: float
) -> tuple[int, float]:
"""Find optimal Z-overlap between consecutive slices using cross-correlation.
Searches around the expected overlap for the best normalized
cross-correlation score, using the center XY region for speed.
Parameters
----------
fixed_vol : np.ndarray
Bottom (fixed) slice volume (Z, Y, X).
moving_vol : np.ndarray
Top (moving) slice volume (Z, Y, X).
slicing_interval_mm : float
Expected physical slice thickness in mm.
search_range_mm : float
Search range around expected position in mm.
resolution_um : float
Z resolution in microns per voxel.
Returns
-------
best_overlap : int
Optimal overlap in Z voxels.
best_corr : float
Correlation score at optimal overlap.
"""
interval_vox = int((slicing_interval_mm * 1000) / resolution_um)
expected_overlap_vox = min(fixed_vol.shape[0], moving_vol.shape[0]) - interval_vox
search_range_vox = int((search_range_mm * 1000) / resolution_um)
min_overlap = max(1, expected_overlap_vox - search_range_vox)
max_overlap = min(fixed_vol.shape[0], moving_vol.shape[0], expected_overlap_vox + search_range_vox)
if min_overlap >= max_overlap:
return expected_overlap_vox, 0.0
h, w = fixed_vol.shape[1], fixed_vol.shape[2]
margin = min(h, w) // 4
y_slice = slice(margin, h - margin)
x_slice = slice(margin, w - margin)
best_overlap = expected_overlap_vox
best_corr = -np.inf
for overlap in range(min_overlap, max_overlap + 1):
fixed_region = fixed_vol[-overlap:, y_slice, x_slice]
moving_region = moving_vol[:overlap, y_slice, x_slice]
fixed_norm = (fixed_region - fixed_region.mean()) / (fixed_region.std() + 1e-8)
moving_norm = (moving_region - moving_region.mean()) / (moving_region.std() + 1e-8)
corr = np.mean(fixed_norm * moving_norm)
if corr > best_corr:
best_corr = corr
best_overlap = overlap
return best_overlap, best_corr
[docs]
def apply_xy_shift(vol: np.ndarray, dx_px: float, dy_px: float, output_shape: tuple[int, int]) -> tuple:
"""Compute destination region for placing a shifted volume.
Returns the (possibly cropped) volume data and destination coordinates
without allocating a full-size output array.
Parameters
----------
vol : np.ndarray
3D volume (Z, Y, X).
dx_px, dy_px : float
Shift in pixels (X and Y directions).
output_shape : tuple
(out_ny, out_nx) output canvas size.
Returns
-------
cropped_vol : np.ndarray or None
Cropped volume data to write.
dst_coords : tuple or None
(y_start, y_end, x_start, x_end) in output coordinates.
"""
out_ny, out_nx = output_shape
dx_int, dy_int = round(dx_px), round(dy_px)
dst_y_start = dy_int
dst_x_start = dx_int
dst_y_end = dst_y_start + vol.shape[1]
dst_x_end = dst_x_start + vol.shape[2]
src_y_start = max(0, -dst_y_start)
src_y_end = vol.shape[1] - max(0, dst_y_end - out_ny)
src_x_start = max(0, -dst_x_start)
src_x_end = vol.shape[2] - max(0, dst_x_end - out_nx)
dst_y_start = max(0, dst_y_start)
dst_y_end = min(out_ny, dst_y_end)
dst_x_start = max(0, dst_x_start)
dst_x_end = min(out_nx, dst_x_end)
if src_y_end > src_y_start and src_x_end > src_x_start:
cropped = vol[:, src_y_start:src_y_end, src_x_start:src_x_end]
return cropped, (dst_y_start, dst_y_end, dst_x_start, dst_x_end)
return None, None
[docs]
def blend_overlap_z(fixed_region: np.ndarray, moving_region: np.ndarray) -> np.ndarray:
"""Blend overlapping Z-region using a cosine (Hann) ramp along Z-axis.
The weight ramp has zero slope at both endpoints, so there is no abrupt
intensity change at either boundary of the overlap zone. At tissue
boundaries where only one slice has data the full intensity of that slice
is used unchanged.
Parameters
----------
fixed_region : np.ndarray
3D array (Z, Y, X) from the existing stack (bottom portion).
moving_region : np.ndarray
3D array (Z, Y, X) from the new slice (top portion).
Returns
-------
np.ndarray
Blended region with smooth Z-transition.
"""
nz = fixed_region.shape[0]
if nz <= 1:
return moving_region if np.sum(moving_region > 0) >= np.sum(fixed_region > 0) else fixed_region
# Cosine (Hann) ramp: 0 → 1 with zero slope at both ends
t = np.linspace(0, np.pi, nz)
z_weights = 0.5 * (1 - np.cos(t))
alphas = np.broadcast_to(z_weights[:, np.newaxis, np.newaxis], fixed_region.shape).copy()
fixed_valid = fixed_region > 0
moving_valid = moving_region > 0
both_valid = fixed_valid & moving_valid
fixed_only = fixed_valid & ~moving_valid
moving_only = moving_valid & ~fixed_valid
blended = np.zeros_like(moving_region, dtype=np.float32)
if np.any(both_valid):
blended[both_valid] = ((1 - alphas) * fixed_region + alphas * moving_region)[both_valid]
if np.any(fixed_only):
blended[fixed_only] = fixed_region[fixed_only]
if np.any(moving_only):
blended[moving_only] = moving_region[moving_only]
return blended
[docs]
def blend_overlap_xy(existing: np.ndarray, new_data: np.ndarray, method: str = "none") -> np.ndarray:
"""Blend overlapping XY regions for motor-only stacking.
Parameters
----------
existing : np.ndarray
Existing data in the output region.
new_data : np.ndarray
Incoming data to blend.
method : str
'none' (overwrite), 'average', 'max', or 'feather'.
Returns
-------
np.ndarray
Blended result.
"""
if method == "none":
mask = new_data != 0
existing[mask] = new_data[mask]
return existing
elif method == "average":
both_valid = (existing != 0) & (new_data != 0)
only_new = (existing == 0) & (new_data != 0)
existing[both_valid] = (existing[both_valid] + new_data[both_valid]) / 2
existing[only_new] = new_data[only_new]
return existing
elif method == "max":
return np.maximum(existing, new_data)
elif method == "feather":
return blend_overlap_xy(existing, new_data, "average")
return existing
[docs]
def refine_z_blend_overlap(
existing: np.ndarray, moving_overlap: np.ndarray, max_refinement_px: float
) -> tuple[np.ndarray, float]:
"""Find and apply a small XY shift to align moving_overlap with existing before blending.
Uses 2D phase correlation on Z-projected overlap regions to detect residual
XY misalignment at slice boundaries.
Parameters
----------
existing : np.ndarray
3D array (Z, Y, X) from current stack at the overlap zone.
moving_overlap : np.ndarray
3D array (Z, Y, X) from incoming slice at the overlap zone.
max_refinement_px : float
Maximum allowed shift magnitude in pixels.
Returns
-------
refined : np.ndarray
Shifted moving_overlap with residual XY misalignment corrected.
magnitude : float
Shift magnitude applied (pixels), or 0.0 if not applied.
"""
from scipy.ndimage import shift as ndi_shift
from linumpy.registration.transforms import pair_wise_phase_correlation
fixed_2d = np.mean(existing, axis=0).astype(np.float32)
moving_2d = np.mean(moving_overlap, axis=0).astype(np.float32)
valid = (fixed_2d > 0) & (moving_2d > 0)
if np.sum(valid) < 1000:
return moving_overlap, 0.0
try:
shift = pair_wise_phase_correlation(fixed_2d, moving_2d)
dy, dx = float(shift[0]), float(shift[1])
except Exception as e:
logger.debug("Z-blend phase correlation failed: %s", e)
return moving_overlap, 0.0
magnitude = np.sqrt(dy**2 + dx**2)
if magnitude < 0.1:
return moving_overlap, 0.0
if magnitude > max_refinement_px:
logger.debug("Z-blend refinement rejected: %.2f px > max %s px", magnitude, max_refinement_px)
return moving_overlap, 0.0
refined = ndi_shift(moving_overlap.astype(np.float32), [0, dy, dx], order=0, mode="nearest")
return refined, magnitude