Source code for linumpy.gpu.n4

"""GPU N4 bias field correction.

Implements the Tustison 2010 N4 algorithm using the B-spline primitive
in :mod:`linumpy.gpu.bspline` and a CuPy/NumPy-shared histogram
sharpening routine.

Each fitting level loops over:

1.  Compute the log-residual ``r = log(v) - log_bias`` on masked voxels.
2.  Sharpen the residual histogram by Wiener-deconvolving it with a
    Gaussian PSF (Sled 1998 / Tustison 2010), producing a LUT mapping
    observed log-intensity to expected (unbiased) log-intensity.
3.  Voxel-wise, compute the per-voxel log-bias update
    ``delta = log(v) - LUT(log(v) - log_bias)``.
4.  Fit a tensor-product cubic B-spline to ``delta`` on a regular
    control grid, evaluate at full resolution, and add to ``log_bias``.

The next fitting level doubles the number of control points per axis.

Memory budget (per N4 call):

    ~6 x volume_size x 4 bytes

i.e. ~12 GB for a (256, 1024, 1024) float32 volume.
"""

from __future__ import annotations

from typing import Any

import numpy as np

from linumpy.gpu import GPU_AVAILABLE, get_array_module
from linumpy.gpu.bspline import (
    _build_axis_basis,
    _is_gpu_array,
    bspline_evaluate,
    bspline_fit,
    bspline_fit_precompute,
)

# ---------------------------------------------------------------------------
# Histogram sharpening
# ---------------------------------------------------------------------------


def _build_log_psf(n_bins: int, bin_width: float, fwhm: float, xp: Any) -> Any:
    """Return a centred Gaussian PSF over *n_bins* bins.

    Parameters
    ----------
    n_bins : int
        Histogram bin count.
    bin_width : float
        Histogram bin width in log-intensity units.
    fwhm : float
        Full-width-half-maximum of the Gaussian PSF, log-intensity units.
    xp : module
        Array module.
    """
    sigma = fwhm / 2.3548200450309493  # 2 sqrt(2 ln 2)
    centre = n_bins // 2
    x = (xp.arange(n_bins, dtype=xp.float32) - centre) * bin_width
    psf = xp.exp(-0.5 * (x / sigma) ** 2)
    psf = psf / psf.sum()
    return psf


[docs] def sharpen_residual( log_v: np.ndarray, mask: np.ndarray | None, *, n_bins: int = 200, fwhm_log: float = 0.15, wiener_noise: float = 0.01, use_gpu: bool = True, mask_weights: np.ndarray | None = None, ) -> np.ndarray: """Return the per-voxel sharpened log-intensity (LUT-mapped). Implements the Sled/Tustison histogram sharpening: build the weighted log-intensity histogram restricted to *mask*, deconvolve it by a Gaussian PSF (Wiener-regularised), and return the LUT ``E[v_true | v_obs]`` evaluated at every voxel in *log_v*. Parameters ---------- log_v : np.ndarray Log-intensity volume (any shape, float32). mask : np.ndarray or None Boolean mask; only masked voxels contribute to the histogram. When ``None``, all voxels are used. n_bins : int Histogram bin count. fwhm_log : float Full-width-half-maximum of the Gaussian PSF in log-intensity units. Controls how much sharpening is applied (smaller FWHM means less sharpening, since the deconvolution kernel is narrower). N4 default is approximately 0.15. wiener_noise : float Wiener regularisation term. Larger values stabilise the deconvolution at the expense of sharpening. use_gpu : bool Use CuPy when available. mask_weights : np.ndarray, optional Float32 view of *mask* (``mask.astype(float32)``). When provided, skips the per-call cast -- a full-volume allocation that the N4 fit loop otherwise repeats every iteration. Returns ------- np.ndarray Sharpened log-intensity, same shape and dtype as *log_v*. Outside the mask, the input log-intensity is returned unchanged. """ xp = get_array_module(use_gpu=use_gpu and GPU_AVAILABLE) log_v_xp = xp.asarray(log_v, dtype=xp.float32) mask_xp = xp.ones_like(log_v_xp, dtype=xp.bool_) if mask is None else xp.asarray(mask, dtype=xp.bool_) # Compute masked min/max without materialising the masked subset # (boolean indexing is a slow scatter-gather on GPU). We use # +/-inf sentinels outside the mask so reductions ignore them. pos_inf = xp.float32(np.inf) neg_inf = xp.float32(-np.inf) r_min = float(xp.where(mask_xp, log_v_xp, pos_inf).min()) r_max = float(xp.where(mask_xp, log_v_xp, neg_inf).max()) if not np.isfinite(r_min) or not np.isfinite(r_max): return log_v_xp if _is_gpu_array(log_v) else np.asarray(log_v).astype(np.float32) if r_max - r_min < 1e-8: # Degenerate distribution -- no sharpening possible. return log_v_xp if _is_gpu_array(log_v) else np.asarray(log_v).astype(np.float32) bin_width = (r_max - r_min) / float(n_bins - 1) bin_centres = xp.linspace(r_min, r_max, n_bins, dtype=xp.float32) # Quantise the FULL volume once. bin_idx_full feeds both the # weighted histogram (via bincount) AND the per-voxel LUT lookup, # so we avoid a second pass over the volume and the # boolean-indexed copy of the masked subset. bin_idx_full = xp.clip(((log_v_xp - r_min) / bin_width + 0.5).astype(xp.int64), 0, n_bins - 1) mask_w = mask_xp.astype(xp.float32) if mask_weights is None else mask_weights hist = xp.bincount(bin_idx_full.reshape(-1), weights=mask_w.reshape(-1), minlength=n_bins).astype(xp.float32) # Gaussian PSF (centred); FFT-shift to align with FFT convention. # Zero-pad histogram and PSF to ``n_pad = 2 * n_bins`` so the FFT # convolutions are linear, not circular. Without padding, mass in # the top bins (typically white matter for OCT) wraps into the # bottom-bin LUT entries (and vice-versa), pulling WM intensities # downward and visibly muting bright tissue. n_pad = 2 * n_bins psf = _build_log_psf(n_bins, bin_width, fwhm_log, xp) psf_padded = xp.zeros(n_pad, dtype=xp.float32) psf_padded[:n_bins] = psf psf_shifted = xp.roll(psf_padded, -(n_bins // 2)) hist_padded = xp.zeros(n_pad, dtype=xp.float32) hist_padded[:n_bins] = hist psf_fft = xp.fft.rfft(psf_shifted) hist_fft = xp.fft.rfft(hist_padded) # Wiener deconvolution: H_sharp = H * conj(G) / (|G|^2 + noise). psf_mag2 = (psf_fft * xp.conj(psf_fft)).real sharp_fft = hist_fft * xp.conj(psf_fft) / (psf_mag2 + wiener_noise) hist_sharp = xp.fft.irfft(sharp_fft, n=n_pad)[:n_bins] hist_sharp = xp.maximum(hist_sharp, 0.0) # LUT: for each output bin i, E[r | r_obs = bin_centres[i]] # = sum_j r_j * hist_sharp[j] * G(i - j) / sum_j hist_sharp[j] * G(i - j) # i.e. (bin_centres * hist_sharp) (*) G / hist_sharp (*) G. # Pad to n_pad as well so the LUT convolution is linear. weighted = bin_centres * hist_sharp weighted_padded = xp.zeros(n_pad, dtype=xp.float32) weighted_padded[:n_bins] = weighted hist_sharp_padded = xp.zeros(n_pad, dtype=xp.float32) hist_sharp_padded[:n_bins] = hist_sharp num_fft = xp.fft.rfft(weighted_padded) den_fft = xp.fft.rfft(hist_sharp_padded) num = xp.fft.irfft(num_fft * psf_fft, n=n_pad)[:n_bins] den = xp.fft.irfft(den_fft * psf_fft, n=n_pad)[:n_bins] lut = num / xp.maximum(den, 1e-12) # Apply LUT to every voxel; outside mask, leave intensity unchanged. sharpened = lut[bin_idx_full] sharpened = xp.where(mask_xp, sharpened, log_v_xp).astype(xp.float32) if _is_gpu_array(log_v): return sharpened if xp is np: return sharpened import cupy as cp return cp.asnumpy(sharpened).astype(np.float32)
# --------------------------------------------------------------------------- # N4 driver # ---------------------------------------------------------------------------
[docs] def n4_correct_gpu( vol: np.ndarray, mask: np.ndarray | None = None, *, shrink_factor: int = 4, n_iterations: list[int] | None = None, spline_distance_mm: float = 10.0, voxel_size_mm: tuple[float, float, float] = (1.0, 1.0, 1.0), n_bins: int = 200, fwhm_log: float = 0.15, wiener_noise: float = 0.01, convergence_tol: float = 1e-3, use_gpu: bool = True, out: np.ndarray | None = None, bias_out: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: """GPU-accelerated N4 bias field correction. Faithful CuPy/NumPy port of the Tustison 2010 N4 algorithm: at each fitting level, alternate Sled-style histogram sharpening and tensor cubic B-spline scattered-data fitting until convergence. The B-spline control mesh is fixed across levels (matching SimpleITK's behaviour); ``n_iterations`` only controls per-level iteration counts and the residual is composed across levels. Parameters mirror :func:`linumpy.intensity.bias_field.n4_correct` so the two backends are interchangeable. Extra knobs (``n_bins``, ``fwhm_log``, ``wiener_noise``) tune the sharpening histogram. Parameters ---------- vol : np.ndarray Float32 input volume (Z, Y, X). mask : np.ndarray or None Boolean tissue mask. ``None`` = full volume. shrink_factor : int Isotropic spatial subsampling factor for the fit (>=1). n_iterations : list of int or None Max iterations per fitting level. Length sets the number of levels. Default ``[20, 20, 20]``. Fewer iterations than the SimpleITK CPU backend because the GPU PSDB residual update has no internal multilevel dampening, so each iteration has full effect; more than ~20 per level causes the bias field to absorb true tissue contrast (verified empirically on live OCT). spline_distance_mm : float Approximate distance between B-spline control knots at level 0. voxel_size_mm : 3-tuple of float Voxel size (z, y, x) in mm. n_bins, fwhm_log, wiener_noise : sharpening parameters See :func:`sharpen_residual`. convergence_tol : float Per-iteration convergence threshold on the relative L2 change of ``log_bias``. Iterations stop early when the change drops below this value. use_gpu : bool Use CuPy when available. out : np.ndarray, optional Destination buffer for the corrected output (full ``vol.shape``, float32). When provided, the streaming Z-tile loop writes results directly into this buffer instead of allocating a fresh array. May safely alias the input ``vol`` (the host buffer is not read after the initial H2D upload at function entry), saving a full-volume float32 allocation on large mosaics. bias_out : np.ndarray, optional Destination buffer for the bias-field output, same shape and dtype constraints as *out*. Returns ------- corrected : np.ndarray Bias-corrected float32 volume (Z, Y, X), full resolution. bias_field : np.ndarray Estimated multiplicative bias field, float32, full resolution. """ if n_iterations is None: n_iterations = [25, 25, 25] n_levels = len(n_iterations) xp = get_array_module(use_gpu=use_gpu and GPU_AVAILABLE) on_gpu = xp is not np # Single host -> device transfer. All intermediates remain on `xp`. vol_xp = xp.asarray(vol, dtype=xp.float32) full_shape: tuple[int, int, int] = (int(vol_xp.shape[0]), int(vol_xp.shape[1]), int(vol_xp.shape[2])) mask_xp = xp.ones(full_shape, dtype=xp.bool_) if mask is None else xp.asarray(mask, dtype=xp.bool_) # Spatial subsampling for fit (stride-subsample, on device). if shrink_factor > 1: # Materialise the subsampled volume/mask as contiguous copies so we # can release the full-resolution `mask_xp` (and any temporaries # that share storage with it) during the fit loop. The fit only # touches the small arrays; `vol_xp` is kept live for the # full-resolution evaluation pass at the bottom of the function. vol_small = xp.ascontiguousarray(vol_xp[::shrink_factor, ::shrink_factor, ::shrink_factor]) mask_small = xp.ascontiguousarray(mask_xp[::shrink_factor, ::shrink_factor, ::shrink_factor]) else: vol_small = vol_xp mask_small = mask_xp # Free the full-resolution mask now -- only `mask_small` is used in the # fit loop, and a fresh full mask is not needed for the evaluation pass. del mask_xp if on_gpu: import cupy as _cp_free _cp_free.get_default_memory_pool().free_all_blocks() log_v = xp.log(xp.maximum(vol_small, 1e-6)).astype(xp.float32) # Base control-point grid sized to physical extent. ITK's spline order is # 3, so we need at least 4 control points per axis. We keep this grid # FIXED across all fitting levels: SimpleITK's N4 reuses one B-spline # mesh and accumulates residual composition across levels. Doubling the # grid per level (as earlier versions did) yields an effectively # per-voxel control mesh at level 2-3 on typical OCT slabs, which # absorbs true tissue contrast and produces a visibly jagged bias # estimate. extents_mm = tuple(full_shape[i] * float(voxel_size_mm[i]) for i in range(3)) n_ctrl_base = tuple(max(4, round(e / spline_distance_mm)) for e in extents_mm) small_shape: tuple[int, int, int] = ( int(vol_small.shape[0]), int(vol_small.shape[1]), int(vol_small.shape[2]), ) n_ctrl: tuple[int, int, int] = ( max(4, min(n_ctrl_base[0], small_shape[0])), max(4, min(n_ctrl_base[1], small_shape[1])), max(4, min(n_ctrl_base[2], small_shape[2])), ) # Build the three (n_voxels, n_control) cubic-B-spline basis matrices # once and reuse them across every level/iteration for both the fit # (forward) and evaluate (transpose-shaped) contractions. bases = ( _build_axis_basis(small_shape[0], n_ctrl[0], xp), _build_axis_basis(small_shape[1], n_ctrl[1], xp), _build_axis_basis(small_shape[2], n_ctrl[2], xp), ) log_bias = xp.zeros_like(vol_small, dtype=xp.float32) weights = mask_small.astype(xp.float32) # Accumulate control coefficients so the final full-resolution bias # field can be obtained by a single B-spline evaluation rather than # by upsampling the coarse field with a different kernel. coeff_total = xp.zeros(n_ctrl, dtype=xp.float32) # The B-spline fit denominator S(p) and the squared/cubed per-axis basis # matrices depend only on `bases`, which we hold fixed across all # fitting levels. Precompute them once so each iteration's # bspline_fit call skips a small_vol-sized allocation of S plus six # basis-power multiplies. fit_precomputed = bspline_fit_precompute(bases) # Reuse the float32 mask cast across iterations of sharpen_residual # (it is identical every iter for a given level). sharpen_mask_weights = weights for level in range(n_levels): for _ in range(n_iterations[level]): # current = log_v - log_bias (small_vol-sized intermediate; the # CuPy memory pool reuses the slot every iteration). current = log_v - log_bias sharpened = sharpen_residual( current, mask_small, n_bins=n_bins, fwhm_log=fwhm_log, wiener_noise=wiener_noise, use_gpu=use_gpu, mask_weights=sharpen_mask_weights, ) # `weights == mask_small.astype(float32)`, so multiplying by it # zeros the residual outside the mask without an extra # ``xp.where`` call. current -= sharpened current *= weights del sharpened coeffs = bspline_fit( current, weights=weights, mask=None, # weights already encodes the mask n_control_points=n_ctrl, use_gpu=use_gpu, bases=bases, precomputed=fit_precomputed, ) del current update = bspline_evaluate( coeffs, target_shape=small_shape, use_gpu=use_gpu, bases=bases, ).astype(xp.float32) log_bias += update coeff_total += coeffs del coeffs # Convergence: ||update|| / ||log_bias|| < tol. Compute on # device and issue a single D2H sync so each iteration of the # fit loop has at most one host round-trip (instead of two # ``float(xp.linalg.norm(...))`` calls). update_sq = (update * update).sum() del update bias_sq = (log_bias * log_bias).sum() converged = bool(xp.logical_and(bias_sq > 0, update_sq < (convergence_tol * convergence_tol) * bias_sq)) if converged: break # Evaluate the accumulated B-spline at full resolution directly, # using the same cubic basis as the coarse-grid fits. This replaces # the previous separable Catmull-Rom upsample of the coarse log-bias # field (different kernel -> ~2-3% spatial mismatch vs the ITK # reference, which evaluates the spline analytically on the fine # grid). # # The final stage materializes (log_bias_full, bias_field, corrected) # at full volume size. For large volumes that dwarfs GPU memory, so # we drop the fit-time intermediates first and stream the evaluation # in Z-tiles back to host. del log_v, log_bias, weights, mask_small, vol_small, bases, fit_precomputed, sharpen_mask_weights if on_gpu: import cupy as cp cp.get_default_memory_pool().free_all_blocks() else: cp = None full_bases = ( _build_axis_basis(full_shape[0], n_ctrl[0], xp), _build_axis_basis(full_shape[1], n_ctrl[1], xp), _build_axis_basis(full_shape[2], n_ctrl[2], xp), ) M_z_full, M_y_full, M_x_full = full_bases # Pick a Z-tile that keeps the per-tile working set small relative to # vol_xp (which we keep on device for the per-voxel division). Each # tile allocates ~3x its float32 size on GPU (log_bias_chunk, bias, # corrected). Aim for ~2 GB total per tile. tile_bytes_target = 2 * 1024**3 bytes_per_z = full_shape[1] * full_shape[2] * 4 * 3 z_tile = max(1, min(full_shape[0], tile_bytes_target // max(bytes_per_z, 1))) if out is None: corrected_host = np.empty(full_shape, dtype=np.float32) else: if out.shape != full_shape or out.dtype != np.float32: raise ValueError(f"out must be {full_shape} float32, got {out.shape} {out.dtype}") corrected_host = out if bias_out is None: bias_host = np.empty(full_shape, dtype=np.float32) else: if bias_out.shape != full_shape or bias_out.dtype != np.float32: raise ValueError(f"bias_out must be {full_shape} float32, got {bias_out.shape} {bias_out.dtype}") bias_host = bias_out for z0 in range(0, full_shape[0], z_tile): z1 = min(z0 + z_tile, full_shape[0]) log_bias_chunk = bspline_evaluate( coeff_total, target_shape=(z1 - z0, full_shape[1], full_shape[2]), use_gpu=use_gpu, bases=(M_z_full[z0:z1], M_y_full, M_x_full), ) bias_chunk = xp.exp(log_bias_chunk).astype(xp.float32) del log_bias_chunk corrected_chunk = (vol_xp[z0:z1] / xp.maximum(bias_chunk, 1e-6)).astype(xp.float32) if on_gpu: corrected_host[z0:z1] = cp.asnumpy(corrected_chunk) bias_host[z0:z1] = cp.asnumpy(bias_chunk) else: corrected_host[z0:z1] = corrected_chunk bias_host[z0:z1] = bias_chunk del bias_chunk, corrected_chunk if on_gpu: cp.get_default_memory_pool().free_all_blocks() return corrected_host, bias_host