"""Methods to download data from the Allen Institute."""
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any
import numpy as np
import requests
import SimpleITK as sitk
from tqdm import tqdm
[docs]
AVAILABLE_RESOLUTIONS = [10, 25, 50, 100]
[docs]
def numpy_to_sitk_image(volume: np.ndarray, spacing: tuple | Sequence, cast_dtype: type | None = None) -> sitk.Image:
"""Convert numpy array (Z, Y, X) to SimpleITK image format.
Parameters
----------
volume : np.ndarray
3D volume with shape (Z, Y, X) matching the project-wide convention
(axis 0 = Z/depth, axis 1 = Y/row, axis 2 = X/column).
spacing : tuple
Voxel spacing in mm as (res_z, res_y, res_x).
cast_dtype : numpy dtype or None
If provided, cast the volume to this dtype before creating the SITK image
(useful for registration where float32 is expected). If None, preserve
the input numpy dtype.
Returns
-------
sitk.Image
SimpleITK image with proper spacing and orientation
"""
# sitk.GetImageFromArray interprets a numpy array with shape (Z, Y, X) as a
# SITK image with size (X, Y, Z), so no transpose is needed. The SITK call
# copies the buffer into its own storage, so we only allocate an extra
# numpy array when an explicit dtype cast is requested.
vol_for_sitk = volume.astype(cast_dtype, copy=False) if cast_dtype is not None else volume
vol_sitk = sitk.GetImageFromArray(vol_for_sitk)
# Spacing: SimpleITK uses (X, Y, Z) = (width, height, depth).
# Our spacing is (res_z, res_y, res_x), so SITK spacing is (res_x, res_y, res_z).
vol_sitk.SetSpacing([spacing[2], spacing[1], spacing[0]])
vol_sitk.SetOrigin([0, 0, 0])
vol_sitk.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
return vol_sitk
[docs]
def download_template(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image:
"""Download a 3D average mouse brain.
Parameters
----------
resolution
Allen template resolution in micron. Must be 10, 25, 50 or 100.
cache
Keep the downloaded volume in cache
cache_dir
Cache directory
Returns
-------
Allen average mouse brain.
"""
assert resolution in AVAILABLE_RESOLUTIONS
# Preparing the cache directory
output = Path(cache_dir)
output.mkdir(exist_ok=True, parents=True)
# Preparing the filenames
nrrd_file = output / f"allen_template_{resolution}um.nrrd"
# Preparing the request
url = f"http://download.alleninstitute.org/informatics-archive/current-release/mouse_ccf/average_template/average_template_{int(resolution)}.nrrd"
# Check that the data is in cache
if not (nrrd_file.is_file()):
# Download the template
response = requests.get(url, stream=True)
with Path(nrrd_file).open("wb") as f:
for data in tqdm(response.iter_content()):
f.write(data)
# Loading the nrrd file
vol = sitk.ReadImage(str(nrrd_file))
# Remove the file from cache
if not cache:
nrrd_file.unlink() # Removes the nrrd file
return vol
[docs]
def download_template_ras_aligned(resolution: int, cache: bool = True, cache_dir: str = ".data/") -> sitk.Image:
"""Download a 3D average mouse brain and align it to RAS+ orientation.
The Allen CCF v3 template is stored in PIR orientation
(SITK axes ``(X, Y, Z) = (AP, DV, ML)`` with ``+X = Posterior``,
``+Y = Inferior``, ``+Z = Right``). Converting to RAS+
(``+X = Right``, ``+Y = Anterior``, ``+Z = Superior``) requires
``PermuteAxes((2, 0, 1))`` followed by flipping **both** the Y and Z
axes (I → S and P → A).
Parameters
----------
resolution
Allen template resolution in micron. Must be 10, 25, 50 or 100.
cache
Keep the downloaded volume in cache
cache_dir
Cache directory
Returns
-------
Allen average mouse brain in RAS+ orientation.
"""
vol = download_template(resolution, cache, cache_dir)
# Preparing the affine to align the template in the RAS+
r_mm = resolution / 1e3 # Convert the resolution from micron to mm
vol.SetSpacing([r_mm] * 3) # Set the spacing in mm
# Ensure origin/direction are standardized so physical coordinates are stable
vol.SetOrigin([0.0, 0.0, 0.0])
vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
# Convert PIR → RAS:
# PermuteAxes((2, 0, 1)) maps (P, I, R) → (R, P, I)
# Flip Y (P → A) and Z (I → S) to reach (R, A, S).
vol = sitk.PermuteAxes(vol, (2, 0, 1))
vol = sitk.Flip(vol, (False, True, True))
# After permuting/flipping, also ensure origin/direction are identity/zero
vol.SetOrigin([0.0, 0.0, 0.0])
vol.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
return vol
[docs]
def register_3d_rigid_to_allen(
moving_image: np.ndarray,
moving_spacing: tuple,
allen_resolution: int = 100,
metric: str = "MI",
max_iterations: int = 1000,
verbose: bool = False,
progress_callback: Callable[[Any], None] | None = None,
initial_rotation_deg: tuple = (0.0, 0.0, 0.0),
) -> tuple:
"""Perform 3D rigid registration of a brain volume to the Allen atlas.
Parameters
----------
moving_image : np.ndarray
3D brain volume to register (shape: Z, Y, X)
moving_spacing : tuple
Voxel spacing in mm (res_z, res_y, res_x)
allen_resolution : int
Allen template resolution in micron (default: 100)
metric : str
Similarity metric: 'MI' (mutual information), 'MSE', 'CC' (correlation),
or 'AntsCC' (ANTS correlation)
max_iterations : int
Maximum number of iterations
verbose : bool
Print registration progress
progress_callback : callable, optional
Callback function called on each iteration with the registration method.
Function signature: callback(registration_method)
initial_rotation_deg : tuple, optional
Initial rotation in degrees (rx, ry, rz) applied before optimization.
Returns
-------
transform : sitk.Euler3DTransform
Rigid transform to align moving_image to Allen atlas
stop_condition : str
Optimizer stopping condition
error : float
Final registration metric value
"""
# Download and prepare Allen atlas in RAS orientation
allen_atlas = download_template_ras_aligned(allen_resolution, cache=True)
# If the moving image is coarser than the Allen atlas along any axis,
# downsample the atlas to match the moving resolution. The registration
# cost is dominated by the fixed (Allen) image size, so downsampling the
# atlas up-front is much cheaper than upsampling moving to a finer grid
# that carries no additional information.
moving_min_spacing_mm = min(moving_spacing)
allen_spacing_mm = allen_atlas.GetSpacing()
allen_min_spacing_mm = min(allen_spacing_mm)
if moving_min_spacing_mm > allen_min_spacing_mm * 1.2:
target_spacing_mm = float(moving_min_spacing_mm)
allen_size = allen_atlas.GetSize()
new_size = [max(1, round(sz * sp / target_spacing_mm)) for sz, sp in zip(allen_size, allen_spacing_mm, strict=False)]
ref = sitk.Image(new_size, allen_atlas.GetPixelIDValue())
ref.SetOrigin(allen_atlas.GetOrigin())
ref.SetDirection(allen_atlas.GetDirection())
ref.SetSpacing((target_spacing_mm,) * 3)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(ref)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0)
allen_atlas = resampler.Execute(allen_atlas)
if verbose:
print(
f"Downsampled Allen atlas to match moving spacing: "
f"{allen_spacing_mm} mm → {allen_atlas.GetSpacing()} mm, "
f"size {allen_size} → {allen_atlas.GetSize()}"
)
# Crop moving image to tissue bounding box to reduce volume size.
# Large motor drift during acquisition inflates the canvas with empty space,
# causing the Allen-domain resampling to clip away brain tissue. Cropping
# first keeps the volume compact so most of the brain survives resampling,
# giving the optimizer a much better cost-function landscape.
margin_voxels = 10
crop_origin_mm = (0.0, 0.0, 0.0) # physical offset in (Z, Y, X) order
nonzero_coords = np.nonzero(moving_image)
if len(nonzero_coords[0]) > 0:
bbox_slices = tuple(
slice(
max(0, int(dim.min()) - margin_voxels),
min(moving_image.shape[ax], int(dim.max()) + margin_voxels + 1),
)
for ax, dim in enumerate(nonzero_coords)
)
crop_origin_mm = (
bbox_slices[0].start * moving_spacing[0],
bbox_slices[1].start * moving_spacing[1],
bbox_slices[2].start * moving_spacing[2],
)
cropped = moving_image[bbox_slices]
if verbose:
print(f"Cropped tissue bounding box: {moving_image.shape} -> {cropped.shape}")
moving_image = cropped
# Convert moving image to SimpleITK format.
# Origin stays at (0,0,0) so the compact brain sits at the start of physical
# space and overlaps with the Allen atlas domain during resampling. The crop
# offset is added to the final transform's translation after registration so
# the transform remains valid for the original (uncropped) full volume.
moving_sitk = numpy_to_sitk_image(moving_image, moving_spacing)
# Compute a preliminary brain centre BEFORE any resampling.
# This is used as the fallback only when needs_resample=False (images already
# share the same physical space). When resampling IS needed, this value is
# overwritten below with the centroid of the clipped brain within the Allen
# domain, because the full-brain geometric centre can be tens of mm outside
# the Allen atlas extent and would produce a translation that maps every
# Allen voxel outside the resampled moving image buffer.
original_moving_size = moving_sitk.GetSize()
original_moving_center_idx = [s / 2.0 for s in original_moving_size]
original_moving_center = np.array(moving_sitk.TransformContinuousIndexToPhysicalPoint(original_moving_center_idx))
# Resample moving image to match Allen atlas spacing and size for better registration.
# NOTE: we deliberately keep the original moving center computed above so that the
# centre-aligned fallback initialisation is always correct even after resampling.
allen_spacing = allen_atlas.GetSpacing()
allen_size = allen_atlas.GetSize()
moving_spacing_sitk = moving_sitk.GetSpacing()
moving_size_sitk = moving_sitk.GetSize()
# Check if resampling is needed (if spacing differs significantly or sizes are very different)
spacing_ratio = np.array(allen_spacing) / np.array(moving_spacing_sitk)
size_ratio = np.array(allen_size, dtype=float) / np.array(moving_size_sitk, dtype=float)
# Resample if spacing differs by more than 10% or if volumes are very different sizes
needs_resample = np.any(np.abs(spacing_ratio - 1.0) > 0.1) or np.any(size_ratio < 0.5) or np.any(size_ratio > 2.0)
if needs_resample:
if verbose:
print(
f"Resampling moving image from {moving_spacing_sitk} mm, size {moving_size_sitk} "
f"to {allen_spacing} mm, size {allen_size}"
)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(allen_atlas)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0)
moving_sitk = resampler.Execute(moving_sitk)
# Recompute the effective brain centre from the RESAMPLED image.
# The pre-resampling centre can lie far outside the Allen domain (e.g. a
# large 25 µm brain whose geometric centre is at ~37 mm, while the Allen
# atlas only spans ~11 mm). Using that centre directly gives a translation
# of +31 mm, which maps every Allen voxel outside the moving image buffer.
# Instead, use the centroid of the non-zero (brain-tissue) voxels that
# survived the clipping into the Allen domain.
moving_arr = sitk.GetArrayFromImage(moving_sitk) # shape (Z, Y, X) in numpy
nonzero_idx = np.argwhere(moving_arr > 0) # rows are (z, y, x)
if len(nonzero_idx) > 0:
centroid_zyx = nonzero_idx.mean(axis=0)
# SITK index order is (x, y, z), reverse of numpy (z, y, x)
centroid_xyz = [float(centroid_zyx[2]), float(centroid_zyx[1]), float(centroid_zyx[0])]
original_moving_center = np.array(moving_sitk.TransformContinuousIndexToPhysicalPoint(centroid_xyz))
if verbose:
print(f"Resampled brain centroid (physical): {original_moving_center} mm")
# If all voxels are zero (brain entirely outside Allen domain), keep
# the pre-resampling centre and accept a potentially poor initialization.
# Normalize images for better registration
fixed_image = sitk.Normalize(allen_atlas)
moving_image_sitk = sitk.Normalize(moving_sitk)
if verbose:
print(f"Fixed (Allen) image: size={fixed_image.GetSize()}, spacing={fixed_image.GetSpacing()}")
print(f"Moving (brain) image: size={moving_image_sitk.GetSize()}, spacing={moving_image_sitk.GetSpacing()}")
# Initialize registration
registration_method = sitk.ImageRegistrationMethod()
# Set metric
# Note: For correlation-based metrics, negative values are possible
# The optimizer will maximize MI/CC and minimize MSE
if metric.upper() == "MI":
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
elif metric.upper() == "MSE":
registration_method.SetMetricAsMeanSquares()
elif metric.upper() == "CC":
registration_method.SetMetricAsCorrelation()
elif metric.upper() == "ANTSCC":
registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=20)
else:
raise ValueError(f"Unknown metric: {metric}. Choose from: MI, MSE, CC, AntsCC")
# Set metric sampling - use regular sampling for reproducibility and speed
registration_method.SetMetricSamplingStrategy(registration_method.REGULAR)
registration_method.SetMetricSamplingPercentage(0.25) # 25% of pixels is usually sufficient
# Set optimizer with conservative parameters
# Use smaller learning rate and steps to prevent overshooting
learning_rate = 0.5 # Smaller learning rate for stability
min_step = 0.0001
registration_method.SetOptimizerAsRegularStepGradientDescent(
learningRate=learning_rate,
minStep=min_step,
numberOfIterations=max_iterations,
relaxationFactor=0.5,
gradientMagnitudeTolerance=1e-8,
)
# Use physical shift for scaling - more appropriate for physical coordinate registration
# This computes scales based on how a 1mm shift affects the metric
registration_method.SetOptimizerScalesFromPhysicalShift()
# Multi-resolution approach - start coarse, refine progressively
# More levels for robustness
registration_method.SetShrinkFactorsPerLevel([8, 4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel([4, 2, 1, 0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
# Initialize rigid transform with guaranteed overlap.
# Use the ORIGINAL moving image centre (before any resampling) so that
# the centre-aligned fallback always produces a meaningful initial translation
# regardless of the resolution/size relationship between the two images.
initial_transform = sitk.Euler3DTransform()
# Calculate image centres in physical space
fixed_size = fixed_image.GetSize()
fixed_center_idx = [s / 2.0 for s in fixed_size]
fixed_center = np.array(fixed_image.TransformContinuousIndexToPhysicalPoint(fixed_center_idx))
# Translation to align brain centre with Allen centre (ensures initial overlap).
# ITK transform maps fixed→moving: T(p) = R(p - c) + c + t
# For identity rotation and c=fixed_center: T(fixed_center) = fixed_center + t
# We need T(fixed_center) = original_moving_center, so t = moving_center - fixed_center.
translation = tuple(original_moving_center - fixed_center)
# Set center of rotation to fixed image center
initial_transform.SetCenter(fixed_center)
# Convert initial rotation from degrees to radians
rx_rad = np.deg2rad(initial_rotation_deg[0])
ry_rad = np.deg2rad(initial_rotation_deg[1])
rz_rad = np.deg2rad(initial_rotation_deg[2])
# Set translation to align centers and apply initial rotation
initial_transform.SetTranslation(translation)
initial_transform.SetRotation(rx_rad, ry_rad, rz_rad)
if verbose:
print(f"Initial center alignment: fixed={fixed_center}, moving (original)={original_moving_center}")
print(f"Translation to align centers: {translation}")
if any(r != 0 for r in initial_rotation_deg):
print(f"Initial rotation (deg): {initial_rotation_deg}")
# Only try MOMENTS initialization if no initial rotation was specified
# (user-specified rotation takes precedence) and the image was NOT resampled
# into the Allen domain. After resampling, the brain occupies only a small
# corner of the 640³ Allen image; sitk.Normalize then gives the large
# zero-padded background a uniform negative value that dominates the
# centre-of-mass computation, producing translation ≈ 0 which places every
# sample point outside the brain buffer.
if all(r == 0 for r in initial_rotation_deg) and not needs_resample:
try:
# Use MOMENTS initialization which is more robust
init_transform = sitk.Euler3DTransform()
init_transform = sitk.CenteredTransformInitializer(
fixed_image, moving_image_sitk, init_transform, sitk.CenteredTransformInitializerFilter.MOMENTS
)
# Verify the initialized transform has reasonable translation
init_params = init_transform.GetParameters()
init_translation = np.array(init_params[3:6])
# Check if the initialized transform is reasonable (translation not too large)
# If translation is reasonable, use it; otherwise use our center-aligned one
translation_magnitude = np.linalg.norm(init_translation)
fixed_size_mm = np.array(fixed_image.GetSpacing()) * np.array(fixed_image.GetSize())
max_reasonable_translation = np.linalg.norm(fixed_size_mm) * 0.5 # Half the image size
if translation_magnitude < max_reasonable_translation:
initial_transform = init_transform
if verbose:
print(f"Using MOMENTS initialization (translation magnitude: {translation_magnitude:.2f} mm)")
else:
if verbose:
print(
f"MOMENTS initialization translation too large ({translation_magnitude:.2f} mm), using center-aligned"
)
except Exception as e:
if verbose:
print(f"MOMENTS initialization failed: {e}, using center-aligned translation")
if verbose:
final_params = initial_transform.GetParameters()
final_center = initial_transform.GetCenter()
print(f"Final initial transform: rotation={final_params[:3]}, translation={final_params[3:]}")
print(f"Transform center: {final_center}")
registration_method.SetInitialTransform(initial_transform)
registration_method.SetInterpolator(sitk.sitkLinear)
# Set up iteration callback
if verbose or progress_callback is not None:
def command_iteration(method: Any) -> None:
if verbose:
if method.GetOptimizerIteration() == 0:
print(f"Estimated scales: {method.GetOptimizerScales()}")
print(
f"Iteration {method.GetOptimizerIteration():3d} = "
f"{method.GetMetricValue():7.5f} : "
f"{method.GetOptimizerPosition()}"
)
if progress_callback is not None:
progress_callback(method)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
# Execute registration
final_transform = registration_method.Execute(fixed_image, moving_image_sitk)
stop_condition = registration_method.GetOptimizerStopConditionDescription()
error = registration_method.GetMetricValue()
if verbose:
print(f"Registration complete: {stop_condition}")
print(f"Final metric value: {error:.6f}")
final_params = final_transform.GetParameters()
print(f"Final transform: rotation={final_params[:3]}, translation={final_params[3:]}")
print(f"Fixed image size: {fixed_image.GetSize()}, spacing: {fixed_image.GetSpacing()}")
print(f"Moving image size: {moving_image_sitk.GetSize()}, spacing: {moving_image_sitk.GetSpacing()}")
# Restore crop offset in the translation so the transform is valid for the
# full original (uncropped) brain volume. Derivation:
# T(p) = R(p-c)+c+t maps Allen coords to cropped-brain coords (origin=0).
# Same tissue in full brain is at (cropped_coord + crop_origin_mm).
# So t_full = t_crop + crop_origin_sitk (center c cancels out).
if any(v != 0.0 for v in crop_origin_mm):
params = list(final_transform.GetParameters())
# SITK Euler3D params: (rx, ry, rz, tx, ty, tz) in SITK XYZ order
# numpy axis order (Z, Y, X) -> SITK (X, Y, Z):
params[3] += crop_origin_mm[2] # SITK X = numpy axis 2
params[4] += crop_origin_mm[1] # SITK Y = numpy axis 1
params[5] += crop_origin_mm[0] # SITK Z = numpy axis 0
final_transform.SetParameters(params)
if verbose:
print(
f"Adjusted translation for crop: +"
f"[{crop_origin_mm[2]:.3f}, {crop_origin_mm[1]:.3f}, {crop_origin_mm[0]:.3f}] mm (SITK XYZ)"
)
return final_transform, stop_condition, error