#!/usr/bin/env python3
"""Quick reconstruction and processing methods for the S-OCT data."""
import re
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from imageio import imwrite
from matplotlib.patches import Rectangle
from scipy.ndimage import binary_fill_holes, median_filter
from skimage.color import label2rgb
from skimage.filters import threshold_otsu
from skimage.measure import label
from skimage.transform import resize
from tqdm.auto import tqdm
from linumpy.microscope.oct import OCT
[docs]
def get_largest_cc(segmentation: np.ndarray) -> np.ndarray:
"""Get the largest connected component in a binary image.
Parameters
----------
segmentation : np.ndarray
The binary image to process.
Returns
-------
np.ndarray
The largest connected component.
"""
labels = label(segmentation)
assert labels.max() != 0 # assume at least 1 CC
largest_cc = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
return largest_cc
[docs]
DEFAULT_TILE_FILE_PATTERN = r"tile_x(?P<x>\d+)_y(?P<y>\d+)_z(?P<z>\d+)"
[docs]
def get_tiles_ids(directory: Path, z: int | None = None) -> tuple:
"""Analyze a directory and detect all the tiles it contains."""
input_directory = Path(directory)
# Get a list of the input tiles
tiles_to_process = f"*z{z:02d}" if z is not None else "tile_*"
tiles = list(input_directory.rglob(tiles_to_process))
tiles = [t for t in tiles if t.name.startswith("tile_") and not t.is_file()]
tile_ids = get_tiles_ids_from_list(tiles)
return tiles, tile_ids
[docs]
def get_tiles_ids_from_list(tiles_list: list, file_pattern: str = DEFAULT_TILE_FILE_PATTERN) -> list:
"""Return tile (x, y, z) IDs parsed from a list of tile paths."""
tiles_list.sort()
# Get the tile positions
tile_ids = []
n_tiles = len(tiles_list)
for t in tqdm(tiles_list, desc="Extracting tile ids", total=n_tiles):
# Extract the tile's mosaic position.
match = re.match(file_pattern, t.name)
assert match is not None
mx = int(match.group("x"))
my = int(match.group("y"))
mz = int(match.group("z"))
tile_ids.append((mx, my, mz))
return tile_ids
[docs]
def get_mosaic_info(directory: Path, z: int, overlap_fraction: float = 0.2, use_stage_positions: bool = False) -> dict:
"""Return mosaic geometry and tile metadata for a given z-slice."""
# Get a list of the input tiles
tiles, _tile_ids = get_tiles_ids(directory, z)
# Get the tile positions (in pixel and mm)
file_pattern = r"tile_x(?P<x>\d+)_y(?P<y>\d+)_z(?P<z>\d+)"
tiles_positions_px = []
tiles_positions_mm = []
mosaic_tile_pos = []
# Progress bars overlap as the position is the same in all threads. Position is 1 to avoid overlap with outer loop.
# No better solution has been found.
oct_tile: OCT | None = None
for t in tqdm(tiles, desc="Reading mosaic info", leave=False, position=1):
oct_tile = OCT(t)
# Extract the tile's mosaic position.
match = re.match(file_pattern, t.name)
assert match is not None
mx = int(match.group("x"))
my = int(match.group("y"))
if oct_tile.position_available and use_stage_positions:
x_mm, y_mm, _ = oct_tile.position
else:
# Compute the tile position in mm
x_mm = oct_tile.dimension[0] * (1 - overlap_fraction) * mx
y_mm = oct_tile.dimension[1] * (1 - overlap_fraction) * my
x_px = int(np.floor(x_mm / oct_tile.resolution[0]))
y_px = int(np.floor(y_mm / oct_tile.resolution[1]))
mosaic_tile_pos.append((mx, my))
tiles_positions_mm.append((x_mm, y_mm))
tiles_positions_px.append((x_px, y_px))
# Compute the mosaic shape
assert oct_tile is not None
x_min = min([x for x, _ in tiles_positions_px])
y_min = min([y for _, y in tiles_positions_px])
x_max = max([x for x, _ in tiles_positions_px]) + oct_tile.shape[0]
y_max = max([y for _, y in tiles_positions_px]) + oct_tile.shape[1]
mosaic_nrows = x_max - x_min
mosaic_ncols = y_max - y_min
# Get the mosaic grid shape
n_mx = len(np.unique([x[0] for x in mosaic_tile_pos]))
n_my = len(np.unique([x[1] for x in mosaic_tile_pos]))
# Get the mosaic limits in mm
xmin_mm = np.min([p[0] for p in tiles_positions_mm]) - oct_tile.dimension[0] / 2
ymin_mm = np.min([p[1] for p in tiles_positions_mm]) - oct_tile.dimension[1] / 2
xmax_mm = np.max([p[0] for p in tiles_positions_mm]) + oct_tile.dimension[0] / 2
ymax_mm = np.max([p[1] for p in tiles_positions_mm]) + oct_tile.dimension[1] / 2
mosaic_center_mm = ((xmin_mm + xmax_mm) / 2, (ymin_mm + ymax_mm) / 2)
mosaic_width_mm = xmax_mm - xmin_mm
mosaic_height_mm = ymax_mm - ymin_mm
info = {
"tiles": tiles,
"tiles_pos_px": tiles_positions_px,
"tiles_pos_mm": tiles_positions_mm,
"mosaic_tile_pos": mosaic_tile_pos,
"mosaic_nrows": mosaic_nrows,
"mosaic_ncols": mosaic_ncols,
"mosaic_xmin_px": x_min,
"mosaic_ymin_px": y_min,
"mosaic_xmax_px": x_max,
"mosaic_ymax_px": y_max,
"mosaic_xmin_mm": xmin_mm,
"mosaic_ymin_mm": ymin_mm,
"mosaic_xmax_mm": xmax_mm,
"mosaic_ymax_mm": ymax_mm,
"mosaic_center_mm": mosaic_center_mm,
"mosaic_width_mm": mosaic_width_mm,
"mosaic_height_mm": mosaic_height_mm,
"mosaic_grid_shape": (n_mx, n_my),
"tile_shape_px": oct_tile.shape,
"tile_shape_mm": oct_tile.dimension,
"tile_resolution": oct_tile.resolution,
}
return info
[docs]
def quick_stitch(
directory: Path,
z: int,
overlap_fraction: float = 0.2,
n_rot: int = 3,
zmin: int = 0,
zmax: int = -1,
use_log: bool = False,
use_stage_positions: bool = False,
flip_ud: bool = True,
flip_lr: bool = False,
galvo_shift: int | None = None,
galvo_shift_first_tile: tuple = (0, 0),
) -> np.ndarray:
"""Stitch all tiles in a directory for a given z-slice into a mosaic."""
# TODO: accelerate the stitching by preprocessing the tiles in parallel
input_directory = Path(directory)
# Get a list of the input tiles
tiles_to_process = f"*z{z:02d}"
tiles = list(input_directory.glob(tiles_to_process))
# Get the tile positions (in pixel and mm)
file_pattern = r"tile_x(?P<x>\d+)_y(?P<y>\d+)_z(?P<z>\d+)"
tiles_positions_px = []
tiles_positions_mm = []
oct_tile: OCT | None = None
for t in tiles:
oct_tile = OCT(t)
if oct_tile.position_available and use_stage_positions:
x_mm, y_mm, _ = oct_tile.position
else:
# Extract the tile's mosaic position.
match = re.match(file_pattern, t.name)
assert match is not None
mx = int(match.group("x"))
my = int(match.group("y"))
# Compute the tile position in mm
x_mm = oct_tile.dimension[0] * (1 - overlap_fraction) * mx
y_mm = oct_tile.dimension[1] * (1 - overlap_fraction) * my
x_px = int(np.floor(x_mm / oct_tile.resolution[0]))
y_px = int(np.floor(y_mm / oct_tile.resolution[1]))
tiles_positions_mm.append((x_mm, y_mm))
tiles_positions_px.append((x_px, y_px))
# Compute the mosaic shape
assert oct_tile is not None
x_min = min([x for x, _ in tiles_positions_px])
y_min = min([y for _, y in tiles_positions_px])
x_max = max([x for x, _ in tiles_positions_px]) + oct_tile.shape[0]
y_max = max([y for _, y in tiles_positions_px]) + oct_tile.shape[1]
mosaic_nrows = x_max - x_min
mosaic_ncols = y_max - y_min
mosaic = np.zeros((mosaic_nrows, mosaic_ncols), dtype=np.float32)
# Perform stitching
for i in tqdm(range(len(tiles)), desc="Quick Stitch"):
oct_tile = OCT(tiles[i])
# Compute the pixel position within the mosaic
rmin = tiles_positions_px[i][0] - x_min
rmax = rmin + oct_tile.shape[0]
cmin = tiles_positions_px[i][1] - y_min
cmax = cmin + oct_tile.shape[1]
# Get the tile id
match = re.match(file_pattern, tiles[i].name)
assert match is not None
mx = int(match.group("x"))
my = int(match.group("y"))
apply_shift = True
if mx < galvo_shift_first_tile[0] or (mx == galvo_shift_first_tile[0] and my < galvo_shift_first_tile[1]):
apply_shift = False
# Load the fringes
img = oct_tile.load_image(fix_galvo_shift=galvo_shift) if apply_shift else oct_tile.load_image()
# Log transform
if use_log:
img = np.log(img)
# Compute an AIP
img = img[zmin:zmax, :, :].mean(axis=0)
# BUG: there are sometimes missing bscans
if img.shape != oct_tile.shape[0:2]:
img = (
np.zeros((int(oct_tile.shape[0]), int(oct_tile.shape[1])))
if np.any(np.array(img.shape) == 0)
else resize(img, oct_tile.shape[0:2])
)
# Apply rotations
img = np.rot90(img, k=n_rot)
# Flips
if flip_lr:
img = np.fliplr(img)
if flip_ud:
img = np.flipud(img)
# Add the tile to the mosaic
mosaic[rmin:rmax, cmin:cmax] = img
return mosaic
[docs]
def detect_mosaic(
directory: Path,
z: int,
img: np.ndarray | None = None,
margin: float = 0.5,
display: bool = False,
image_file: Path | None = None,
roi_file: Path | None = None,
keep_largest_island: bool = False,
stitching_settings: dict | None = None,
) -> tuple:
"""Detect the tissue in the mosaic and compute the limits of the tissue.
Parameters
----------
directory : str
The directory containing the tiles.
z : int
The z slices to process
img : np.ndarray or None
Optional pre-computed mosaic image.
stitching_settings : dict or None
Optional stitching settings override.
margin : float
The margin to add to the tissue limits (in mm).
display : bool
Display the result in a matplotlib window.
image_file : str
The filename to save the quickstitch image.
roi_file : str
The filename to save the ROI image.
keep_largest_island : bool
Keep the largest connected component in the mask.
"""
# Additional parameters
threshold_size = 1024 # maximum image size to use for the thresholding
normalization_percentile = 99.7
median_size = 15 # pixel
# Extract the parameters
directory = Path(directory)
# Get the mosaic information
info = get_mosaic_info(directory, z=z, use_stage_positions=True)
# Extract the tile positions from the metadata
xmin = np.min([p[0] for p in info["tiles_pos_mm"]]) - info["tile_shape_mm"][0] / 2
ymin = np.min([p[1] for p in info["tiles_pos_mm"]]) - info["tile_shape_mm"][1] / 2
xmax = np.max([p[0] for p in info["tiles_pos_mm"]]) + info["tile_shape_mm"][0] / 2
ymax = np.max([p[1] for p in info["tiles_pos_mm"]]) + info["tile_shape_mm"][1] / 2
# Stitch the image using the tile position
if img is None:
img = quick_stitch(directory, z=z, use_stage_positions=True, **(stitching_settings or {}))
# Save the quick stitch image
if image_file is not None:
save_quickstitch(img, image_file)
# Rescale the image to a small size
new_shape = tuple((np.array(img.shape) * threshold_size / np.min(img.shape)).astype(int).tolist())
img = resize(img, new_shape)
# Normalize the intensity
img = (img.astype(np.float32) - img.min()) / (np.percentile(img, normalization_percentile) - img.min())
img[img > 1] = 1
# Process the image, to find a mask
thresh = threshold_otsu(img)
mask = img > thresh
mask = median_filter(mask, median_size)
# Fill holes
mask = binary_fill_holes(mask)
# Keep the largest connected component
if keep_largest_island:
mask = get_largest_cc(mask)
# Compute the mosaic limits
n_rows, n_cols = img.shape
rows, cols = np.where(mask)
roi_r_min = rows.min()
roi_r_max = rows.max()
roi_c_min = cols.min()
roi_c_max = cols.max()
# Convert to mm
roi_x_min = (xmax - xmin) * roi_r_min / n_rows + xmin
roi_x_max = (xmax - xmin) * roi_r_max / n_rows + xmin
roi_y_min = (ymax - ymin) * roi_c_min / n_cols + ymin
roi_y_max = (ymax - ymin) * roi_c_max / n_cols + ymin
# Add margin
roi_x_min_margin = roi_x_min - margin
roi_x_max_margin = roi_x_max + margin
roi_y_min_margin = roi_y_min - margin
roi_y_max_margin = roi_y_max + margin
# TODO: Make sure the mosaic limits are within the allowed imaging limits
# Display the result
if display or roi_file is not None:
_fig, ax = plt.subplots()
ax.imshow(label2rgb(mask, img, bg_label=0, colors=["blue"]), extent=(ymin, ymax, xmax, xmin)) # Y axes are inverted
rect = Rectangle(
(roi_y_min, roi_x_min),
width=(roi_y_max - roi_y_min),
height=(roi_x_max - roi_x_min),
fill=None,
edgecolor="red",
linestyle="dashed",
label="ROI",
)
ax.add_patch(rect)
rect_margin = Rectangle(
(roi_y_min_margin, roi_x_min_margin),
width=(roi_y_max_margin - roi_y_min_margin),
height=(roi_x_max_margin - roi_x_min_margin),
fill=None,
edgecolor="red",
label="ROI + margin",
)
ax.add_patch(rect_margin)
ax.set_ylabel("x axis (mm)")
ax.set_xlabel("y axis (mm)")
title = (
f"xmin={roi_x_min_margin:.4f}mm, xmax={roi_x_max_margin:.4f}mm\n"
f"ymin={roi_y_min_margin:.4f}mm, ymax={roi_y_max_margin:.4f}mm"
)
ax.set_title(title)
ax.legend()
if roi_file is not None:
filename = Path(roi_file)
filename.parent.mkdir(exist_ok=True, parents=True)
plt.savefig(filename, dpi=150, bbox_inches="tight", pad_inches=0)
if display:
plt.show()
return roi_x_min_margin, roi_x_max_margin, roi_y_min_margin, roi_y_max_margin
[docs]
def save_quickstitch(img: np.ndarray, quickstitch_file: Path) -> None:
"""Save the quickstitch mosaic to a file, normalizing intensity."""
filename = Path(quickstitch_file)
# Normalize the intensity
mask = img > 0
imin = img[mask].min()
imax = np.percentile(img[mask], 99.7)
mosaic = (img - imin) / (imax - imin)
mosaic[~mask] = 0.0
# Save the mosaic
if filename.name.endswith(".jpg") or filename.name.endswith(".png"):
mosaic[mosaic > 1] = 1
mosaic = (255 * mosaic).astype(np.uint8)
elif filename.name.endswith(".tiff"):
mosaic = mosaic.astype(np.float32)
filename.parent.mkdir(exist_ok=True, parents=True)
imwrite(filename, mosaic)