"""Functions for reading medical images
Taken from torchio and modified to output np.ndarray
https://github.com/fepegar/torchio
Author: Jacob Reinhold <jcreinhold@gmail.com>
"""
from __future__ import annotations
__all__ = [
"read_affine",
"read_image",
"read_matrix",
"write_image",
"write_matrix",
]
import gzip
import logging
import pathlib
import traceback
import typing
import warnings
import numpy as np
import numpy.typing as npt
try:
import nibabel as nib
import SimpleITK as sitk
except ImportError as imp_exn:
imp_exn_msg = f"NiBabel and SimpleITK must be installed to use {__name__}."
raise ImportError(imp_exn_msg) from imp_exn
import pymedio.typing as miot
import pymedio.utils as miou
# Image formats that are typically 2D
_2d_formats = [".jpg", ".jpeg", ".bmp", ".png", ".tif", ".tiff"]
IMAGE_2D_FORMATS = _2d_formats + [s.upper() for s in _2d_formats]
NibabelImageClass = typing.Type[
typing.Union[
nib.nifti1.Nifti1Pair,
nib.nifti1.Nifti1Image,
nib.nifti2.Nifti2Pair,
nib.nifti2.Nifti2Image,
nib.cifti2.cifti2.Cifti2Image,
nib.spm2analyze.Spm2AnalyzeImage,
nib.spm99analyze.Spm99AnalyzeImage,
nib.analyze.AnalyzeImage,
nib.minc1.Minc1Image,
nib.minc2.Minc2Image,
nib.freesurfer.mghformat.MGHImage,
nib.gifti.gifti.GiftiImage,
]
]
logger = logging.getLogger(__name__)
[docs]def read_image(
path: miot.PathLike,
*,
dtype: typing.Type[miot.DType] | None = None,
eager: bool = True,
) -> miot.DataAffine[miot.DType]:
try:
result = _read_sitk(path, dtype=dtype, copy=eager)
except RuntimeError as exn1: # try with NiBabel
message = f"Error loading image with SimpleITK:\n{exn1}\n\nTrying NiBabel..."
warnings.warn(message)
try:
result = _read_nibabel(path, dtype=dtype, mmap=not eager)
except nib.loadsave.ImageFileError as exn2: # type: ignore[attr-defined]
message = (
f"File '{path}' not understood."
" Check supported formats by at"
" https://simpleitk.readthedocs.io/en/master/IO.html#images"
" and https://nipy.org/nibabel/api.html#file-formats"
)
raise RuntimeError(message) from exn2
return result
def read_image_from_stream(
stream: typing.IO,
*,
dtype: typing.Type[miot.DType] | None = None,
gzipped: bool = False,
image_class: typing.Optional[NibabelImageClass] = None,
) -> miot.DataAffine[miot.DType]:
"""https://mail.python.org/pipermail/neuroimaging/2017-February/001345.html"""
_stream = gzip.GzipFile(fileobj=stream) if gzipped else stream
fh = nib.fileholders.FileHolder(fileobj=_stream) # type: ignore[arg-type]
if image_class is None:
for cls in nib.imageclasses.all_image_classes:
if hasattr(cls, "from_file_map"):
try:
args = {"header": fh, "image": fh}
kws = {"mmap": False}
img = cls.from_file_map(args, **kws) # type: ignore[call-arg]
break
except Exception:
logger.debug(traceback.format_exc())
else:
raise RuntimeError("Couldn't open data stream.")
else:
img = image_class.from_file_map({"header": fh, "image": fh}, mmap=False)
data = img.get_fdata(dtype=dtype) # type: ignore[attr-defined]
if data.ndim == 5:
data = data[..., 0, :]
data = data.transpose(3, 0, 1, 2)
affine = img.affine # type: ignore[attr-defined]
return data, affine
def _read_nibabel(
path: miot.PathLike,
*,
dtype: typing.Type[miot.DType] | None = None,
mmap: bool = False,
) -> miot.DataAffine[miot.DType]:
img = nib.loadsave.load(str(path), mmap=mmap)
data = img.get_fdata(dtype=dtype) # type: ignore[attr-defined]
if data.ndim == 5:
data = data[..., 0, :]
data = data.transpose(3, 0, 1, 2)
affine = miou.to_f64(img.affine) # type: ignore[attr-defined]
return data, affine
def _read_sitk(
path: miot.PathLike,
*,
dtype: typing.Type[miot.DType] | None = None,
copy: bool = True,
) -> miot.DataAffine[miot.DType]:
if pathlib.Path(path).is_dir(): # assume DICOM
image = _read_dicom_sitk(path)
else:
image = sitk.ReadImage(str(path))
data, affine = sitk_to_array(image, dtype=dtype, copy=copy)
return data, affine
def _read_dicom_sitk(directory: miot.PathLike) -> sitk.Image:
directory = pathlib.Path(directory)
if not directory.is_dir(): # unreachable if called from _read_sitk
raise FileNotFoundError(f"Directory '{directory}' not found")
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(str(directory))
if not dicom_names:
message = f"The directory '{directory}' does not seem to contain DICOM files"
raise FileNotFoundError(message)
reader.SetFileNames(dicom_names)
image = reader.Execute()
return typing.cast(sitk.Image, image)
def read_shape(path: miot.PathLike) -> miot.Shape:
reader = sitk.ImageFileReader()
reader.SetFileName(str(path))
reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
spatial_shape = reader.GetSize()
num_dimensions = reader.GetDimension()
if num_dimensions == 2:
spatial_shape = *spatial_shape, 1
elif num_dimensions == 4: # assume bad NIfTI
*spatial_shape, num_channels = spatial_shape
sx, sy, sz = spatial_shape
shape = (num_channels, sx, sy, sz)
return shape
[docs]def read_affine(path: miot.PathLike) -> npt.NDArray[np.float64]:
reader = get_reader(path)
affine: npt.NDArray[np.float64] = get_ras_affine_from_sitk(reader)
return affine
def get_reader(path: miot.PathLike, *, read: bool = True) -> sitk.ImageFileReader:
reader = sitk.ImageFileReader()
reader.SetFileName(str(path))
if read:
reader.ReadImageInformation()
return reader
[docs]def write_image(
array: npt.NDArray,
affine: npt.NDArray,
path: miot.PathLike,
*,
squeeze: bool = True,
**write_sitk_kwargs: bool,
) -> None:
if squeeze:
array = array.squeeze()
try:
_write_sitk(array, affine, path, **write_sitk_kwargs)
except RuntimeError: # try with NiBabel
_write_nibabel(array, affine, path)
def _write_nibabel(
array: npt.NDArray,
affine: npt.NDArray,
path: miot.PathLike,
) -> None:
"""
Expects a path with an extension that can be used by nibabel.save
to write a NIfTI-1 image, such as '.nii.gz' or '.img'
"""
num_components = array.shape[0]
# NIfTI components must be at the end, in a 5D array
if num_components == 1:
array = array[0]
elif array.ndim == 4:
array = array.transpose((1, 2, 3, 0))
suffix = pathlib.Path(str(path).replace(".gz", "")).suffix
img: typing.Union[nib.nifti1.Nifti1Image, nib.nifti1.Nifti1Pair]
if ".nii" in suffix:
img = nib.nifti1.Nifti1Image(np.asanyarray(array), affine)
elif ".hdr" in suffix or ".img" in suffix:
img = nib.nifti1.Nifti1Pair(np.asanyarray(array), affine)
else:
raise nib.loadsave.ImageFileError # type: ignore[attr-defined]
if num_components > 1:
img.header.set_intent("vector")
img.header["qform_code"] = 1
img.header["sform_code"] = 0
nib.loadsave.save(img, str(path))
def _write_sitk(
array: npt.NDArray,
affine: npt.NDArray,
path: miot.PathLike,
*,
use_compression: bool = True,
is_multichannel: bool = False,
) -> None:
path = pathlib.Path(path)
if path.suffix in (".png", ".jpg", ".jpeg", ".bmp"):
msg = f"Casting to uint8 before saving to {path}"
warnings.warn(msg, RuntimeWarning)
array = array.astype(dtype=np.uint8, copy=False)
image = array_to_sitk(array, affine, is_multichannel=is_multichannel)
sitk.WriteImage(image, str(path), use_compression)
[docs]def read_matrix(path: miot.PathLike) -> npt.NDArray[np.float64]:
"""Read an affine transform and return array"""
path = pathlib.Path(path)
suffix = path.suffix
if suffix in (".tfm", ".h5"): # ITK
array = _read_itk_matrix(path)
elif suffix in (".txt", ".trsf"): # NiftyReg, blockmatching
array = _read_niftyreg_matrix(path)
else:
raise ValueError(f"Unknown suffix for transform file: '{suffix}'")
return array
[docs]def write_matrix(matrix: npt.NDArray, path: miot.PathLike) -> None:
"""Write an affine transform."""
path = pathlib.Path(path)
suffix = path.suffix
if suffix in (".tfm", ".h5"): # ITK
_write_itk_matrix(matrix, path)
elif suffix in (".txt", ".trsf"): # NiftyReg, blockmatching
_write_niftyreg_matrix(matrix, path)
def _to_itk_convention(matrix: npt.NDArray) -> npt.NDArray[np.float64]:
"""RAS to LPS"""
_flipxy_44 = miou.flipxy_44()
matrix = np.dot(_flipxy_44, matrix)
matrix = np.dot(matrix, _flipxy_44)
matrix = np.linalg.inv(matrix)
return matrix
def _from_itk_convention(matrix: npt.NDArray) -> npt.NDArray[np.float64]:
"""LPS to RAS"""
_flipxy_44 = miou.flipxy_44()
matrix = np.dot(matrix, _flipxy_44)
matrix = np.dot(_flipxy_44, matrix)
matrix = np.linalg.inv(matrix)
return matrix
def _read_itk_matrix(path: miot.PathLike) -> npt.NDArray[np.float64]:
"""Read an affine transform in ITK's .tfm format"""
transform = sitk.ReadTransform(str(path))
parameters = transform.GetParameters()
rotation_params = parameters[:9]
rotation_matrix = miou.to_f64(rotation_params).reshape(3, 3)
translation_params = parameters[9:]
translation_vector = miou.to_f64(translation_params).reshape(3, 1)
matrix: np.ndarray = np.hstack([rotation_matrix, translation_vector])
homogeneous_matrix_lps: np.ndarray = np.vstack([matrix, [0.0, 0.0, 0.0, 1.0]])
homogeneous_matrix_ras = _from_itk_convention(homogeneous_matrix_lps)
return homogeneous_matrix_ras
def _write_itk_matrix(matrix: npt.NDArray, tfm_path: miot.PathLike) -> None:
"""The tfm file contains the matrix from floating to reference."""
transform = _matrix_to_itk_transform(matrix)
transform.WriteTransform(str(tfm_path))
def _matrix_to_itk_transform(
matrix: npt.NDArray, *, dims: int = 3
) -> sitk.AffineTransform:
matrix = _to_itk_convention(matrix)
rotation = matrix[:dims, :dims].ravel().tolist()
translation = matrix[:dims, 3].tolist()
transform = sitk.AffineTransform(rotation, translation)
return transform
def _read_niftyreg_matrix(path: miot.PathLike) -> npt.NDArray[np.float64]:
"""Read a NiftyReg matrix and return it as a torch.Tensor"""
matrix: np.ndarray = np.loadtxt(path, dtype=np.float64)
matrix = np.linalg.inv(matrix)
return matrix
def _write_niftyreg_matrix(matrix: npt.NDArray, txt_path: miot.PathLike) -> None:
"""Write an affine transform in NiftyReg's .txt format (ref -> flo)"""
matrix = np.linalg.inv(matrix)
np.savetxt(txt_path, matrix, fmt="%.8f")
def array_to_sitk(
array: npt.NDArray,
affine: npt.NDArray,
*,
is_multichannel: bool = False,
) -> sitk.Image:
"""Create a SimpleITK image from an array and a 4x4 affine matrix."""
ndim = array.ndim
array = np.asanyarray(array)
affine = miou.to_f64(affine)
image = sitk.GetImageFromArray(array.transpose(), isVector=is_multichannel)
is_2d = (ndim == 3 and is_multichannel) or (ndim == 2 and not is_multichannel)
origin, spacing, direction = miou.get_metadata_from_ras_affine(
affine,
is_2d=is_2d,
)
image.SetOrigin(origin)
image.SetSpacing(spacing)
image.SetDirection(direction)
num_spatial_dims = 2 if is_2d else 3
offset = 1 if is_multichannel else 0
if is_multichannel:
num_components = array.shape[0]
if (_n_comp := image.GetNumberOfComponentsPerPixel()) != num_components:
msg = f"sitk components {_n_comp} != array components {num_components}"
raise RuntimeError(msg)
spatial_dims = array.shape[offset : offset + num_spatial_dims]
if image.GetSize() != spatial_dims:
raise RuntimeError(f"{image.GetSize()} != {spatial_dims}")
return image
def sitk_to_array(
image: sitk.Image,
*,
dtype: typing.Type[miot.DType] | None = None,
copy: bool = True,
) -> miot.DataAffine[miot.DType]:
arr = sitk.GetArrayFromImage(image) if copy else sitk.GetArrayViewFromImage(image)
data: np.ndarray = np.asarray(arr, dtype=dtype).transpose()
num_components = image.GetNumberOfComponentsPerPixel()
input_spatial_dims = image.GetDimension()
if input_spatial_dims == 5: # probably a bad NIfTI (1, sx, sy, sz, c)
# Try to fix it
num_components = data.shape[-1]
data = data[0]
data = data.transpose(3, 0, 1, 2)
if num_components > 1 and data.shape[0] != num_components:
raise RuntimeError(f"{data.shape[0]} != {num_components}")
affine: npt.NDArray[np.float64] = get_ras_affine_from_sitk(image)
return data, affine
def get_ras_affine_from_sitk(
sitk_object: sitk.Image | sitk.ImageFileReader,
*,
dtype: typing.Type[miot.DType] | None = None,
) -> npt.NDArray[miot.DType]:
if dtype is None:
dtype = np.float64 # type: ignore[assignment]
spacing: np.ndarray = np.asarray(sitk_object.GetSpacing(), dtype=dtype)
direction_lps: np.ndarray = np.asarray(sitk_object.GetDirection(), dtype=dtype)
origin_lps: np.ndarray = np.asarray(sitk_object.GetOrigin(), dtype=dtype)
direction_length = len(direction_lps)
if direction_length == 9:
rotation_lps = direction_lps.reshape(3, 3)
elif direction_length == 4: # ignore last dimension if 2D (1, W, H, 1)
rotation_lps_2d = direction_lps.reshape(2, 2)
rotation_lps = np.eye(3, dtype=dtype)
rotation_lps[:2, :2] = rotation_lps_2d
spacing = np.append(spacing, 1)
origin_lps = np.append(origin_lps, 0)
elif direction_length == 16: # probably a bad NIfTI. Let's try to fix it
rotation_lps = direction_lps.reshape(4, 4)[:3, :3]
spacing = spacing[:-1]
origin_lps = origin_lps[:-1]
else:
raise RuntimeError(f"Invalid direction length: {direction_length}")
_flipxy_33 = miou.flipxy_33()
rotation_ras = np.dot(_flipxy_33, rotation_lps)
rotation_ras_zoom = rotation_ras * spacing
translation_ras = np.dot(_flipxy_33, origin_lps)
affine: np.ndarray = np.eye(4, dtype=dtype)
affine[:3, :3] = rotation_ras_zoom
affine[:3, 3] = translation_ras
return affine