Source code for linumpy.gpu.bspline

"""Tensor-product cubic B-spline scattered-data approximation.

Provides a simple GPU/CPU primitive for fitting a smooth 3-D field to
scattered (weighted) voxel samples on a regular control-point lattice
and evaluating the resulting field at arbitrary voxel grids.

Used by :mod:`linumpy.gpu.n4` for the bias-field B-spline update step,
but kept generic so other smoothing/warp primitives can reuse it.

The fit implements the single-level Lee-Wolberg-Shin (1997) B-spline
approximation that ITK uses inside ``BSplineScatteredDataPointSetToImageFilter``
(the engine of N4).  For each scattered sample p with value v_p the
locally-optimal value at surrounding control point c is::

    phi_c(p) = w_c(p) * v_p / sum_d w_d(p)^2

and the per-control-point coefficient is the squared-weight average::

    coeff[c] = sum_p w_c(p)^2 * phi_c(p) / sum_p w_c(p)^2
             = sum_p gamma_p * w_c(p)^3 * v_p / S(p)
               -------------------------------------
                       sum_p gamma_p * w_c(p)^2

where ``S(p) = sum_d w_d(p)^2`` and gamma_p folds in the per-voxel
mask/weight.  Because the tensor-product basis is separable,
``w_c(p)^k`` factorises across axes and S(p) factorises into a product
of per-axis sums of squared basis weights, so the fit reduces to three
contiguous tensor contractions -- one through ``B^3`` for the numerator
and one through ``B^2`` for the denominator.  This matches the ITK
behaviour while remaining a single GPU-friendly tensordot chain.

An earlier implementation used a Nadaraya-Watson kernel regression
(``coeff[c] = sum_p w_c(p) * v_p / sum_p w_c(p)``).  That form has no
implicit smoothness penalty and, at the dense control grids reached by
later N4 fitting levels, lets the fit absorb tissue-scale features
(e.g. white-matter contrast) into the bias estimate.  PSDB's squared
weights regularise short-range support and recover the contrast.
"""

from __future__ import annotations

from typing import Any

import numpy as np

from linumpy.gpu import GPU_AVAILABLE, get_array_module


def _is_gpu_array(arr: Any) -> bool:
    """Return True if *arr* is a CuPy ndarray (so callers can keep results on GPU)."""
    try:
        import cupy as cp
    except ImportError:
        return False
    return isinstance(arr, cp.ndarray)


# ---------------------------------------------------------------------------
# Cubic B-spline basis
# ---------------------------------------------------------------------------


def _cubic_bspline_basis(t: Any, xp: Any) -> Any:
    """Return the four uniform cubic B-spline basis weights at offset *t*.

    Parameters
    ----------
    t : array-like
        Fractional offset(s) in [0, 1).  Any shape.
    xp : module
        Array module (numpy or cupy).

    Returns
    -------
    array
        Stack of shape ``t.shape + (4,)`` with weights ``[B0, B1, B2, B3]``.
        Weights sum to 1 along the last axis.
    """
    t = xp.asarray(t, dtype=xp.float32)
    t2 = t * t
    t3 = t2 * t
    one_m_t = 1.0 - t
    b0 = (one_m_t * one_m_t * one_m_t) / 6.0
    b1 = (3.0 * t3 - 6.0 * t2 + 4.0) / 6.0
    b2 = (-3.0 * t3 + 3.0 * t2 + 3.0 * t + 1.0) / 6.0
    b3 = t3 / 6.0
    return xp.stack([b0, b1, b2, b3], axis=-1)


# ---------------------------------------------------------------------------
# Coordinate mapping
# ---------------------------------------------------------------------------


def _voxel_to_control_coords(n_voxels: int, n_control: int, xp: Any) -> Any:
    """Map ``[0, n_voxels-1]`` voxel indices to control-grid coordinates.

    Voxel 0 maps to control coordinate 0; voxel ``n_voxels - 1`` maps to
    ``n_control - 3``.  This leaves one control-point of padding on each
    side so the 4-tap cubic B-spline kernel has full support at the
    boundaries.
    """
    if n_voxels == 1:
        return xp.zeros(1, dtype=xp.float32)
    span = float(n_control - 3)
    if span <= 0:
        raise ValueError(f"n_control={n_control} too small; need at least 4 control points to host a cubic B-spline.")
    return xp.arange(n_voxels, dtype=xp.float32) * (span / float(n_voxels - 1))


# ---------------------------------------------------------------------------
# Per-axis basis matrix
# ---------------------------------------------------------------------------


def _build_axis_basis(n_voxels: int, n_control: int, xp: Any) -> Any:
    """Return the dense (n_voxels, n_control) cubic-B-spline basis matrix.

    Row ``i`` contains exactly four non-zero entries -- the four basis
    weights at offsets ``-1, 0, 1, 2`` around ``floor(u_i)``, with OOB
    stencil indices clamped to ``[0, n_control - 1]`` (boundary
    partition-of-unity preservation, matching the original scattered
    formulation).

    The matrix is small (axes are at most a few hundred voxels by a few
    dozen control points) so a dense layout is cheap and lets us turn
    the fit/evaluate into three contiguous tensor contractions.
    """
    u = _voxel_to_control_coords(n_voxels, n_control, xp)
    iu = xp.floor(u).astype(xp.int32)
    t = u - iu.astype(xp.float32)
    b = _cubic_bspline_basis(t, xp)  # (n_voxels, 4)

    M = xp.zeros((n_voxels, n_control), dtype=xp.float32)
    rows = xp.arange(n_voxels, dtype=xp.int32)
    for d in range(4):
        cols = xp.clip(iu + (d - 1), 0, n_control - 1)
        # Multiple stencil offsets may map to the same column at the
        # boundary; accumulate so partition-of-unity is preserved.
        if xp is np:
            np.add.at(M, (rows, cols), b[:, d])
        else:
            xp.add.at(M, (rows, cols), b[:, d])
    return M


# ---------------------------------------------------------------------------
# Fit
# ---------------------------------------------------------------------------


[docs] def bspline_fit_precompute( bases: tuple[Any, Any, Any], *, eps: float = 1e-8, ) -> tuple[Any, Any, Any, Any, Any, Any, Any]: """Build the iteration-invariant constants used by :func:`bspline_fit`. The squared/cubed per-axis basis matrices and the separable per-voxel denominator ``S(p) = (sum_c M_z[z,c]^2)(sum_c M_y[y,c]^2)(sum_c M_x[x,c]^2)`` depend only on *bases*, so callers that issue many fits at the same shape (e.g. the N4 fitting loop) can build them once and pass them in via :func:`bspline_fit`'s ``precomputed`` argument. Returns ``(M_z2, M_y2, M_x2, M_z3, M_y3, M_x3, S_safe)`` where ``S_safe = maximum(S, eps)`` -- the per-iteration ``maximum`` against *eps* is folded into the precompute. """ M_z, M_y, M_x = bases xp = get_array_module(use_gpu=_is_gpu_array(M_z)) M_z2 = M_z * M_z M_y2 = M_y * M_y M_x2 = M_x * M_x M_z3 = M_z2 * M_z M_y3 = M_y2 * M_y M_x3 = M_x2 * M_x s_z = M_z2.sum(axis=1) s_y = M_y2.sum(axis=1) s_x = M_x2.sum(axis=1) # Pre-clamp S to >= eps so :func:`bspline_fit` skips a per-call # full-volume ``maximum`` op. S_safe = xp.maximum( s_z[:, None, None] * s_y[None, :, None] * s_x[None, None, :], eps, ).astype(xp.float32) return M_z2, M_y2, M_x2, M_z3, M_y3, M_x3, S_safe
[docs] def bspline_fit( values: np.ndarray, weights: np.ndarray | None, mask: np.ndarray | None, n_control_points: tuple[int, int, int], *, use_gpu: bool = True, eps: float = 1e-8, bases: tuple[Any, Any, Any] | None = None, precomputed: tuple[Any, Any, Any, Any, Any, Any, Any] | None = None, ) -> np.ndarray: """Fit a tensor-product cubic B-spline to scattered voxel samples. Parameters ---------- values : np.ndarray Sample values, shape (Z, Y, X), float32. weights : np.ndarray or None Per-voxel non-negative weights (same shape). ``None`` = all ones. mask : np.ndarray or None Boolean mask selecting which voxels participate in the fit. ``None`` = all voxels. n_control_points : tuple of int Control-grid size ``(Cz, Cy, Cx)``. Each value must be ``>= 4``. use_gpu : bool Use CuPy when available; falls back to NumPy. eps : float Floor on the kernel-weight denominator to avoid division by zero for control points with no support. bases : tuple of arrays, optional Pre-built per-axis basis matrices ``(M_z, M_y, M_x)`` from :func:`_build_axis_basis` matching ``values.shape`` and ``n_control_points``. When provided, skips the per-call build; useful when the caller (e.g. an N4 fitting level) issues many fits at the same shape. precomputed : tuple of arrays, optional Output of :func:`bspline_fit_precompute` for the same *bases*. When provided, skips rebuilding the squared/cubed bases and the separable denominator ``S`` -- a per-iteration full-volume allocation in the N4 fit loop. Returns ------- np.ndarray Control coefficients, shape ``n_control_points``, float32 NumPy array (always returned on the host). """ if values.ndim != 3: raise ValueError(f"values must be 3-D, got shape {values.shape}") cz, cy, cx = n_control_points if min(cz, cy, cx) < 4: raise ValueError(f"n_control_points must each be >= 4, got {n_control_points}") xp = get_array_module(use_gpu=use_gpu and GPU_AVAILABLE) vals = xp.asarray(values, dtype=xp.float32) w = xp.ones_like(vals) if weights is None else xp.asarray(weights, dtype=xp.float32) if mask is not None: w = w * xp.asarray(mask, dtype=xp.float32) z_n, y_n, x_n = vals.shape # Build dense per-axis basis matrices: M_axis[i, c] is the cubic # B-spline weight that voxel ``i`` deposits onto control point ``c``. # The 3-D scattered-data fit factorises along axes because the basis # is separable, so the whole accumulation is three contiguous tensor # contractions instead of 64 scatter-adds. Bases can be precomputed # by the caller (e.g. once per N4 level) and reused across many # fit/evaluate calls to avoid rebuilding the same small matrices. if bases is None: M_z = _build_axis_basis(z_n, cz, xp) M_y = _build_axis_basis(y_n, cy, xp) M_x = _build_axis_basis(x_n, cx, xp) else: M_z, M_y, M_x = bases # PSDB: separable tensor-product implementation of the Lee-Wolberg-Shin # single-level scattered-data B-spline approximation. # # coeff[c] = sum_p gamma_p * w_c(p)^3 * v_p / S(p) # ------------------------------------------ # sum_p gamma_p * w_c(p)^2 # # Squared and cubed per-axis basis matrices fold the per-control-point # weight powers into separable contractions. S(p) factorises as the # product of per-axis sums of squared basis weights. These derive only # from the bases, so an N4 fitting loop can build them once via # :func:`bspline_fit_precompute` and pass them in. if precomputed is None: M_z2, M_y2, M_x2, M_z3, M_y3, M_x3, S_safe = bspline_fit_precompute((M_z, M_y, M_x), eps=eps) else: M_z2, M_y2, M_x2, M_z3, M_y3, M_x3, S_safe = precomputed psi = (w * vals) / S_safe # (Z, Y, X) # num[Cz, Cy, Cx] = sum_{z,y,x} M_z3[z,Cz] M_y3[y,Cy] M_x3[x,Cx] * psi num = xp.tensordot(psi, M_x3, axes=([2], [0])) # (Nz, Ny, Cx) num = xp.tensordot(num, M_y3, axes=([1], [0])) # (Nz, Cx, Cy) num = xp.tensordot(num, M_z3, axes=([0], [0])) # (Cx, Cy, Cz) num = xp.transpose(num, (2, 1, 0)) # (Cz, Cy, Cx) # den[Cz, Cy, Cx] = sum_{z,y,x} M_z2[z,Cz] M_y2[y,Cy] M_x2[x,Cx] * w den = xp.tensordot(w, M_x2, axes=([2], [0])) den = xp.tensordot(den, M_y2, axes=([1], [0])) den = xp.tensordot(den, M_z2, axes=([0], [0])) den = xp.transpose(den, (2, 1, 0)) coeff = (num / xp.maximum(den, eps)).astype(xp.float32) # Preserve caller's array module: cupy in -> cupy out, numpy in -> numpy out. if _is_gpu_array(values): return coeff if xp is np: return coeff import cupy as cp return cp.asnumpy(coeff).astype(np.float32)
# --------------------------------------------------------------------------- # Evaluate # ---------------------------------------------------------------------------
[docs] def bspline_evaluate( control_coeffs: np.ndarray, target_shape: tuple[int, int, int], *, use_gpu: bool = True, bases: tuple[Any, Any, Any] | None = None, ) -> np.ndarray: """Evaluate a cubic B-spline given control coefficients on a regular grid. Inverse of :func:`bspline_fit`'s coordinate mapping: target voxel 0 maps to control coordinate 0; target voxel ``N - 1`` maps to ``Cn - 3``. Parameters ---------- control_coeffs : np.ndarray Control-grid coefficients, shape ``(Cz, Cy, Cx)``. target_shape : tuple of int Output volume shape ``(Z, Y, X)``. use_gpu : bool Use CuPy when available. bases : tuple of arrays, optional Pre-built per-axis basis matrices ``(M_z, M_y, M_x)`` matching ``target_shape`` and ``control_coeffs.shape``. When provided, skips the per-call build. Returns ------- np.ndarray Evaluated field, shape ``target_shape``, float32. """ xp = get_array_module(use_gpu=use_gpu and GPU_AVAILABLE) coeff = xp.asarray(control_coeffs, dtype=xp.float32) cz, cy, cx = coeff.shape z_n, y_n, x_n = target_shape if bases is None: M_z = _build_axis_basis(z_n, cz, xp) # (Nz, Cz) M_y = _build_axis_basis(y_n, cy, xp) M_x = _build_axis_basis(x_n, cx, xp) else: M_z, M_y, M_x = bases # out[z, y, x] = sum_{Z,Y,X} M_z[z,Z] M_y[y,Y] M_x[x,X] * coeff[Z,Y,X] out = xp.tensordot(coeff, M_x, axes=([2], [1])) # (Cz, Cy, Nx) out = xp.tensordot(out, M_y, axes=([1], [1])) # (Cz, Nx, Ny) out = xp.tensordot(out, M_z, axes=([0], [1])) # (Nx, Ny, Nz) out = xp.transpose(out, (2, 1, 0)).astype(xp.float32) # (Nz, Ny, Nx) if _is_gpu_array(control_coeffs): return out if xp is np: return out import cupy as cp return cp.asnumpy(out).astype(np.float32)