Source code for linumpy.gpu.corrections
"""GPU-accelerated correction operations for linumpy."""
from typing import Any
import numpy as np
from . import GPU_AVAILABLE, to_cpu
[docs]
def fix_galvo_shift(volume: Any, shift: Any, axis: Any = 1, use_gpu: Any = True) -> Any:
"""
GPU-accelerated galvo shift correction.
Parameters
----------
volume : np.ndarray
Input volume
shift : int
Shift amount in pixels
axis : int
Axis along which to shift
use_gpu : bool
Whether to use GPU
Returns
-------
np.ndarray
Corrected volume
"""
if shift == 0:
return volume
if use_gpu and GPU_AVAILABLE:
import cupy as cp
vol_gpu = cp.asarray(volume)
result = cp.roll(vol_gpu, shift, axis=axis)
return to_cpu(result)
else:
return np.roll(volume, shift, axis=axis)
[docs]
def detect_and_fix_galvo_shift(
volume: Any, n_pixel_return: Any = 40, threshold: Any = 0.5, axis: Any = 1, use_gpu: Any = True
) -> Any:
"""
Detect and conditionally fix galvo shift.
Note: Detection uses CPU (GPU offers no benefit). Only the fix uses GPU.
Parameters
----------
volume : np.ndarray
Input volume (3D)
n_pixel_return : int
Number of pixels in galvo return region
threshold : float
Confidence threshold for applying fix (default 0.5, higher = more conservative)
axis : int
A-line axis
use_gpu : bool
Whether to use GPU for the fix operation
Returns
-------
np.ndarray
Corrected volume (or original if no fix needed)
dict
Detection results with 'shift', 'confidence', 'fixed' keys
"""
from linumpy.geometry.galvo import detect_galvo_shift
# Compute AIP
aip = np.mean(volume, axis=0)
# Detect shift using CPU (GPU offers no benefit for detection)
shift, confidence = detect_galvo_shift(aip, n_pixel_return)
result = {"shift": shift, "confidence": confidence, "fixed": False}
if confidence >= threshold:
volume = fix_galvo_shift(volume, shift, axis=axis, use_gpu=use_gpu)
result["fixed"] = True
return volume, result