"""
GPU-accelerated FFT operations for linumpy.
Provides GPU versions of FFT-based operations including phase correlation
for image registration and stitching.
"""
from typing import Any
import numpy as np
from . import GPU_AVAILABLE, to_cpu
[docs]
def phase_correlation(vol1: Any, vol2: Any, n_peaks: Any = 8, use_gpu: Any = True) -> Any:
"""
GPU-accelerated phase correlation for finding translation between images.
Parameters
----------
vol1 : np.ndarray
Fixed image (2D or 3D)
vol2 : np.ndarray
Moving image (2D or 3D)
n_peaks : int
Number of peaks to sample for refinement
use_gpu : bool
Whether to use GPU acceleration
Returns
-------
list
Translation [dx, dy] or [dx, dy, dz] of vol2 relative to vol1
float
Cross-correlation score
"""
if use_gpu and GPU_AVAILABLE:
return _phase_correlation_gpu(vol1, vol2, n_peaks)
else:
return _phase_correlation_cpu(vol1, vol2, n_peaks)
def _phase_correlation_gpu(vol1: Any, vol2: Any, n_peaks: Any = 8) -> Any:
"""GPU implementation of phase correlation."""
import cupy as cp
vol_shape = vol1.shape
ndim = vol1.ndim
# Transfer to GPU
vol1_gpu = cp.asarray(vol1, dtype=cp.float32)
vol2_gpu = cp.asarray(vol2, dtype=cp.float32)
# Extend images by 1/4 of their size (padding)
new_shape = tuple(int(s * 1.25) for s in vol_shape)
pad_size = tuple((int(np.ceil(0.5 * (n - s))),) * 2 for s, n in zip(vol_shape, new_shape, strict=False))
vol1_p = cp.pad(vol1_gpu, pad_size, mode="reflect")
vol2_p = cp.pad(vol2_gpu, pad_size, mode="reflect")
# Apply Hanning window
vol1_p = _apply_hanning_window_gpu(vol1_p, [p[0] for p in pad_size])
vol2_p = _apply_hanning_window_gpu(vol2_p, [p[0] for p in pad_size])
# Phase correlation using cuFFT
if ndim == 2:
fft_func = cp.fft.fft2
ifft_func = cp.fft.ifft2
else:
fft_func = cp.fft.fftn
ifft_func = cp.fft.ifftn
q_num = fft_func(vol2_p) * cp.conj(fft_func(vol1_p))
q_denum = cp.abs(q_num)
# Avoid division by zero
q_freq = cp.where(q_denum > 1e-10, q_num / q_denum, 0)
q = ifft_func(q_freq)
q_abs = cp.abs(q)
# Find peaks
from cupyx.scipy.ndimage import maximum_filter
# Local maxima detection
local_max = maximum_filter(q_abs, size=3)
_peaks_mask = q_abs == local_max
# Get top n_peaks
flat_indices = cp.argsort(q_abs.ravel())[-n_peaks:]
coordinates = cp.unravel_index(flat_indices, q_abs.shape)
coordinates = cp.stack(coordinates, axis=1)
# Try all translation permutations
best_translation = None
best_score = -1
coordinates_cpu = to_cpu(coordinates)
vol1_cpu = to_cpu(vol1_gpu)
vol2_cpu = to_cpu(vol2_gpu)
for indices in coordinates_cpu:
deltas = []
for idx, s in zip(indices, vol1_p.shape, strict=False):
deltas.append(int(-idx + s / 2))
# Check bounds
for ii in range(len(deltas)):
if abs(deltas[ii]) > vol_shape[ii]:
deltas[ii] -= int(np.sign(deltas[ii]) * vol_shape[ii])
# Generate candidate translations
if ndim == 2:
dx, dy = deltas
candidates = [
[dx, dy],
[dx - int(np.sign(dx) * vol1_p.shape[0] / 2), dy],
[dx, dy - int(np.sign(dy) * vol1_p.shape[1] / 2)],
[dx - int(np.sign(dx) * vol1_p.shape[0] / 2), dy - int(np.sign(dy) * vol1_p.shape[1] / 2)],
]
else:
dx, dy, dz = deltas
nxp = int(np.sign(dx) * vol1_p.shape[0] / 2)
nyp = int(np.sign(dy) * vol1_p.shape[1] / 2)
nzp = int(np.sign(dz) * vol1_p.shape[2] / 2)
candidates = [
[dx, dy, dz],
[dx - nxp, dy, dz],
[dx, dy - nyp, dz],
[dx - nxp, dy - nyp, dz],
[dx, dy, dz - nzp],
[dx, dy - nyp, dz - nzp],
[dx - nxp, dy, dz - nzp],
[dx - nxp, dy - nyp, dz - nzp],
]
for trans in candidates:
score = _compute_correlation_score(vol1_cpu, vol2_cpu, trans)
if score > best_score:
best_score = score
best_translation = trans
return best_translation, best_score
def _apply_hanning_window_gpu(vol: Any, pad_sizes: Any) -> Any:
"""Apply Hanning window on GPU."""
import cupy as cp
ndim = vol.ndim
result = vol.copy()
for axis, pad in enumerate(pad_sizes):
if pad <= 0:
continue
s = vol.shape[axis]
h = cp.hanning(pad * 2)
h_full = cp.ones(s)
h_full[:pad] = h[:pad]
h_full[-pad:] = h[pad:]
# Reshape for broadcasting
shape = [1] * ndim
shape[axis] = s
h_full = h_full.reshape(shape)
result = result * h_full
return result
def _compute_correlation_score(vol1: Any, vol2: Any, translation: Any) -> Any:
"""Compute normalized cross-correlation score for a translation."""
# Compute overlap region
slices1 = []
slices2 = []
for i, t in enumerate(translation):
t = int(t)
if t >= 0:
slices1.append(slice(t, None))
slices2.append(slice(None, vol2.shape[i] - t if t > 0 else None))
else:
slices1.append(slice(None, vol1.shape[i] + t))
slices2.append(slice(-t, None))
try:
ov1 = vol1[tuple(slices1)]
ov2 = vol2[tuple(slices2)]
if ov1.size == 0 or ov2.size == 0:
return 0
# Normalized cross-correlation
ov1_norm = ov1 - np.mean(ov1)
ov2_norm = ov2 - np.mean(ov2)
std1 = np.std(ov1_norm)
std2 = np.std(ov2_norm)
if std1 < 1e-10 or std2 < 1e-10:
return 0
return float(np.mean(ov1_norm * ov2_norm) / (std1 * std2))
except Exception:
return 0
def _phase_correlation_cpu(vol1: Any, vol2: Any, n_peaks: Any = 8) -> Any:
"""CPU fallback for phase correlation - calls existing implementation."""
from linumpy.registration.transforms import pair_wise_phase_correlation
return pair_wise_phase_correlation(vol1, vol2, n_peaks=n_peaks, return_cc=True)
[docs]
def fft2(image: Any, use_gpu: Any = True) -> Any:
"""
GPU-accelerated 2D FFT.
Parameters
----------
image : np.ndarray
Input 2D image
use_gpu : bool
Whether to use GPU
Returns
-------
np.ndarray
FFT result (complex)
"""
if use_gpu and GPU_AVAILABLE:
import cupy as cp
img_gpu = cp.asarray(image)
result = cp.fft.fft2(img_gpu)
return to_cpu(result)
else:
return np.fft.fft2(image)
[docs]
def ifft2(spectrum: Any, use_gpu: Any = True) -> Any:
"""
GPU-accelerated 2D inverse FFT.
Parameters
----------
spectrum : np.ndarray
Input spectrum (complex)
use_gpu : bool
Whether to use GPU
Returns
-------
np.ndarray
Inverse FFT result
"""
if use_gpu and GPU_AVAILABLE:
import cupy as cp
spec_gpu = cp.asarray(spectrum)
result = cp.fft.ifft2(spec_gpu)
return to_cpu(result)
else:
return np.fft.ifft2(spectrum)