"""N4 bias field correction for serial OCT stacks.
Provides CPU-based N4 correction via SimpleITK and helpers to run it
per serial section in parallel via :mod:`multiprocessing`.
Typical two-pass usage::
from linumpy.intensity.bias_field import compute_tissue_mask, n4_correct_per_section, n4_correct
mask = compute_tissue_mask(vol)
vol_ps, _ = n4_correct_per_section(vol, n_serial_slices=50, mask=mask, n_processes=48)
vol_out, _ = n4_correct(vol_ps, mask)
"""
from __future__ import annotations
import multiprocessing
from typing import Any
import numpy as np
import SimpleITK as sitk
from linumpy.intensity.normalization import _chunk_boundaries
# ---------------------------------------------------------------------------
# Tissue mask
# ---------------------------------------------------------------------------
def _compute_tissue_mask_gpu(
vol: np.ndarray,
smoothing_sigma: float,
smoothing_sigma_z: float,
n_serial_slices: int,
closing_radius: int,
z_closing_sections: int,
) -> np.ndarray:
"""GPU implementation of :func:`compute_tissue_mask`.
Keeps the full pipeline (gaussian → Otsu → threshold → per-Z hole
fill + closing → final Z-closing) resident on GPU. Only the final
bool mask crosses PCIe (8x smaller than a float32 D2H of the
smoothed volume). One section per H2D round trip; if a single
section exceeds GPU memory, we fall back to the CPU path.
"""
import cupy as cp
from cupyx.scipy.ndimage import (
binary_closing as cp_binary_closing,
)
from cupyx.scipy.ndimage import (
binary_fill_holes as cp_binary_fill_holes,
)
from cupyx.scipy.ndimage import (
gaussian_filter as cp_gaussian_filter,
)
from skimage.morphology import disk
sigma_zyx = (smoothing_sigma_z, smoothing_sigma, smoothing_sigma)
structuring_g = cp.asarray(disk(closing_radius), dtype=bool) if closing_radius > 0 else None
bounds = _chunk_boundaries(vol.shape[0], n_serial_slices)
mask = np.zeros(vol.shape, dtype=bool)
for s, e in bounds:
section_g = cp.asarray(vol[s:e], dtype=cp.float32)
smoothed_g = cp_gaussian_filter(section_g, sigma=sigma_zyx)
del section_g
# Otsu on the GPU section using cupy.histogram on nonzero voxels.
nonzero_g = smoothed_g[smoothed_g > 0]
if nonzero_g.size < 100:
mask[s:e] = True
del smoothed_g, nonzero_g
cp.get_default_memory_pool().free_all_blocks()
continue
thresh = float(_otsu_threshold_gpu(nonzero_g))
del nonzero_g
section_mask_g = smoothed_g > thresh
del smoothed_g
# Per-Z hole filling and closing (oblique masks differ across Z).
for z in range(section_mask_g.shape[0]):
plane_g = cp_binary_fill_holes(section_mask_g[z])
if structuring_g is not None:
plane_g = cp_binary_closing(plane_g, structure=structuring_g)
section_mask_g[z] = plane_g
mask[s:e] = cp.asnumpy(section_mask_g)
del section_mask_g
cp.get_default_memory_pool().free_all_blocks()
# Bridge step artifacts at section boundaries by closing along Z.
if z_closing_sections > 0 and n_serial_slices > 1:
z_struct = np.ones((2 * z_closing_sections + 1, 1, 1), dtype=bool)
# The full bool mask is 8x smaller than vol; usually fits on a single
# GPU. If it does not, fall back to CPU for this final step.
mask_bytes = int(mask.size)
free_mem, _ = cp.cuda.runtime.memGetInfo()
if mask_bytes * 4 < free_mem: # 4x headroom for kernel scratch
mask_g = cp.asarray(mask)
struct_g = cp.asarray(z_struct)
mask_g = cp_binary_closing(mask_g, structure=struct_g)
mask = cp.asnumpy(mask_g)
del mask_g, struct_g
cp.get_default_memory_pool().free_all_blocks()
else:
from scipy.ndimage import binary_closing as np_binary_closing
mask = np_binary_closing(mask, structure=z_struct)
return mask
def _otsu_threshold_gpu(values: Any, nbins: int = 256) -> float:
"""Compute Otsu's threshold on a 1-D CuPy array via histogram search."""
import cupy as cp
lo = float(values.min().item())
hi = float(values.max().item())
if hi <= lo:
return lo
hist, edges = cp.histogram(values, bins=nbins, range=(lo, hi))
# Mirror skimage.filters.threshold_otsu: minimize within-class variance
# equivalent to maximizing between-class variance.
centers = 0.5 * (edges[:-1] + edges[1:])
hist = hist.astype(cp.float64)
weight1 = cp.cumsum(hist)
weight2 = cp.cumsum(hist[::-1])[::-1]
mean1 = cp.cumsum(hist * centers) / cp.maximum(weight1, 1.0)
mean2 = (cp.cumsum((hist * centers)[::-1]) / cp.maximum(weight2[::-1], 1.0))[::-1]
variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:]) ** 2
idx = int(cp.argmax(variance12).item())
return float(centers[idx].item())
[docs]
def compute_tissue_mask(
vol: np.ndarray,
smoothing_sigma: float = 2.0,
n_serial_slices: int = 1,
closing_radius: int = 3,
z_closing_sections: int = 2,
smoothing_sigma_z: float = 1.0,
use_gpu: bool = False,
) -> np.ndarray:
"""Return a 3-D boolean mask where *True* indicates tissue (not agarose).
The volume is lightly smoothed with an anisotropic 3-D Gaussian
(``smoothing_sigma`` in XY, ``smoothing_sigma_z`` in Z) and a single
Otsu threshold is computed per serial section from the smoothed
voxel histogram (background-zero voxels excluded). The threshold is
then applied per voxel, so the mask follows tissue shape through Z
and correctly handles oblique sections (e.g. 45° acquisitions),
where the tissue footprint shifts across Z within a section.
Each Z-plane is post-processed with hole-filling and morphological
closing to remove internal speckle (e.g. dark white-matter or
ventricle voxels falling below the Otsu threshold). Finally the
stacked 3-D mask is closed along Z to bridge step artifacts at
section boundaries.
Parameters
----------
vol : np.ndarray
3-D volume (Z, Y, X), any float dtype.
smoothing_sigma : float
Gaussian smoothing sigma in XY (pixels) before thresholding.
n_serial_slices : int
Number of serial sections in the volume. When 1 (default), one
global Otsu threshold is used.
closing_radius : int
Radius (pixels) of the 2-D disk used for morphological closing
on each Z-plane mask. 0 disables 2-D closing.
z_closing_sections : int
Number of adjacent sections to bridge with a 3-D closing pass on
the stacked mask. 0 disables Z-direction closing.
smoothing_sigma_z : float
Gaussian smoothing sigma along Z (voxels) before thresholding.
Small values (1-2) denoise without blurring oblique edges.
use_gpu : bool
If True, run the dominant 3-D ``gaussian_filter`` on GPU via
CuPy (Z-chunked for memory safety). Falls back to CPU silently
if CuPy is unavailable. Otsu and morphology stay on CPU.
Returns
-------
np.ndarray
Boolean array of shape (Z, Y, X) -- True where tissue is present.
"""
from scipy.ndimage import binary_closing, binary_fill_holes, gaussian_filter
from skimage.filters import threshold_otsu
from skimage.morphology import disk
if use_gpu:
try:
return _compute_tissue_mask_gpu(
vol,
smoothing_sigma=smoothing_sigma,
smoothing_sigma_z=smoothing_sigma_z,
n_serial_slices=n_serial_slices,
closing_radius=closing_radius,
z_closing_sections=z_closing_sections,
)
except ImportError:
pass # CuPy missing -- fall back to CPU below.
# Anisotropic 3-D smoothing: stronger in XY, light in Z to preserve
# oblique tissue boundaries without per-Z Otsu noise.
sigma_zyx = (smoothing_sigma_z, smoothing_sigma, smoothing_sigma)
smoothed = gaussian_filter(vol.astype(np.float32), sigma=sigma_zyx)
bounds = _chunk_boundaries(vol.shape[0], n_serial_slices)
mask = np.zeros(vol.shape, dtype=bool)
structuring = disk(closing_radius) if closing_radius > 0 else None
for s, e in bounds:
section_smooth = smoothed[s:e]
nonzero = section_smooth[section_smooth > 0]
if nonzero.size < 100:
mask[s:e] = True
continue
thresh = threshold_otsu(nonzero)
section_mask = section_smooth > thresh
# Per-Z hole filling and closing (oblique masks differ across Z).
for z in range(section_mask.shape[0]):
plane = binary_fill_holes(section_mask[z])
if structuring is not None:
plane = binary_closing(plane, structure=structuring)
section_mask[z] = plane
mask[s:e] = section_mask
# Bridge step artifacts at section boundaries by closing along Z.
if z_closing_sections > 0 and n_serial_slices > 1:
z_struct = np.ones((2 * z_closing_sections + 1, 1, 1), dtype=bool)
mask = binary_closing(mask, structure=z_struct)
return mask
# ---------------------------------------------------------------------------
# N4 core
# ---------------------------------------------------------------------------
[docs]
def n4_correct(
vol: np.ndarray,
mask: np.ndarray | None = None,
*,
shrink_factor: int = 4,
n_iterations: list[int] | None = None,
spline_distance_mm: float = 10.0,
voxel_size_mm: tuple[float, float, float] = (1.0, 1.0, 1.0),
backend: str = "cpu",
out: np.ndarray | None = None,
bias_out: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Run N4 bias field correction on a 3-D volume.
The N4 fit is performed on a spatially downsampled copy (``shrink_factor``);
the bias field is then upsampled back to full resolution before division.
Parameters
----------
vol : np.ndarray
Float32 input volume (Z, Y, X).
mask : np.ndarray or None
Boolean tissue mask (Z, Y, X) -- same shape as *vol*. A full-volume
mask is used when *None*.
shrink_factor : int
Isotropic spatial downsampling factor for the N4 fit.
n_iterations : list of int or None
Max iterations per fitting level; its length sets the number of fitting
levels. Defaults to ``[50, 50, 50, 50]`` (4 levels).
spline_distance_mm : float
Approximate distance (in mm) between B-spline control-point knots.
voxel_size_mm : 3-tuple of float
Voxel size (z, y, x) in mm -- sets physical spacing for SimpleITK.
backend : {"cpu", "gpu", "auto"}
Backend selector. ``"cpu"`` (default) uses SimpleITK's N4
implementation. ``"gpu"`` dispatches to
:func:`linumpy.gpu.n4.n4_correct_gpu` (CuPy-accelerated when CUDA is
available, NumPy fallback otherwise). ``"auto"`` picks ``"gpu"`` when
CuPy + CUDA are available and ``"cpu"`` otherwise.
out, bias_out : np.ndarray, optional
Destination buffers (GPU backend only). When provided, the
N4 driver writes its full-resolution outputs directly into
these buffers instead of allocating fresh arrays, saving up to
two full-volume float32 allocations. ``out`` may safely alias
the input ``vol`` -- the host buffer is not read after the
initial H2D upload.
Returns
-------
corrected : np.ndarray
Bias-corrected float32 volume, same shape as *vol*.
bias_field : np.ndarray
Estimated bias field (multiplicative), float32, same shape as *vol*.
"""
if backend not in ("cpu", "gpu", "auto"):
raise ValueError(f"backend must be 'cpu', 'gpu', or 'auto', got {backend!r}")
if backend == "auto":
from linumpy.gpu import GPU_AVAILABLE
backend = "gpu" if GPU_AVAILABLE else "cpu"
if backend == "gpu":
from linumpy.gpu.n4 import n4_correct_gpu
return n4_correct_gpu(
vol,
mask,
shrink_factor=shrink_factor,
n_iterations=n_iterations,
spline_distance_mm=spline_distance_mm,
voxel_size_mm=voxel_size_mm,
use_gpu=True,
out=out,
bias_out=bias_out,
)
if out is not None or bias_out is not None:
raise ValueError("out / bias_out are only supported with backend='gpu'")
vol_f32 = vol.astype(np.float32)
if n_iterations is None:
n_iterations = [50, 50, 50, 50]
# Build SimpleITK images -- ITK convention is (x, y, z), so transpose (Z,Y,X)→(X,Y,Z)
sitk_vol = sitk.GetImageFromArray(vol_f32.transpose(2, 1, 0))
sitk_vol.SetSpacing((float(voxel_size_mm[2]), float(voxel_size_mm[1]), float(voxel_size_mm[0])))
if mask is not None:
sitk_mask = sitk.GetImageFromArray(mask.astype(np.uint8).transpose(2, 1, 0))
sitk_mask.CopyInformation(sitk_vol)
else:
sitk_mask = None
# Shrink for fast fit
shrinker = sitk.ShrinkImageFilter()
shrinker.SetShrinkFactors([shrink_factor] * 3)
sitk_vol_shrunk = shrinker.Execute(sitk_vol)
sitk_mask_shrunk = shrinker.Execute(sitk_mask) if sitk_mask is not None else None
corrector = sitk.N4BiasFieldCorrectionImageFilter()
corrector.SetMaximumNumberOfIterations(n_iterations)
# Per-axis control points = physical extent (mm) / spline_distance (mm).
# SimpleITK expects (x, y, z) order while voxel_size_mm / vol.shape are (z, y, x).
min_control_points = corrector.GetSplineOrder() + 1 # ITK requires n_pts > spline_order
extents_mm_zyx = [vol_f32.shape[i] * float(voxel_size_mm[i]) for i in range(3)]
n_pts_zyx = [max(min_control_points, round(e / spline_distance_mm)) for e in extents_mm_zyx]
corrector.SetNumberOfControlPoints([n_pts_zyx[2], n_pts_zyx[1], n_pts_zyx[0]])
if sitk_mask_shrunk is not None:
corrector.Execute(sitk_vol_shrunk, sitk_mask_shrunk)
else:
corrector.Execute(sitk_vol_shrunk)
# Reconstruct full-resolution bias field
log_bias_shrunk = corrector.GetLogBiasFieldAsImage(sitk_vol_shrunk)
log_bias_full = sitk.Resample(
log_bias_shrunk,
sitk_vol,
sitk.Transform(),
sitk.sitkLinear,
0.0,
sitk.sitkFloat32,
)
log_bias_arr = sitk.GetArrayFromImage(log_bias_full).transpose(2, 1, 0) # back to (Z,Y,X)
bias_field = np.exp(log_bias_arr).astype(np.float32)
corrected = apply_bias_field(vol_f32, bias_field)
return corrected, bias_field
# ---------------------------------------------------------------------------
# Bias field application
# ---------------------------------------------------------------------------
[docs]
def apply_bias_field(vol: np.ndarray, bias_field: np.ndarray, floor: float = 1e-6) -> np.ndarray:
"""Divide *vol* element-wise by *bias_field*, guarding against near-zero divisors.
Parameters
----------
vol : np.ndarray
Input volume, any shape.
bias_field : np.ndarray
Multiplicative bias field, same shape as *vol*.
floor : float
Minimum divisor value (prevents division by zero).
Returns
-------
np.ndarray
Corrected float32 array.
"""
divisor = np.maximum(bias_field.astype(np.float32), floor)
return (vol.astype(np.float32) / divisor).astype(np.float32)
# ---------------------------------------------------------------------------
# Per-section parallel N4
# ---------------------------------------------------------------------------
def _n4_section_worker(args: tuple[Any, ...]) -> tuple[np.ndarray, np.ndarray]:
"""Worker function for :func:`n4_correct_per_section` (picklable top-level)."""
chunk_vol, chunk_mask, kwargs = args
return n4_correct(chunk_vol, chunk_mask, **kwargs)
[docs]
def n4_correct_per_section(
vol: np.ndarray,
n_serial_slices: int,
mask: np.ndarray | None = None,
*,
n_processes: int = 1,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray]:
"""Run N4 bias field correction independently on each serial section.
Splits the volume along Z into *n_serial_slices* chunks and corrects each
chunk independently (serial sections have independent optical attenuation).
Chunks are dispatched to a :class:`multiprocessing.Pool` when
*n_processes* > 1.
Parameters
----------
vol : np.ndarray
Float32 3-D volume (Z, Y, X).
n_serial_slices : int
Number of serial tissue sections stacked along Z.
mask : np.ndarray or None
Boolean tissue mask (Z, Y, X). Sliced alongside *vol*.
n_processes : int
Number of parallel worker processes. 1 runs serially.
**kwargs
Extra keyword arguments forwarded to :func:`n4_correct`
(e.g. ``shrink_factor``, ``spline_distance_mm``).
Returns
-------
corrected : np.ndarray
Bias-corrected float32 volume, same shape as *vol*.
bias_field : np.ndarray
Per-section bias field stitched into a single (Z, Y, X) array.
"""
bounds = _chunk_boundaries(vol.shape[0], n_serial_slices)
# GPU backend cannot be parallelised across processes (single device);
# force serial execution.
backend = kwargs.get("backend", "cpu")
if backend == "auto":
from linumpy.gpu import GPU_AVAILABLE
effective_gpu = GPU_AVAILABLE
else:
effective_gpu = backend == "gpu"
if effective_gpu and n_processes != 1:
import logging
logging.getLogger(__name__).warning(
"GPU N4 backend cannot be parallelised across processes (single device); "
"forcing n_processes=1 (was %d). Per-section sections will run serially on GPU.",
n_processes,
)
n_processes = 1
work_items = [
(
vol[s:e].copy(),
mask[s:e].copy() if mask is not None else None,
kwargs,
)
for s, e in bounds
]
if n_processes == 1:
results = [_n4_section_worker(item) for item in work_items]
else:
with multiprocessing.Pool(processes=n_processes) as pool:
results = pool.map(_n4_section_worker, work_items)
corrected_chunks, bias_chunks = zip(*results, strict=True)
corrected = np.concatenate(corrected_chunks, axis=0)
bias_field = np.concatenate(bias_chunks, axis=0)
return corrected, bias_field