Source code for pymedio.base

"""Base image class for image containers
Author: Jacob Reinhold <jcreinhold@gmail.com>
"""

from __future__ import annotations

__all__ = ["BasicImage"]

import collections.abc
import typing
import warnings

import numpy as np
import numpy.typing as npt

import pymedio.typing as miot
import pymedio.utils as miou

_UfuncMethod = typing.Literal[
    "__call__", "reduce", "reduceat", "accumulate", "outer", "inner"
]
_Image = typing.TypeVar("_Image", bound="BasicImage")


[docs]class BasicImage(npt.NDArray[miot.DType]): _affine: npt.NDArray def __new__( cls: typing.Type[_Image], data: npt.ArrayLike, affine: npt.NDArray | None = None, *, info: str | npt.NDArray[np.str_] | None = None, copy: bool = False, ) -> _Image: obj = cls._check_data(data, copy=copy).view(cls) obj._affine = cls._check_affine(affine) obj._affine.flags.writeable = False obj._info = cls._check_info(info) obj._info.flags.writeable = False return obj def __array_finalize__(self, obj: object) -> None: if obj is None: return self._affine = self._check_affine(getattr(obj, "_affine", None)) self._affine.flags.writeable = False self._info = self._check_info(getattr(obj, "info", None)) self._info.flags.writeable = False @property def repr_properties(self) -> list[str]: return [ f"shape: {self.shape}", f"spacing: {self.get_spacing_string()}", f"dtype: {self.dtype.name}", ] def __repr__(self) -> str: if self.ndim == 0: return str(self.item()) properties = "; ".join(self.repr_properties) string = f"{self.__class__.__name__}({properties})" return string def __array_ufunc__( self, ufunc: np.ufunc, method: _UfuncMethod, *inputs: typing.Any, out: typing.Any = None, **kwargs: typing.Any, ) -> typing.Any: affine = self.affine info = self.info ufunc_args = [] for input_ in inputs: if isinstance(input_, BasicImage): ufunc_args.append(input_.view(np.ndarray)) else: ufunc_args.append(input_) outputs = out if outputs: out_args = [] for output in outputs: if isinstance(output, BasicImage): out_args.append(output.view(np.ndarray)) else: out_args.append(output) kwargs["out"] = tuple(out_args) else: outputs = (None,) * ufunc.nout results = super().__array_ufunc__(ufunc, method, *ufunc_args, **kwargs) if method == "at": if isinstance(inputs[0], BasicImage): inputs[0].affine = affine inputs[0].info = info return if ufunc.nout == 1: results = (results,) results = tuple( ( self.__class__(np.asarray(result), affine, info=info) if output is None else output ) for result, output in zip(results, outputs) ) return results[0] if len(results) == 1 else results @property def affine(self) -> npt.NDArray[np.float64]: return self._affine @affine.setter def affine(self, new_affine: npt.NDArray) -> None: self._affine = self._check_affine(new_affine) @property def info(self) -> npt.NDArray[np.str_]: return self._info @info.setter def info(self, new_info: str | npt.NDArray[np.str_] | None) -> None: self._info = self._check_info(new_info) @property def direction(self) -> miot.Direction: _, _, direction = miou.get_metadata_from_ras_affine(self.affine, lps=False) return direction @property def spacing(self) -> tuple[miot.Float, ...]: """Voxel spacing in mm.""" _, spacing = miou.get_rotation_and_spacing_from_affine(self.affine) return tuple(spacing) @property def origin(self) -> tuple[miot.Float, ...]: """Center of first voxel in array, in mm.""" return tuple(self.affine[:3, 3]) @property def memory(self) -> miot.Int: """Number of bytes that the image array occupies in RAM""" mem: int = np.prod(self.shape) * self.itemsize return mem
[docs] def get_spacing_string(self) -> str: strings = [f"{n:.2g}" for n in self.spacing] string = f'({", ".join(strings)})' return string
@staticmethod def _check_data(data: npt.ArrayLike, copy: bool) -> npt.NDArray: data = np.array(data, copy=copy) if np.isnan(data).any(): warnings.warn("NaNs found in data", RuntimeWarning) if any(d == 0 for d in data.shape): msg = f"Data must have all non-zero dimensions. Got {data.shape}." raise ValueError(msg) return data @staticmethod def _check_affine(affine: npt.NDArray | None) -> npt.NDArray[np.float64]: if affine is None: return np.eye(4, dtype=np.float64) if not isinstance(affine, np.ndarray): bad_type = type(affine) raise TypeError(f"Affine must be a NumPy array, not {bad_type}") if affine.shape != (4, 4): bad_shape = affine.shape raise ValueError(f"Affine shape must be (4, 4), not {bad_shape}") return miou.to_f64(affine) @staticmethod def _check_info( info: str | npt.NDArray[np.str_] | None, ) -> npt.NDArray[np.str_]: if info is None: info = "" return np.asarray(info, dtype=np.str_)
[docs] def to_npz(self, file: miot.PathLike | typing.BinaryIO) -> None: np.savez_compressed( file, data=np.asarray(self), affine=self.affine, info=self.info )
[docs] @classmethod def from_npz( cls, file: miot.PathLike | typing.BinaryIO, **np_load_kwargs ) -> _Image: _data = np.load(file, **np_load_kwargs) return cls(_data["data"], affine=_data["affine"], info=_data["info"])
[docs] def torch_compatible(self) -> npt.NDArray: return miou.ensure_4d(miou.check_uint_to_int(np.asarray(self)))
[docs] def resample_image(self, shape: collections.abc.Sequence[int]) -> _Image: if self.ndim != len(shape): raise ValueError("length of 'shape' != number of dimensions in image.") if any(s <= 0 for s in shape): raise ValueError("All elements of 'shape' must be positive.") data = np.asarray(self) resolution_scale = np.empty(len(shape), dtype=np.float64) for i, (new_s, s) in enumerate(zip(shape, self.shape)): resolution_scale[i] = s / new_s def interpolate(fp: np.ndarray) -> np.ndarray: return np.interp( np.linspace(0, s, dtype=self.dtype, endpoint=False, num=new_s), np.arange(0, s, dtype=self.dtype), fp, ) data = np.apply_along_axis(interpolate, i, data) # https://math.stackexchange.com/q/1120209 rzs = self.affine[:3, :3] # rotation, scale (zoom), shear rotation, zs = np.linalg.qr(rzs) scale = zs.diagonal() shear = zs * (1.0 / scale) new_affine = np.zeros_like(self.affine) new_scale = scale * resolution_scale new_affine[:3, :3] = (rotation * new_scale) @ shear new_affine[:, 3] = self.affine[:, 3] return self.__class__( data.astype(self.dtype, copy=False), new_affine, info=self.info )