Source code for linumpy.metrics.image_quality

#!/usr/bin/env python3
"""
Image quality assessment functions for slice analysis.

This module provides CPU-based functions for assessing image quality in 3D volumes,
including:

- Structural Similarity Index (SSIM)
- Edge preservation scoring
- Variance consistency analysis
- Overall slice quality assessment

For GPU-accelerated versions, see `linumpy.gpu.image_quality`.

Usage::

    from linumpy.metrics.image_quality import (
        compute_ssim_2d,
        compute_ssim_3d,
        compute_edge_score,
        compute_variance_score,
        assess_slice_quality,
    )

    # Compare two volumes
    ssim = compute_ssim_3d(vol1, vol2)

    # Assess overall slice quality
    quality, metrics = assess_slice_quality(vol, vol_before, vol_after)
"""

from typing import Any

import numpy as np


[docs] def normalize_image(img: np.ndarray) -> np.ndarray: """ Normalize image to [0, 1] range. Parameters ---------- img : np.ndarray Input image. Returns ------- np.ndarray Normalized image as float32. """ result = img.astype(np.float32) img_min, img_max = result.min(), result.max() if img_max > img_min: result = (result - img_min) / (img_max - img_min) return result
[docs] def compute_ssim_2d(img1: np.ndarray, img2: np.ndarray, win_size: int = 7) -> float: """ Compute SSIM between two 2D images. Parameters ---------- img1, img2 : np.ndarray Input images (2D). win_size : int Window size for SSIM computation. Returns ------- float SSIM score (0 to 1, higher is better). """ if img1.shape != img2.shape: min_y = min(img1.shape[0], img2.shape[0]) min_x = min(img1.shape[1], img2.shape[1]) img1 = img1[:min_y, :min_x] img2 = img2[:min_y, :min_x] try: from skimage.metrics import structural_similarity as ssim # Normalize images i1 = normalize_image(img1) i2 = normalize_image(img2) # Adjust window size for image dimensions actual_win_size = min(win_size, min(i1.shape) - 1) if actual_win_size % 2 == 0: actual_win_size -= 1 if actual_win_size < 3: actual_win_size = 3 return float(ssim(i1, i2, win_size=actual_win_size, data_range=1.0)) except Exception: # Fallback to normalized cross-correlation i1 = normalize_image(img1) i2 = normalize_image(img2) corr = np.corrcoef(i1.flatten(), i2.flatten())[0, 1] return float(max(0.0, corr)) if not np.isnan(corr) else 0.0
[docs] def compute_ssim_3d(vol1: np.ndarray, vol2: np.ndarray, win_size: int = 7, sample_depth: int = 0, xy_roi: int = 0) -> float: """ Compute mean SSIM between two 3D volumes. Computes SSIM for each z-slice and returns the mean. Parameters ---------- vol1, vol2 : np.ndarray Input volumes (Z, Y, X). win_size : int Window size for SSIM computation. sample_depth : int Number of z-planes to sample. 0 = all planes. xy_roi : int Side length of center crop in XY (pixels). 0 = full plane. Use a small value (e.g. 1024) on very large single-resolution zarr arrays to avoid loading gigabytes per plane. Returns ------- float Mean SSIM score (0 to 1, higher is better). """ nz = min(vol1.shape[0], vol2.shape[0]) ny = min(vol1.shape[1], vol2.shape[1]) nx = min(vol1.shape[2], vol2.shape[2]) # Compute center-crop bounds once (same for every plane) if xy_roi > 0: yc, xc = ny // 2, nx // 2 half = xy_roi // 2 ys, ye = max(0, yc - half), min(ny, yc + half) xs, xe = max(0, xc - half), min(nx, xc + half) else: ys, ye, xs, xe = 0, ny, 0, nx # Sample z-planes if requested indices = np.linspace(0, nz - 1, sample_depth, dtype=int) if sample_depth > 0 and nz > sample_depth else np.arange(nz) ssim_scores = [] for z in indices: # Load one plane (or crop) at a time -- works for zarr and numpy p1 = np.asarray(vol1[z, ys:ye, xs:xe]) p2 = np.asarray(vol2[z, ys:ye, xs:xe]) score = compute_ssim_2d(p1, p2, win_size) ssim_scores.append(score) return float(np.mean(ssim_scores))
[docs] def compute_edge_score(vol: np.ndarray, reference: np.ndarray, sample_z: int | None = None) -> float: """ Compute edge preservation score between volume and reference. Uses Sobel edge detection to compare edge structures. Parameters ---------- vol : np.ndarray Input volume (Z, Y, X) or 2D image. reference : np.ndarray Reference volume or image. sample_z : int, optional Z-index to sample for 3D volumes. If None, uses middle slice. Returns ------- float Edge preservation score (0 to 1, higher is better). """ from scipy.ndimage import sobel # Handle 3D volumes if vol.ndim == 3: if sample_z is None: sample_z = vol.shape[0] // 2 v = normalize_image(vol[sample_z]) r = normalize_image(reference[sample_z] if reference.ndim == 3 else reference) else: v = normalize_image(vol) r = normalize_image(reference) if v.shape != r.shape: min_y = min(v.shape[0], r.shape[0]) min_x = min(v.shape[1], r.shape[1]) v = v[:min_y, :min_x] r = r[:min_y, :min_x] # Compute edges using Sobel edges_v = np.sqrt(sobel(v, axis=0) ** 2 + sobel(v, axis=1) ** 2) edges_r = np.sqrt(sobel(r, axis=0) ** 2 + sobel(r, axis=1) ** 2) # Normalize edges if edges_v.max() > 0: edges_v = edges_v / edges_v.max() if edges_r.max() > 0: edges_r = edges_r / edges_r.max() # Compute correlation -- suppress divide warning when edges are constant (e.g. zero array) with np.errstate(invalid="ignore"): correlation = np.corrcoef(edges_v.flatten(), edges_r.flatten())[0, 1] if np.isnan(correlation): return 0.0 return float(max(0.0, correlation))
[docs] def compute_variance_score(vol: np.ndarray, reference: np.ndarray) -> float: """ Compute variance consistency score between volume and reference. Low variance may indicate data loss or corruption. Parameters ---------- vol : np.ndarray Input volume. reference : np.ndarray Reference volume. Returns ------- float Variance score (0 to 1, higher means more similar variance). """ var_vol = float(np.var(vol)) var_ref = float(np.var(reference)) if var_ref == 0: return 0.0 ratio = var_vol / var_ref # Score is 1 when variances are equal, decreases as they diverge score = 2.0 / (1.0 + abs(np.log(ratio + 1e-10))) return float(min(1.0, max(0.0, score)))
[docs] def assess_slice_quality( vol: np.ndarray, vol_before: np.ndarray | None, vol_after: np.ndarray | None, sample_depth: int = 5, weights: dict[str, float] | None = None, xy_roi: int = 0, ) -> tuple[float, dict[str, Any]]: """ Assess overall quality of a slice volume. Uses multiple metrics to determine slice quality: - SSIM with neighboring slices (50%) - Edge preservation compared to expected structure (30%) - Variance consistency (20%) Parameters ---------- vol : np.ndarray The slice volume (Z, Y, X). vol_before : np.ndarray or None The previous slice volume. vol_after : np.ndarray or None The next slice volume. sample_depth : int Number of z-planes to sample for SSIM. 0 = all. weights : dict, optional Custom weights for metrics. Keys: 'ssim', 'edge', 'variance'. xy_roi : int Side length of center crop in XY (pixels). 0 = full plane. Use a small value (e.g. 1024) on very large single-resolution zarr arrays to avoid loading gigabytes per plane. Returns ------- float Overall quality score (0 to 1). dict Individual metric values. """ if weights is None: weights = {"ssim": 0.5, "edge": 0.3, "variance": 0.2} nz = vol.shape[0] if vol.ndim == 3 else 1 ny = vol.shape[1] if vol.ndim == 3 else vol.shape[0] nx = vol.shape[2] if vol.ndim == 3 else vol.shape[1] # Compute center-crop bounds once -- all plane reads below use this region. # For large single-resolution zarr mosaic grids this is the primary # performance control: a 1024x1024 crop loads ~2 MB instead of ~5 GB. if xy_roi > 0: yc, xc = ny // 2, nx // 2 half = xy_roi // 2 ys, ye = max(0, yc - half), min(ny, yc + half) xs, xe = max(0, xc - half), min(nx, xc + half) else: ys, ye, xs, xe = 0, ny, 0, nx # Load a strided subsample (≤ 8 planes) of the crop for has-data / variance checks. step = max(1, nz // 8) vol_sample = np.asarray(vol[::step, ys:ye, xs:xe]) metrics: dict[str, Any] = { "ssim_before": 0.0, "ssim_after": 0.0, "ssim_mean": 0.0, "edge_score": 0.0, "variance_score": 0.0, "depth": nz, "has_data": True, } # Check if slice has meaningful data using the cheap sample if vol_sample.max() == vol_sample.min() or np.std(vol_sample) < 1e-6: metrics["has_data"] = False metrics["overall"] = 0.0 return 0.0, metrics # Compute SSIM with neighbors -- each call loads only sample_depth cropped planes ssim_scores = [] if vol_before is not None: metrics["ssim_before"] = compute_ssim_3d(vol, vol_before, sample_depth=sample_depth, xy_roi=xy_roi) ssim_scores.append(metrics["ssim_before"]) if vol_after is not None: metrics["ssim_after"] = compute_ssim_3d(vol, vol_after, sample_depth=sample_depth, xy_roi=xy_roi) ssim_scores.append(metrics["ssim_after"]) if ssim_scores: metrics["ssim_mean"] = float(np.mean(ssim_scores)) # Build a single reference plane (middle z, cropped) for edge and variance scores. mid_z = nz // 2 ny_n = min(ye, vol_before.shape[1] if vol_before is not None else ye, vol_after.shape[1] if vol_after is not None else ye) nx_n = min(xe, vol_before.shape[2] if vol_before is not None else xe, vol_after.shape[2] if vol_after is not None else xe) # Re-clip crop to neighbour extents ye_n = min(ye, ny_n) xe_n = min(xe, nx_n) ref_plane: np.ndarray | None = None if vol_before is not None and vol_after is not None: z_b = min(mid_z, vol_before.shape[0] - 1) z_a = min(mid_z, vol_after.shape[0] - 1) ref_plane = 0.5 * np.asarray(vol_before[z_b, ys:ye_n, xs:xe_n]).astype(np.float32) + 0.5 * np.asarray( vol_after[z_a, ys:ye_n, xs:xe_n] ).astype(np.float32) elif vol_before is not None: z_b = min(mid_z, vol_before.shape[0] - 1) ref_plane = np.asarray(vol_before[z_b, ys:ye_n, xs:xe_n]).astype(np.float32) elif vol_after is not None: z_a = min(mid_z, vol_after.shape[0] - 1) ref_plane = np.asarray(vol_after[z_a, ys:ye_n, xs:xe_n]).astype(np.float32) # Compute edge preservation score using the single cropped reference plane if ref_plane is not None: vol_plane = np.asarray(vol[mid_z, ys:ye_n, xs:xe_n]) metrics["edge_score"] = compute_edge_score(vol_plane, ref_plane) # Compute variance consistency using the strided crop vs reference plane if ref_plane is not None: metrics["variance_score"] = compute_variance_score(vol_sample, vol_sample * 0 + ref_plane.mean()) # Compute overall score overall = ( weights["ssim"] * metrics["ssim_mean"] + weights["edge"] * metrics["edge_score"] + weights["variance"] * metrics["variance_score"] ) metrics["overall"] = float(overall) return float(overall), metrics
[docs] def detect_calibration_slice(volumes: dict[int, np.ndarray], thickness_ratio: float = 1.5) -> list[int]: """ Detect calibration slices by their different thickness. Calibration slices are typically thicker than regular slices. Parameters ---------- volumes : dict Mapping from slice_id to volume array. thickness_ratio : float Slices with depth > median * ratio are flagged. Returns ------- list List of slice IDs identified as calibration slices. """ if not volumes: return [] slice_ids = sorted(volumes.keys()) depths = {sid: vol.shape[0] for sid, vol in volumes.items()} valid_depths = [d for d in depths.values() if d > 0] if not valid_depths: return [] median_depth = float(np.median(valid_depths)) # Check first few slices for unusual thickness calibration = [] for sid in slice_ids[:3]: if sid in depths and depths[sid] > 0: ratio = depths[sid] / median_depth if ratio > thickness_ratio: calibration.append(sid) return calibration
[docs] def compute_quality_report(slice_qualities: dict[int, dict[str, Any]], min_quality: float = 0.0) -> dict[str, Any]: """ Generate a quality report from slice quality assessments. Parameters ---------- slice_qualities : dict Mapping from slice_id to quality metrics dict. min_quality : float Minimum quality threshold for flagging. Returns ------- dict Summary report with statistics and flagged slices. """ if not slice_qualities: return {"error": "No slices to analyze"} overall_scores = [q.get("overall", 0.0) for q in slice_qualities.values()] report = { "n_slices": len(slice_qualities), "mean_quality": float(np.mean(overall_scores)), "std_quality": float(np.std(overall_scores)), "min_quality": float(np.min(overall_scores)), "max_quality": float(np.max(overall_scores)), "low_quality_slices": [], "no_data_slices": [], } for sid, metrics in slice_qualities.items(): if not metrics.get("has_data", True): report["no_data_slices"].append(sid) elif metrics.get("overall", 0.0) < min_quality: report["low_quality_slices"].append(sid) return report