"""Image class to hold a medical image
Taken from torchio and modified to use eager
load and use np.ndarray instead of torch.tensor
https://github.com/fepegar/torchio
Author: Jacob Reinhold <jcreinhold@gmail.com>
"""
from __future__ import annotations
__all__ = ["Image"]
import typing
import zipfile
import numpy as np
import numpy.typing as npt
try:
import nibabel as nib
import SimpleITK as sitk
except ImportError as imp_exn:
imp_exn_msg = f"NiBabel and SimpleITK must be installed to use {__name__}."
raise ImportError(imp_exn_msg) from imp_exn
import pymedio.base as miob
import pymedio.dicom as miod
import pymedio.functional as miof
import pymedio.typing as miot
[docs]class Image(miob.BasicImage[typing.Any, miot.DType]): # type: ignore[type-arg]
@property
def repr_properties(self) -> list[str]:
props = super().repr_properties
props += [f"orientation: {''.join(self.orientation)}+"]
return props
@property
def orientation(self) -> tuple[str, str, str]:
"""Orientation codes."""
codes: tuple[str, str, str]
codes = nib.orientations.aff2axcodes(self.affine)
return codes
@property
def bounds(self) -> npt.NDArray[np.float64]:
"""Position of centers of voxels in smallest and largest indices."""
ini = 0, 0, 0
fin: np.ndarray = np.asarray(self.shape) - 1
point_ini = nib.affines.apply_affine(self.affine, ini)
point_fin = nib.affines.apply_affine(self.affine, fin)
return np.asarray((point_ini, point_fin))
[docs] def axis_name_to_index(self, axis: str) -> int:
# Top and bottom are used for the vertical 2D axis as the use of
# Height vs Horizontal might be ambiguous
if not isinstance(axis, str):
raise ValueError("Axis must be a string")
axis = axis[0].upper()
if axis in "TB": # Top, Bottom
return -2
else:
try:
index = self.orientation.index(axis)
except ValueError:
index = self.orientation.index(self.flip_axis(axis))
# Return negative indices so that it does not matter whether we
# refer to spatial dimensions or not
index = -3 + index
return index
[docs] @staticmethod
def flip_axis(axis: str) -> str:
labels = "LRPAISTBDV"
first = labels[::2]
last = labels[1::2]
flip_dict = {a: b for a, b in zip(first + last, last + first)}
axis = axis[0].upper()
flipped_axis = flip_dict.get(axis)
if flipped_axis is None:
values = ", ".join(labels)
message = f"Axis not understood. Please use one of: {values}"
raise ValueError(message)
return flipped_axis
[docs] def get_bounds(self) -> miot.Bounds:
"""Get minimum and maximum world coordinates occupied by the image."""
first_index = 3 * (-0.5,)
last_index: np.ndarray = np.asarray(self.shape) - 0.5
first_point = nib.affines.apply_affine(self.affine, first_index)
last_point = nib.affines.apply_affine(self.affine, last_index)
array: np.ndarray = np.asarray((first_point, last_point))
bounds_x, bounds_y, bounds_z = array.T.tolist()
return bounds_x, bounds_y, bounds_z
[docs] def save(self, path: miot.PathLike, *, squeeze: bool = True) -> None:
miof.write_image(np.array(self), self.affine, path, squeeze=squeeze)
[docs] def to_filename(self, path: miot.PathLike) -> None:
self.save(path, squeeze=False)
[docs] def get_center(self, lps: bool = False) -> miot.TripletFloat:
"""Get image center in RAS+ or LPS+ coordinates"""
size: np.ndarray = np.asarray(self.shape)
center_index = (size - 1) / 2
r, a, s = nib.affines.apply_affine(self.affine, center_index)
return (-r, -a, s) if lps else (r, a, s)
[docs] @classmethod
def from_path(
cls: typing.Type[Image],
path: miot.PathLike,
*,
dtype: typing.Type[miot.DType] | None = None,
eager: bool = True,
) -> Image[miot.DType]:
data, affine = miof.read_image(path, dtype=dtype, eager=eager)
return cls(data=data, affine=affine)
[docs] @classmethod
def from_stream(
cls: typing.Type[Image],
data_stream: typing.IO,
*,
dtype: typing.Type[miot.DType] | None = None,
gzipped: bool = False,
image_class: miof.NibabelImageClass | None = None,
) -> Image[miot.DType]:
data, affine = miof.read_image_from_stream(
data_stream, dtype=dtype, gzipped=gzipped, image_class=image_class
)
return cls(data=data, affine=affine)
[docs] @classmethod
def from_zipped_stream(
cls: typing.Type[Image],
data_stream: typing.IO,
*,
dtype: typing.Type[miot.DType] | None = None,
gzipped: bool = False,
image_class: miof.NibabelImageClass | None = None,
**zip_kwargs: typing.Any,
) -> Image[miot.DType]:
with zipfile.ZipFile(data_stream, "r", **zip_kwargs) as zf:
names = [name for name in zf.namelist() if not name.endswith("/")]
if (n := len(names)) != 1:
msg = f"{n} files in zipped archive. This constructor requires only 1."
raise RuntimeError(msg)
name = names[0]
with zf.open(name, mode="r") as f:
return cls.from_stream(
typing.cast(typing.BinaryIO, f),
dtype=dtype,
gzipped=gzipped,
image_class=image_class,
)
[docs] @classmethod
def from_sitk(
cls: typing.Type[Image],
sitk_image: sitk.Image,
*,
dtype: typing.Type[miot.DType] | None = None,
) -> Image:
data, affine = miof.sitk_to_array(sitk_image, dtype=dtype)
return cls(data=data, affine=affine)
[docs] @classmethod
def from_dicom_image(
cls: typing.Type[Image], dicom_image: miod.DICOMImage
) -> Image:
return cls(data=dicom_image, affine=dicom_image.affine)
[docs] @classmethod
def from_dicom_zipped_stream(
cls: typing.Type[Image],
data_stream: typing.IO,
*,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
remove_anomalous_images: bool = True,
encryption_key: bytes | str | None = None,
rescale: bool | None = None,
rescale_dtype: typing.Type[miot.DType] | None = None,
) -> Image[miot.DType]:
dicom_image = miod.DICOMImage.from_zipped_stream(
data_stream,
max_nonuniformity=max_nonuniformity,
fail_outside_max_nonuniformity=fail_outside_max_nonuniformity,
remove_anomalous_images=remove_anomalous_images,
encryption_key=encryption_key,
rescale=rescale,
rescale_dtype=rescale_dtype,
)
return cls.from_dicom_image(dicom_image)
[docs] def to_sitk(self, **kwargs: bool) -> sitk.Image:
"""Get the image as an instance of :class:`sitk.Image`."""
return miof.array_to_sitk(np.array(self), self.affine, **kwargs)
[docs] def to_nibabel(self) -> nib.nifti1.Nifti1Image:
return nib.nifti1.Nifti1Image(np.array(self), self.affine)