"""Slice interpolation utilities for missing serial sections.
Single strategy: :func:`interpolate_z_morph` -- z-aware morphing via fractional
affine warps (``T**alpha``, :func:`scipy.linalg.fractional_matrix_power`).
Reconstructs a synthetic slice that transitions along Z from ``vol_before[-1]``
to ``vol_after[0]``, matching the physical geometry of serial sectioning.
When any quality gate fails, :func:`interpolate_z_morph` **does not fabricate
a volume**. It returns ``(None, diagnostics)`` with
``diagnostics["interpolation_failed"] = True`` and a specific
``fallback_reason``. The caller must honour this by *not* emitting a
reconstructed slice -- a blend of the two neighbours would also be
fabricated data, with the added failure mode of ghost/double-contour
artefacts whenever the two neighbours differ.
Downstream of a failed interpolation, the pipeline treats the slice as a
genuine multi-slice gap: no zarr is produced, the manifest fragment records
``interpolation_failed=true``, and ``slice_config_final.csv`` surfaces the
failure so the final report can flag it.
Returns ``(volume, diagnostics)`` where *volume* is ``None`` on failure and
*diagnostics* is always a JSON-serialisable dict.
See ``docs/SLICE_INTERPOLATION_FEATURE.md`` for the scientific rationale,
failure modes, and the connection to ``slice_config.csv`` (``interpolated``,
``interpolation_failed``, ``interpolation_fallback_reason`` columns).
"""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
import SimpleITK as sitk
from scipy.linalg import fractional_matrix_power
from scipy.ndimage import distance_transform_edt, gaussian_filter
from linumpy.registration.sitk import apply_transform, register_2d_images_sitk
[docs]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Normalization / NCC helpers
# ---------------------------------------------------------------------------
def _normalize_plane_for_ncc(plane: np.ndarray) -> np.ndarray:
"""Return a zero-mean / unit-std plane suitable for normalised CC."""
crop = plane.astype(np.float32)
valid = crop > 0
if valid.any():
pmin = float(np.percentile(crop[valid], 5))
pmax = float(np.percentile(crop[valid], 95))
crop = np.clip((crop - pmin) / max(pmax - pmin, 1e-8), 0, 1)
return (crop - crop.mean()) / (crop.std() + 1e-8)
def _ncc(a: np.ndarray, b: np.ndarray, margin_frac: float = 0.25) -> float:
"""Normalised cross-correlation of two 2D images on their central ROI."""
if a.shape != b.shape:
raise ValueError(f"Shape mismatch in NCC: {a.shape} vs {b.shape}")
h, w = a.shape
margin = int(min(h, w) * margin_frac)
roi = (slice(margin, h - margin), slice(margin, w - margin))
an = _normalize_plane_for_ncc(a[roi])
bn = _normalize_plane_for_ncc(b[roi])
return float(np.mean(an * bn))
def _foreground_fraction(plane: np.ndarray, threshold: float | None = None) -> float:
"""Fraction of pixels above a background threshold.
When *threshold* is None, uses the 1st percentile of positive values as a
soft background estimate. This makes the function robust to common OCT
volumes that have a non-trivial dark offset.
"""
if plane.size == 0:
return 0.0
if threshold is None:
positive = plane[plane > 0]
if positive.size == 0:
return 0.0
threshold = float(np.percentile(positive, 1))
return float((plane > threshold).mean())
# ---------------------------------------------------------------------------
# Affine / fractional-affine helpers
# ---------------------------------------------------------------------------
def _matrix_fractional_power(matrix: np.ndarray, alpha: float, _imag_tol: float = 1e-4) -> tuple[np.ndarray, float]:
"""Return ``matrix ** alpha`` as a real matrix, plus the max imaginary magnitude.
Uses :func:`scipy.linalg.fractional_matrix_power` which internally
performs a Schur decomposition and is numerically more stable than a
bare eigen-decomposition with ``.real`` truncation.
Returns
-------
real_part : np.ndarray
Real part of the fractional power.
imag_magnitude : float
Maximum ``|imag|`` relative to ``max(|real|, 1)``. A large value
indicates the matrix had negative-real-eigenvalue components (e.g.
reflection) and the real projection is *not* a valid power.
"""
with np.errstate(all="ignore"):
m_alpha = fractional_matrix_power(matrix, alpha)
if np.iscomplexobj(m_alpha):
scale = max(float(np.max(np.abs(m_alpha.real))), 1.0)
imag_mag = float(np.max(np.abs(m_alpha.imag)) / scale)
m_alpha = m_alpha.real
else:
imag_mag = 0.0
return m_alpha, imag_mag
def _fractional_affine_parts(
matrix: np.ndarray, translation: np.ndarray, alpha: float
) -> tuple[np.ndarray, np.ndarray, float]:
"""Compute ``T**alpha`` for an affine transform given by (matrix, translation, centre).
The affine acts as ``T(x) = M(x - c) + c + t``. For fractional alpha, we
keep the centre fixed and compute:
M_alpha = M ** alpha
t_alpha = (I - M_alpha) @ (I - M)^{-1} @ t
Derivation: write ``T(x) = M x + b`` with ``b = (I - M) c + t``. Iterating
this k times gives ``M^k x + (sum_{i=0}^{k-1} M^i) b``. The closed form
``(I - M^alpha)(I - M)^{-1}`` continuously extends the geometric sum to
real alpha. Substituting ``b`` back and using a fixed centre gives the
formula above.
Sanity checks:
alpha=1 → M_alpha=M, t_alpha=t.
alpha=0 → M_alpha=I, t_alpha=0.
alpha=0.5 → (M_alpha + I) @ t_alpha = t, matching the half-transform
relation used throughout this module.
When ``I - M`` is near-singular (pure translation / near-identity
matrix), we fall back to the linear approximation ``t_alpha ≈ alpha * t``
which is exact for M = I.
"""
dim = matrix.shape[0]
identity = np.eye(dim)
# Exact shortcuts for alpha∈{0, 1} -- avoids numerical drift from
# fractional_matrix_power that would move identity pixels around and
# break the "boundary planes match sources exactly" guarantee.
if alpha == 0.0:
return identity, np.zeros(dim, dtype=np.float64), 0.0
if alpha == 1.0:
return np.asarray(matrix, dtype=np.float64), np.asarray(translation, dtype=np.float64), 0.0
m_alpha, imag_mag = _matrix_fractional_power(matrix, alpha)
diff = identity - matrix
if abs(np.linalg.det(diff)) < 1e-10:
t_alpha = alpha * np.asarray(translation, dtype=np.float64)
else:
acc = np.linalg.solve(diff, np.asarray(translation, dtype=np.float64))
t_alpha = (identity - m_alpha) @ acc
return m_alpha, t_alpha, imag_mag
# ---------------------------------------------------------------------------
# Simple interpolators (unchanged API; used as fallbacks)
# ---------------------------------------------------------------------------
[docs]
def interpolate_average(vol_before: np.ndarray, vol_after: np.ndarray) -> np.ndarray:
"""Return a 50/50 average of two adjacent volumes."""
return 0.5 * vol_before.astype(np.float32) + 0.5 * vol_after.astype(np.float32)
[docs]
def interpolate_weighted(vol_before: np.ndarray, vol_after: np.ndarray, sigma: float = 2.0) -> np.ndarray:
"""Weighted average with Gaussian smoothing along Z."""
avg = 0.5 * vol_before.astype(np.float32) + 0.5 * vol_after.astype(np.float32)
return gaussian_filter(avg, sigma=(sigma, 0, 0))
# ---------------------------------------------------------------------------
# Boundary plane / slab selection
# ---------------------------------------------------------------------------
def _build_reference_slab(vol: np.ndarray, z_center: int, slab_size: int) -> np.ndarray:
"""Mean-intensity projection over *slab_size* planes centred at *z_center*.
Clamps to the volume bounds. A 1-plane slab returns the plane itself.
"""
nz = vol.shape[0]
half = max(1, slab_size) // 2
lo = max(0, z_center - half)
hi = min(nz, z_center + half + 1)
return vol[lo:hi].mean(axis=0).astype(np.float32)
[docs]
def find_best_overlap_planes(
vol_before: np.ndarray,
vol_after: np.ndarray,
search_window: int = 5,
min_foreground_fraction: float = 0.1,
) -> tuple[int, int, float]:
"""Find the best-correlated plane pair at the volume boundary.
In serial sectioning the physically adjacent tissue is near the **bottom**
of *vol_before* and the **top** of *vol_after*. This function searches
the last ``search_window`` planes of *vol_before* against the first
``search_window`` planes of *vol_after* using normalised cross-correlation
on the central ROI, skipping planes whose foreground fraction is below
``min_foreground_fraction``.
Returns ``(ref_before, ref_after, best_corr)``. When no candidate pair
passes the foreground filter, the corner planes are returned with a
correlation of ``-inf``.
"""
nz_before = vol_before.shape[0]
nz_after = vol_after.shape[0]
before_zs = [
z
for z in range(max(0, nz_before - search_window), nz_before)
if _foreground_fraction(vol_before[z]) >= min_foreground_fraction
]
after_zs = [
z for z in range(min(search_window, nz_after)) if _foreground_fraction(vol_after[z]) >= min_foreground_fraction
]
if not before_zs or not after_zs:
logger.warning(
"find_best_overlap_planes: no candidate plane passed the foreground filter (before_zs=%s, after_zs=%s)",
before_zs,
after_zs,
)
return nz_before - 1, 0, float("-inf")
h, w = vol_before.shape[1], vol_before.shape[2]
margin = min(h, w) // 4
roi = (slice(margin, h - margin), slice(margin, w - margin))
# Normalise on the ROI (not the full plane) so the resulting arrays are
# zero-mean / unit-std over the region that actually goes into the NCC.
# Normalising on the full plane -- where OCT backgrounds are mostly zero --
# leaves the central tissue ROI with a strongly positive mean, which
# inflates `mean(a*b)` well beyond the [-1, 1] range expected for NCC.
before_norms = {z: _normalize_plane_for_ncc(vol_before[z][roi]) for z in before_zs}
after_norms = {z: _normalize_plane_for_ncc(vol_after[z][roi]) for z in after_zs}
best_corr = -np.inf
ref_before = before_zs[-1]
ref_after = after_zs[0]
for zb in before_zs:
for za in after_zs:
corr = float(np.mean(before_norms[zb] * after_norms[za]))
if corr > best_corr:
best_corr = corr
ref_before = zb
ref_after = za
return ref_before, ref_after, best_corr
# ---------------------------------------------------------------------------
# 2D affine registration wrapper shared by both interpolators
# ---------------------------------------------------------------------------
def _prepare_2d(plane: np.ndarray) -> np.ndarray:
"""Normalise a 2D plane to [0, 1] for registration."""
plane = plane.astype(np.float32)
mn, mx = float(plane.min()), float(plane.max())
if mx > mn:
return (plane - mn) / (mx - mn)
return plane
def _register_boundary(
fixed_2d: np.ndarray,
moving_2d: np.ndarray,
metric: str,
max_iterations: int,
) -> sitk.Transform:
transform, _, _ = register_2d_images_sitk(
fixed_2d,
moving_2d,
method="affine",
metric=metric,
max_iterations=max_iterations,
return_3d_transform=False,
verbose=False,
)
return transform
# ---------------------------------------------------------------------------
# Blend helpers
# ---------------------------------------------------------------------------
def _gaussian_feather_blend(
warped_before: np.ndarray,
warped_after: np.ndarray,
w_before: np.ndarray | None = None,
w_after: np.ndarray | None = None,
) -> np.ndarray:
"""Per-plane distance-transform feather blend.
Combines an XY edge feather (via the distance transform of each input's
foreground mask) with optional per-plane z-weights ``w_before`` /
``w_after`` (shape ``(nz,)``). The z-weights are authoritative: when
``w_before[z] = 1`` and ``w_after[z] = 0`` the output at plane ``z`` is
exactly ``warped_before[z]`` within its foreground region, even in pixels
where only ``warped_after`` has data (those pixels remain 0, mirroring
the input). Out-of-mask regions of one source are filled from the other
only when the corresponding z-weight is non-zero.
"""
nz, nx, ny = warped_before.shape
mask_before = warped_before > 0
mask_after = warped_after > 0
dist_before = np.zeros((nz, nx, ny), dtype=np.float32)
dist_after = np.zeros((nz, nx, ny), dtype=np.float32)
for z in range(nz):
if mask_before[z].any():
dist_before[z] = distance_transform_edt(mask_before[z])
if mask_after[z].any():
dist_after[z] = distance_transform_edt(mask_after[z])
dist_before = gaussian_filter(dist_before, sigma=(0, 2, 2))
dist_after = gaussian_filter(dist_after, sigma=(0, 2, 2))
zw_before = np.ones((nz,), dtype=np.float32) if w_before is None else np.asarray(w_before, dtype=np.float32).reshape(-1)
zw_after = np.ones((nz,), dtype=np.float32) if w_after is None else np.asarray(w_after, dtype=np.float32).reshape(-1)
weighted_before = dist_before * zw_before.reshape(-1, 1, 1)
weighted_after = dist_after * zw_after.reshape(-1, 1, 1)
total = weighted_before + weighted_after + 1e-10
wb = weighted_before / total
wa = weighted_after / total
# Only-X-side regions: fall back to that side at its z-weight (not 1).
# This keeps boundary planes (one side fully zeroed out) matching the
# other side exactly without polluting only-X regions with "ghost" data
# that the active side never intended to contribute.
only_before = mask_before & ~mask_after
only_after = mask_after & ~mask_before
zb_bcast = np.broadcast_to(zw_before.reshape(-1, 1, 1), wb.shape)
za_bcast = np.broadcast_to(zw_after.reshape(-1, 1, 1), wa.shape)
wb = np.where(only_before, zb_bcast, wb)
wa = np.where(only_before, 0.0, wa)
wb = np.where(only_after, 0.0, wb)
wa = np.where(only_after, za_bcast, wa)
return wb * warped_before + wa * warped_after
# ---------------------------------------------------------------------------
# z-morph interpolation: physical-geometry aware
# ---------------------------------------------------------------------------
def _fractional_affine_transform(
matrix: np.ndarray,
translation: np.ndarray,
center: np.ndarray,
alpha: float,
dim: int = 2,
) -> sitk.AffineTransform:
"""Construct a SimpleITK affine for ``T**alpha`` with a fixed centre."""
m_alpha, t_alpha, _imag = _fractional_affine_parts(matrix, translation, alpha)
tform = sitk.AffineTransform(dim)
tform.SetMatrix(m_alpha.flatten().tolist())
tform.SetTranslation(t_alpha.tolist())
tform.SetCenter(center.tolist())
return tform
[docs]
def interpolate_z_morph(
vol_before: np.ndarray,
vol_after: np.ndarray,
output_z: int | None = None,
metric: str = "MSE",
max_iterations: int = 1000,
overlap_search_window: int = 5,
min_overlap_correlation: float = 0.3,
reference_slab_size: int = 3,
min_foreground_fraction: float = 0.1,
min_ncc_improvement: float = 0.05,
blend_method: str = "gaussian",
) -> tuple[np.ndarray | None, dict[str, Any]]:
"""Z-aware morphing interpolation.
Registers ``vol_after[0]`` to ``vol_before[-1]`` to get an affine ``T``,
then for each output plane at fractional depth ``alpha ∈ [0, 1]`` warps
the before-boundary by ``T**alpha`` and the after-boundary by
``T**(alpha - 1)``, cross-fading with weight ``alpha``. Output top/bottom
planes match the boundary planes exactly.
**Hard skip on gate failure.** When any quality gate fails
(``no_foreground_planes``, ``low_overlap_ncc``, ``registration_exception``,
``reg_did_not_improve``, ``affine_determinant_non_positive``) the function
returns ``(None, diagnostics)`` with
``diagnostics["interpolation_failed"] = True`` and a specific
``fallback_reason``. **No fabricated volume is produced** -- blending the
two neighbours would also be made-up data, with the added failure mode of
ghost/double-contour artefacts whenever the two neighbours differ.
See ``docs/SLICE_INTERPOLATION_FEATURE.md`` for the physical model, the
rationale for the hard-skip behaviour, and parameter-tuning guidance.
Returns
-------
volume : np.ndarray | None
Interpolated 3D volume, shape ``(output_z or min(nz_before, nz_after), H, W)``,
or ``None`` when a quality gate fails.
diagnostics : dict
JSON-serialisable trace of the attempt.
"""
nz_before, nx, ny = vol_before.shape
nz_after = vol_after.shape[0]
nz_out = output_z if output_z is not None else min(nz_before, nz_after)
diag: dict[str, Any] = {
"method": "zmorph",
"method_used": "zmorph",
"fallback_reason": None,
"nz_out": int(nz_out),
"overlap_search_window": overlap_search_window,
"reference_slab_size": reference_slab_size,
"min_foreground_fraction": min_foreground_fraction,
"min_overlap_correlation": min_overlap_correlation,
"min_ncc_improvement": min_ncc_improvement,
"blend_method": blend_method,
"registration_metric": metric,
"max_iterations": max_iterations,
}
# -- Boundary plane/slab selection --------------------------------------
ref_before, ref_after, best_corr = find_best_overlap_planes(
vol_before,
vol_after,
search_window=overlap_search_window,
min_foreground_fraction=min_foreground_fraction,
)
diag["ref_before"] = int(ref_before)
diag["ref_after"] = int(ref_after)
diag["pre_reg_ncc"] = float(best_corr)
def _hard_skip(reason: str) -> tuple[None, dict[str, Any]]:
"""Abort interpolation without fabricating a volume.
A weighted blend of the two neighbours would also be made-up data
(doubled contours, ghosted structures). Honest behaviour is to emit
no output and let the pipeline treat the slot as a genuine gap.
"""
diag["method_used"] = None
diag["fallback_reason"] = reason
diag["interpolation_failed"] = True
logger.warning(
"[interpolation] zmorph could not produce a reliable output (reason=%s); "
"emitting no volume (slot will stay empty in the final reconstruction)",
reason,
)
return None, diag
if not np.isfinite(best_corr) or best_corr < min_overlap_correlation:
return _hard_skip("no_foreground_planes" if not np.isfinite(best_corr) else "low_overlap_ncc")
slab_before = _build_reference_slab(vol_before, ref_before, reference_slab_size)
slab_after = _build_reference_slab(vol_after, ref_after, reference_slab_size)
fixed_2d = _prepare_2d(slab_after)
moving_2d = _prepare_2d(slab_before)
try:
transform_2d = _register_boundary(fixed_2d, moving_2d, metric=metric, max_iterations=max_iterations)
except Exception as exc:
diag["registration_error_message"] = str(exc)
return _hard_skip("registration_exception")
affine_2d = sitk.AffineTransform(transform_2d)
matrix = np.array(affine_2d.GetMatrix()).reshape(2, 2)
translation = np.array(affine_2d.GetTranslation())
center = np.array(affine_2d.GetCenter())
warped_slab_before = apply_transform(slab_before.astype(np.float32), transform_2d)
post_reg_ncc = _ncc(slab_after, warped_slab_before)
diag["post_reg_ncc"] = float(post_reg_ncc)
diag["ncc_improvement"] = float(post_reg_ncc - best_corr)
diag["affine_matrix"] = matrix.tolist()
diag["affine_translation"] = translation.tolist()
diag["affine_determinant"] = float(np.linalg.det(matrix))
if post_reg_ncc - best_corr < min_ncc_improvement:
return _hard_skip("reg_did_not_improve")
det = float(np.linalg.det(matrix))
if det <= 0.0:
diag["affine_determinant_invalid"] = True
return _hard_skip("affine_determinant_non_positive")
# -- Build the morphed output ------------------------------------------
top_of_after = vol_after[0].astype(np.float32)
bottom_of_before = vol_before[-1].astype(np.float32)
warped_before = np.zeros((nz_out, nx, ny), dtype=np.float32)
warped_after = np.zeros((nz_out, nx, ny), dtype=np.float32)
w_before_list: list[float] = []
max_imag = 0.0
for z in range(nz_out):
alpha = z / (nz_out - 1) if nz_out > 1 else 0.5
# before contribution: warp by T**alpha (alpha ∈ [0, 1])
m_a, t_a, imag_a = _fractional_affine_parts(matrix, translation, alpha)
max_imag = max(max_imag, imag_a)
before_tform = sitk.AffineTransform(2)
before_tform.SetMatrix(m_a.flatten().tolist())
before_tform.SetTranslation(t_a.tolist())
before_tform.SetCenter(center.tolist())
warped_before[z] = apply_transform(bottom_of_before, before_tform)
# after contribution: warp by T**(alpha - 1) (alpha - 1 ∈ [-1, 0])
m_a2, t_a2, imag_b = _fractional_affine_parts(matrix, translation, alpha - 1.0)
max_imag = max(max_imag, imag_b)
after_tform = sitk.AffineTransform(2)
after_tform.SetMatrix(m_a2.flatten().tolist())
after_tform.SetTranslation(t_a2.tolist())
after_tform.SetCenter(center.tolist())
warped_after[z] = apply_transform(top_of_after, after_tform)
w_before_list.append(1.0 - alpha)
diag["fractional_power_max_imag"] = float(max_imag)
if max_imag > 1e-3:
diag["warning_large_imag_part"] = True
w_before_arr = np.asarray(w_before_list, dtype=np.float32)
w_after_arr = 1.0 - w_before_arr
if blend_method == "linear":
result = w_before_arr.reshape(-1, 1, 1) * warped_before + w_after_arr.reshape(-1, 1, 1) * warped_after
elif blend_method == "gaussian":
result = _gaussian_feather_blend(warped_before, warped_after, w_before=w_before_arr, w_after=w_after_arr)
else:
raise ValueError(f"Unknown blend_method: {blend_method}")
# Quality stats at the two boundaries (should be near-perfect)
diag["top_boundary_residual_mean"] = float(np.abs(result[0] - bottom_of_before).mean())
diag["bottom_boundary_residual_mean"] = float(np.abs(result[-1] - top_of_after).mean())
return result, diag