"""ThorOCT module for handling OCT data from ThorLabs PSOCT microscopes."""
import gc
import zipfile
from pathlib import Path
from typing import cast
from xml.dom.minidom import Text as DOMText
from xml.dom.minidom import parse
import numpy as np
[docs]
class PreprocessingConfig:
"""
Configuration for preprocessing OCT data.
Attributes
----------
return_complex (bool):
If True, return the raw complex data.
If False, return its magnitude instead.
crop_first_index (int, default=320):
Index along the first dimension (depth) at which to start cropping.
crop_second_index (int, default=750):
Index along the first dimension (depth) at which to end cropping.
erase_raw_data (bool, default=True):
If True, discard the original raw data array after extracting the complex data.
erase_polarization_1 (bool, default=False):
If True, do not keep polarization-1 data in memory.
erase_polarization_2 (bool, default=True):
If True, do not keep polarization-2 data in memory.
"""
[docs]
crop_first_index: int = 320
[docs]
crop_second_index: int = 750
[docs]
erase_raw_data: bool = True
[docs]
erase_polarization_1: bool = False
[docs]
erase_polarization_2: bool = True
[docs]
class ThorOCT:
"""Handle OCT data from ThorLabs PSOCT microscopes.
Provides methods to load, process, and extract metadata and data from compressed files.
Parameters
----------
path (str): Path to the compressed data file.
compressed_data (zipfile.ZipFile): ZipFile object containing the data.
config (PreprocessingConfig): Configuration for preprocessing.
Attributes
----------
first_polarization (np.ndarray): Data for the first polarization.
second_polarization (np.ndarray): Data for the second polarization.
size_x (int): X-dimension size of the data.
size_y (int): Y-dimension size of the data.
size_z (int): Z-dimension size of the data.
header (xml.dom.minidom.Document): Parsed header metadata.
resolution (list): Resolution values for X, Y, and Z dimensions.
ascan_averaging_value (int): Number of A-scans for averaging.
"""
def __init__(
self,
path: Path | None = None,
compressed_data: zipfile.ZipFile | None = None,
config: PreprocessingConfig | None = None,
) -> None:
"""Initialize the ThorOCT object."""
[docs]
self.compressed_data = compressed_data or (zipfile.ZipFile(path) if path else None)
[docs]
self.first_polarization: np.ndarray | None = None
[docs]
self.second_polarization: np.ndarray | None = None
[docs]
self.ascan_averaging_value = None
[docs]
def load(self) -> None:
"""
Load the data from the compressed file and extract the header and the complex data.
Raises
------
ValueError: If no valid data source is provided.
"""
if not self.compressed_data:
raise ValueError("No valid data source provided.")
assert self.config is not None
self._extract_oct_header()
self._extract_complex_dimensions()
self._load_polarized_data(
erase_polarization_1=self.config.erase_polarization_1,
erase_polarization_2=self.config.erase_polarization_2,
)
# If requested, erase the raw data source (compressed_data) itself
if self.config.erase_raw_data:
self.compressed_data = None
gc.collect() # Force garbage collection
def _extract_oct_header(self) -> None:
"""Load the OCT header metadata from the compressed file.
Raises
------
FileNotFoundError: If the header file is not found in the compressed data.
"""
try:
metadata_file = "Header.xml"
assert self.compressed_data is not None
with self.compressed_data.open(metadata_file) as f:
document = parse(f)
self.header = document
except KeyError as e:
raise FileNotFoundError(f"Error loading header: {e}") from e
def _extract_complex_dimensions(self) -> None:
"""
Extract dimensions and resolution values from the OCT header.
Raises
------
ValueError: If the header has not been loaded.
"""
if not self.header:
raise ValueError("Header must be loaded before extracting dimensions.")
# Find all DataFile elements
data_files = self.header.getElementsByTagName("DataFile")
# Get the <AScans> element
ascan_element = self.header.getElementsByTagName("AScans")[0]
# Extract its text content and convert to an integer
ascan_first_child = ascan_element.firstChild
assert ascan_first_child is not None
self.ascan_averaging_value = int(cast("DOMText", ascan_first_child).data.strip())
# Initialize variables to store found data
complex_data_file = None
# Loop through each DataFile element and check for the specific values
for data_file in data_files:
# Extract text content of the DataFile element
data_first_child = data_file.firstChild
assert data_first_child is not None
file_content = cast("DOMText", data_first_child).data
# Check for specific file paths
if file_content == "data\\Complex.data":
complex_data_file = data_file
break
if complex_data_file:
self.size_z = int(complex_data_file.getAttribute("SizeZ"))
self.size_x = int(complex_data_file.getAttribute("SizeX"))
self.size_y = int(complex_data_file.getAttribute("SizeY"))
range_z = float(complex_data_file.getAttribute("RangeZ"))
range_y = float(complex_data_file.getAttribute("RangeY"))
range_x = float(complex_data_file.getAttribute("RangeX"))
self.resolution = [
range_x / self.size_x,
range_y / self.size_y,
range_z / self.size_z,
]
def _load_polarized_data(self, erase_polarization_2: bool, erase_polarization_1: bool) -> None:
"""Load the polarization data from the compressed file.
Parameters
----------
erase_polarization_1 : bool
Whether to skip loading polarization 1 data.
erase_polarization_2 : bool
Whether to skip loading polarization 2 data.
Raises
------
FileNotFoundError: If required polarization data files are missing.
"""
try:
# Files for the polarization data
raw_data1_file = "data/Complex.data"
raw_data2_file = "data/Complex_Cam1.data"
# Load the data for polarization 1 and 2
if not erase_polarization_1:
self.first_polarization = self.load_and_process(file=raw_data1_file)
if not erase_polarization_2:
self.second_polarization = self.load_and_process(file=raw_data2_file)
except KeyError as e:
raise FileNotFoundError(f"Error loading polarization data: {e}") from e
def _stack_tiles_vertically(self, data: np.ndarray) -> np.ndarray:
"""
Stacks x tiles on top of each other along the y-axis according to the ascan_averaging_value.
eg. If ascan_averaging_value = 3,
the first 3 tiles will be stacked on top of each other, then the next 3, and so on.
[0, 1, 2] -> [0] = |2|
|1|
|0|
Parameters:
data (np.ndarray): The input 3D array (SizeZ, SizeX, SizeY).
Returns
-------
np.ndarray: The 3D array with tiles stacked along the y-axis.
"""
assert self.ascan_averaging_value is not None
assert self.size_x is not None and self.size_y is not None
# Ensure the number of tiles is divisible by ascan_averaging_value
if data.shape[0] % self.ascan_averaging_value != 0:
raise ValueError(
f"The number of tiles ({data.shape[0]}) must be divisible by "
f"ascan_averaging_value ({self.ascan_averaging_value})."
)
stacked_data = []
for i in range(0, data.shape[0], self.ascan_averaging_value):
stacked_tile = np.concatenate(
data[i : i + self.ascan_averaging_value],
axis=0, # Stack along the z-axis
)[::-1] # Reverse the stacking order so the last tile appears on top
stacked_data.append(stacked_tile)
# Combine all stacked tiles into a single array
stacked_data = np.stack(stacked_data, axis=0)
# Since we stacked the tiles along the y-axis, we need to adjust the resolution
self.resolution = [
self.resolution[0] * self.size_x / stacked_data.shape[0],
self.resolution[1] * self.size_y / stacked_data.shape[1],
self.resolution[2],
]
self.size_x = stacked_data.shape[0]
self.size_y = stacked_data.shape[1]
return stacked_data
def _crop_z(self, data: np.ndarray, index1: int = 320, index2: int = 750) -> np.ndarray:
"""Crop the 3D volume along the Z-axis between the specified indices.
Parameters
----------
data : np.ndarray
The input 3D array (SizeX, SizeY, SizeZ).
index1 : int
The starting Z index for cropping (inclusive).
index2 : int
The ending Z index for cropping (exclusive).
Returns
-------
np.ndarray
The cropped 3D array.
Raises
------
ValueError
If indices are invalid.
"""
# Ensure valid indices
if index1 < 0 or index2 > data.shape[2] or index1 >= index2:
raise ValueError(f"Invalid indices: index1={index1}, index2={index2}, data shape={data.shape}")
# Perform the crop
cropped_data = data[:, :, index1:index2]
self.size_z = cropped_data.shape[2]
return cropped_data
def _load_raw_data(self, file: str) -> np.ndarray:
"""Load the raw data from the specified file as a NumPy array.
Parameters
----------
file : str
File path in the compressed data.
Returns
-------
np.ndarray
Raw complex data array.
"""
assert self.compressed_data is not None
assert self.size_x is not None and self.size_y is not None and self.size_z is not None
with self.compressed_data.open(file) as f:
raw_data = np.frombuffer(f.read(), dtype=np.complex64).reshape((self.size_x, self.size_y, self.size_z), order="C")
return raw_data
def _preprocess_data(
self,
data: np.ndarray,
) -> np.ndarray:
"""Preprocess the data: crop, stack, and convert to magnitude.
Parameters
----------
data : np.ndarray
Input complex data array.
Returns
-------
np.ndarray
Preprocessed data array.
"""
assert self.config is not None
# Perform cropping
data = self._crop_z(
data,
index1=self.config.crop_first_index,
index2=self.config.crop_second_index,
)
# Perform stacking
data = self._stack_tiles_vertically(data)
assert self.ascan_averaging_value is not None
# Adjust the size_y to be divisible by ascan_averaging_value. Necessary for stacking.
data = data[:, : data.shape[1] - (data.shape[1] % self.ascan_averaging_value), :]
self.size_y = data.shape[1]
# Return complex or magnitude data
return data if self.config.return_complex else np.abs(data).astype(np.float64)
[docs]
def load_and_process(self, file: str) -> np.ndarray:
"""Load raw data from the file and preprocess it.
Parameters
----------
file : str
File path in the compressed data.
Returns
-------
np.ndarray
Fully processed data array.
"""
raw_data = self._load_raw_data(file)
processed_data = self._preprocess_data(raw_data)
return processed_data
@staticmethod
@staticmethod
[docs]
def get_psoct_tiles_ids(tiles_directory: Path, number_of_angles: int = 2) -> tuple:
"""
Get the .scan file and all .oct files from the tiles_directory.
Parameters
----------
tiles_directory : str
Path to the directory containing the OCT tiles.
number_of_angles : int
Number of acquisition angles.
Returns
-------
- positions: positions of the tiles in 3d
- grouped_files: list of file paths ordered by angles.
Raises
------
- ValueError: If the directory or required files are missing.
"""
# Convert the tiles_directory to a Path object
tiles_path = Path(tiles_directory)
if not tiles_path.is_dir():
raise ValueError(f"Provided path '{tiles_directory}' is not a valid directory.")
# Initialize variables to store the results
scan_file = None
oct_files = []
grouped_files = [[] for _ in range(number_of_angles)]
positions = []
angle_index = 0
# Iterate through files in the directory
for file in tiles_path.iterdir():
# Check for .scan file
if file.suffix == ".scan":
scan_file = file
# Collect .oct files
elif file.suffix == ".oct":
oct_files.append(file)
positions, _ = ThorOCT.extract_positions_from_scan(str(scan_file) if scan_file is not None else None)
# If no .oct files are found, raise a warning
if not oct_files:
raise ValueError("Warning: No .oct files found in the directory.")
for i, oct_file in enumerate(oct_files):
angle_index = i % number_of_angles # Determine the angle based on file index
grouped_files[angle_index].append(oct_file)
print(f"File Count for Angle index = {angle_index + 1}: {len(grouped_files[angle_index])}")
print("Processing the following Files:")
for file in grouped_files[0]:
print(f" - {file}")
return grouped_files, positions
@staticmethod
[docs]
def orient_volume_psoct(vol: np.ndarray) -> np.ndarray:
"""Transform the volume to RAS orientation."""
vol = vol.transpose(2, 0, 1)
return vol