Source code for linumpy.gpu.fft_ops

"""
GPU-accelerated FFT operations for linumpy.

Provides GPU versions of FFT-based operations including phase correlation
for image registration and stitching.
"""

from typing import Any

import numpy as np

from . import GPU_AVAILABLE, to_cpu


[docs] def phase_correlation(vol1: Any, vol2: Any, n_peaks: Any = 8, use_gpu: Any = True) -> Any: """ GPU-accelerated phase correlation for finding translation between images. Parameters ---------- vol1 : np.ndarray Fixed image (2D or 3D) vol2 : np.ndarray Moving image (2D or 3D) n_peaks : int Number of peaks to sample for refinement use_gpu : bool Whether to use GPU acceleration Returns ------- list Translation [dx, dy] or [dx, dy, dz] of vol2 relative to vol1 float Cross-correlation score """ if use_gpu and GPU_AVAILABLE: return _phase_correlation_gpu(vol1, vol2, n_peaks) else: return _phase_correlation_cpu(vol1, vol2, n_peaks)
def _phase_correlation_gpu(vol1: Any, vol2: Any, n_peaks: Any = 8) -> Any: """GPU implementation of phase correlation.""" import cupy as cp vol_shape = vol1.shape ndim = vol1.ndim # Transfer to GPU vol1_gpu = cp.asarray(vol1, dtype=cp.float32) vol2_gpu = cp.asarray(vol2, dtype=cp.float32) # Extend images by 1/4 of their size (padding) new_shape = tuple(int(s * 1.25) for s in vol_shape) pad_size = tuple((int(np.ceil(0.5 * (n - s))),) * 2 for s, n in zip(vol_shape, new_shape, strict=False)) vol1_p = cp.pad(vol1_gpu, pad_size, mode="reflect") vol2_p = cp.pad(vol2_gpu, pad_size, mode="reflect") # Apply Hanning window vol1_p = _apply_hanning_window_gpu(vol1_p, [p[0] for p in pad_size]) vol2_p = _apply_hanning_window_gpu(vol2_p, [p[0] for p in pad_size]) # Phase correlation using cuFFT if ndim == 2: fft_func = cp.fft.fft2 ifft_func = cp.fft.ifft2 else: fft_func = cp.fft.fftn ifft_func = cp.fft.ifftn q_num = fft_func(vol2_p) * cp.conj(fft_func(vol1_p)) q_denum = cp.abs(q_num) # Avoid division by zero q_freq = cp.where(q_denum > 1e-10, q_num / q_denum, 0) q = ifft_func(q_freq) q_abs = cp.abs(q) # Find peaks from cupyx.scipy.ndimage import maximum_filter # Local maxima detection local_max = maximum_filter(q_abs, size=3) _peaks_mask = q_abs == local_max # Get top n_peaks flat_indices = cp.argsort(q_abs.ravel())[-n_peaks:] coordinates = cp.unravel_index(flat_indices, q_abs.shape) coordinates = cp.stack(coordinates, axis=1) # Try all translation permutations best_translation = None best_score = -1 coordinates_cpu = to_cpu(coordinates) vol1_cpu = to_cpu(vol1_gpu) vol2_cpu = to_cpu(vol2_gpu) for indices in coordinates_cpu: deltas = [] for idx, s in zip(indices, vol1_p.shape, strict=False): deltas.append(int(-idx + s / 2)) # Check bounds for ii in range(len(deltas)): if abs(deltas[ii]) > vol_shape[ii]: deltas[ii] -= int(np.sign(deltas[ii]) * vol_shape[ii]) # Generate candidate translations if ndim == 2: dx, dy = deltas candidates = [ [dx, dy], [dx - int(np.sign(dx) * vol1_p.shape[0] / 2), dy], [dx, dy - int(np.sign(dy) * vol1_p.shape[1] / 2)], [dx - int(np.sign(dx) * vol1_p.shape[0] / 2), dy - int(np.sign(dy) * vol1_p.shape[1] / 2)], ] else: dx, dy, dz = deltas nxp = int(np.sign(dx) * vol1_p.shape[0] / 2) nyp = int(np.sign(dy) * vol1_p.shape[1] / 2) nzp = int(np.sign(dz) * vol1_p.shape[2] / 2) candidates = [ [dx, dy, dz], [dx - nxp, dy, dz], [dx, dy - nyp, dz], [dx - nxp, dy - nyp, dz], [dx, dy, dz - nzp], [dx, dy - nyp, dz - nzp], [dx - nxp, dy, dz - nzp], [dx - nxp, dy - nyp, dz - nzp], ] for trans in candidates: score = _compute_correlation_score(vol1_cpu, vol2_cpu, trans) if score > best_score: best_score = score best_translation = trans return best_translation, best_score def _apply_hanning_window_gpu(vol: Any, pad_sizes: Any) -> Any: """Apply Hanning window on GPU.""" import cupy as cp ndim = vol.ndim result = vol.copy() for axis, pad in enumerate(pad_sizes): if pad <= 0: continue s = vol.shape[axis] h = cp.hanning(pad * 2) h_full = cp.ones(s) h_full[:pad] = h[:pad] h_full[-pad:] = h[pad:] # Reshape for broadcasting shape = [1] * ndim shape[axis] = s h_full = h_full.reshape(shape) result = result * h_full return result def _compute_correlation_score(vol1: Any, vol2: Any, translation: Any) -> Any: """Compute normalized cross-correlation score for a translation.""" # Compute overlap region slices1 = [] slices2 = [] for i, t in enumerate(translation): t = int(t) if t >= 0: slices1.append(slice(t, None)) slices2.append(slice(None, vol2.shape[i] - t if t > 0 else None)) else: slices1.append(slice(None, vol1.shape[i] + t)) slices2.append(slice(-t, None)) try: ov1 = vol1[tuple(slices1)] ov2 = vol2[tuple(slices2)] if ov1.size == 0 or ov2.size == 0: return 0 # Normalized cross-correlation ov1_norm = ov1 - np.mean(ov1) ov2_norm = ov2 - np.mean(ov2) std1 = np.std(ov1_norm) std2 = np.std(ov2_norm) if std1 < 1e-10 or std2 < 1e-10: return 0 return float(np.mean(ov1_norm * ov2_norm) / (std1 * std2)) except Exception: return 0 def _phase_correlation_cpu(vol1: Any, vol2: Any, n_peaks: Any = 8) -> Any: """CPU fallback for phase correlation - calls existing implementation.""" from linumpy.registration.transforms import pair_wise_phase_correlation return pair_wise_phase_correlation(vol1, vol2, n_peaks=n_peaks, return_cc=True)
[docs] def fft2(image: Any, use_gpu: Any = True) -> Any: """ GPU-accelerated 2D FFT. Parameters ---------- image : np.ndarray Input 2D image use_gpu : bool Whether to use GPU Returns ------- np.ndarray FFT result (complex) """ if use_gpu and GPU_AVAILABLE: import cupy as cp img_gpu = cp.asarray(image) result = cp.fft.fft2(img_gpu) return to_cpu(result) else: return np.fft.fft2(image)
[docs] def ifft2(spectrum: Any, use_gpu: Any = True) -> Any: """ GPU-accelerated 2D inverse FFT. Parameters ---------- spectrum : np.ndarray Input spectrum (complex) use_gpu : bool Whether to use GPU Returns ------- np.ndarray Inverse FFT result """ if use_gpu and GPU_AVAILABLE: import cupy as cp spec_gpu = cp.asarray(spectrum) result = cp.fft.ifft2(spec_gpu) return to_cpu(result) else: return np.fft.ifft2(spectrum)