"""DICOM-specific functions
inspired by dicom-numpy: https://github.com/innolitics/dicom-numpy
Author: Jacob Reinhold <jcreinhold@gmail.com>
"""
from __future__ import annotations
__all__ = [
"DICOMDir",
"DICOMImage",
"gather_dicom",
]
import collections.abc
import dataclasses
import functools
import io
import logging
import math
import operator
import pathlib
import typing
import warnings
import zipfile
import numpy as np
import numpy.typing as npt
try:
import pydicom
except ImportError as imp_exn:
imp_exn_msg = f"pydicom must be installed to use {__name__}."
raise ImportError(imp_exn_msg) from imp_exn
import pymedio.base as miob
import pymedio.exceptions as mioe
import pymedio.typing as miot
import pymedio.utils as miou
logger = logging.getLogger(__name__)
ORIENTATION_ATOL = 1e-5
DType = typing.TypeVar("DType", bound=np.generic)
T = typing.TypeVar("T")
Datasets = typing.Iterable[pydicom.Dataset]
def _all_float_like(seq: collections.abc.Sequence[float]) -> bool:
return all(isinstance(x, (float, int)) for x in seq)
[docs]def gather_dicom(
dicom_path: miot.PathLike | typing.Iterable[miot.PathLike],
*,
defer_size: str | int | None = "1 KB",
extension: str = ".dcm",
return_paths: bool = False,
) -> Datasets | tuple[Datasets, tuple[miot.PathLike, ...]]:
paths: tuple[miot.PathLike, ...]
if (
isinstance(dicom_path, (str, pathlib.Path))
and (_dcm_dir := pathlib.Path(dicom_path)).is_dir()
):
paths = tuple(sorted(_dcm_dir.glob(f"*{extension}")))
elif (
not isinstance(dicom_path, (str, pathlib.Path))
and miou.is_iterable(dicom_path)
and all(str(p).endswith(extension) for p in dicom_path) # type: ignore[union-attr] # noqa: E501
):
paths = tuple(dicom_path) # type: ignore[arg-type]
else:
raise ValueError("dicom_path must be path to a dir. or a list of dcm paths")
datasets = (pydicom.dcmread(path, defer_size=defer_size) for path in paths)
return (datasets, paths) if return_paths else datasets
@dataclasses.dataclass(frozen=True)
class Cosines:
# dicom-numpy -> dicom_numpy/combine_slices.py
# ITK -> Modules/IO/GDCM/src/itkGDCMImageIO.cxx
row: npt.NDArray
column: npt.NDArray
slice: npt.NDArray
def __repr__(self) -> str:
return f"Cosines(row={self.row}, column={self.column}, slice={self.slice})"
@classmethod
def from_orientation(
cls: typing.Type[Cosines],
image_orientation: typing.Sequence[float] | npt.NDArray,
) -> Cosines:
if isinstance(image_orientation, np.ndarray):
if image_orientation.size != 6 or image_orientation.ndim != 1:
raise ValueError("image_orientation must be seq. of 1 dim and len=6.")
elif len(image_orientation) != 6 or not _all_float_like(image_orientation):
raise ValueError("image_orientation must be seq. floats with len=6.")
row_cosine = miou.to_f64(image_orientation[:3])
column_cosine = miou.to_f64(image_orientation[3:])
slice_cosine: np.ndarray = np.cross(row_cosine, column_cosine)
cosines = cls(row_cosine, column_cosine, slice_cosine)
cosines.writable(False)
cosines.validate()
return cosines
def validate(self) -> None:
dot_prod = float(np.dot(self.row, self.column).item())
err_msg_dp = f"Non-orthogonal direction cosines: {self.row}, {self.column}"
warn_msg_dp = f"Direction cosines aren't quite ortho: {self.row}, {self.column}"
self._validate_value(dot_prod, err_msg_dp, warn_msg_dp, self._almost_zero)
row_cosine_norm = float(np.linalg.norm(self.row).item())
err_msg_rn = f"Row direction cosine's magnitude is not 1: {self.row}"
warn_msg_rn = f"Row direction cosine's magnitude not quite 1: {self.row}"
self._validate_value(row_cosine_norm, err_msg_rn, warn_msg_rn, self._almost_one)
col_cosine_norm = float(np.linalg.norm(self.column).item())
err_msg_cn = f"Column direction cosine's magnitude is not 1: {self.column}"
warn_msg_cn = f"Column direction cosine's magnitude not quite 1: {self.column}"
self._validate_value(col_cosine_norm, err_msg_cn, warn_msg_cn, self._almost_one)
def writable(self, value: bool, /) -> None:
self.row.flags.writeable = value
self.column.flags.writeable = value
self.slice.flags.writeable = value
@staticmethod
def _validate_value(
value: float,
err_msg: str,
warn_msg: str,
check_func: typing.Callable[..., bool],
) -> None:
if not check_func(value, atol=1e-4):
raise mioe.DicomImportException(value, err_msg)
elif not check_func(value, atol=1e-8):
warnings.warn(warn_msg)
@staticmethod
def _almost_zero(value: float, *, atol: float) -> bool:
return math.isclose(value, 0.0, abs_tol=atol)
@staticmethod
def _almost_one(value: float, *, atol: float) -> bool:
return math.isclose(value, 1.0, abs_tol=atol)
@dataclasses.dataclass(frozen=True)
class SortedSlices:
slices: tuple[pydicom.Dataset, ...]
indices: tuple[int, ...] # mapping to get original slice order
positions: tuple[float, ...]
cosines: Cosines
def __repr__(self) -> str:
return f"SortedSlices(n_slices={len(self)}, cosines={self.cosines!r})"
def __len__(self) -> int:
_len = len(self.slices)
exn_msg: list[str] = []
if _len != (ind_len := len(self.indices)):
exn_msg.append(f"num slices {_len} != num indices {ind_len}")
if _len != (pos_len := len(self.positions)):
exn_msg.append(f"num slices {_len} != num positions {pos_len}")
if exn_msg:
raise RuntimeError(" ".join(exn_msg))
return _len
@classmethod
def from_datasets(
cls, slice_datasets: typing.Sequence[pydicom.Dataset]
) -> SortedSlices:
"""sort list of pydicom datasets into the correct order"""
if not slice_datasets:
raise ValueError("slice_datasets empty")
image_orientation = miou.to_f64(slice_datasets[0].ImageOrientationPatient)
cosines = Cosines.from_orientation(image_orientation)
ipps = (miou.to_f64(sd.ImagePositionPatient) for sd in slice_datasets)
positions = (np.dot(cosines.slice, imp).item() for imp in ipps)
_sorted = typing.cast(
typing.Iterable[tuple[int, float]],
sorted(enumerate(positions), key=operator.itemgetter(1)),
)
sorted_indices, sorted_positions = miou.unzip(_sorted)
sorted_slice_datasets = tuple(slice_datasets[i] for i in sorted_indices)
return cls(
slices=sorted_slice_datasets,
indices=sorted_indices,
positions=sorted_positions,
cosines=cosines,
)
def check_nonuniformity(
self,
*,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
missing_slices_cutoff: float = 1e-1,
) -> None:
if len(self) > 1:
diffs = np.diff(self.positions)
warned = False
if not np.allclose(diffs, diffs[0], atol=0.0, rtol=max_nonuniformity):
msg = f"The slice spacing is non-uniform. Slice spacings:\n{diffs}"
if fail_outside_max_nonuniformity:
raise mioe.OutsideMaxNonUniformity(msg)
else:
warned = True
warnings.warn(msg)
if not np.allclose(diffs, diffs[0], atol=0.0, rtol=missing_slices_cutoff):
msg = "There appear to be missing slices."
if fail_outside_max_nonuniformity:
raise mioe.MissingSlicesException(msg)
elif not warned:
warnings.warn(msg)
def remove_anomalous_slices(
self,
*,
strict_unique_orientation: bool = True,
unique_positions: bool = True,
) -> SortedSlices:
to_float_tuple = lambda xs: tuple(float(x) for x in xs) # noqa: E731
orientations = [to_float_tuple(s.ImageOrientationPatient) for s in self.slices]
if strict_unique_orientation:
unq_oris: np.ndarray
unq_oris, counts = np.unique(orientations, axis=0, return_counts=True)
most_common_orientation = unq_oris[np.argmax(counts)]
else:
approx_unique_orientations = self._approx_unique(orientations)
most_common_orientation = approx_unique_orientations[-1]
seen_positions = set()
out = []
for _slice, idx, pos, o in self._zip_with(orientations):
if np.allclose(o, most_common_orientation, atol=ORIENTATION_ATOL):
if unique_positions and pos in seen_positions:
logger.debug(f"Slice at index {idx} has a non-unique position.")
continue
out.append((_slice, idx, pos))
seen_positions.add(pos)
else:
logger.debug(f"Slice at index {idx} has a different orientation.")
new_slices, new_indices, new_positions = miou.unzip(out)
if (n_removed := (len(self) - len(new_slices))) > 1:
warnings.warn(f"{n_removed} anomalous images removed.")
elif n_removed < 0:
raise RuntimeError("Images added in remove image func. Report error.")
new_image_orientation = miou.to_f64(most_common_orientation)
new_cosines = Cosines.from_orientation(new_image_orientation)
return SortedSlices(
slices=new_slices,
indices=new_indices,
positions=new_positions,
cosines=new_cosines,
)
@functools.cached_property
def patient_position(self) -> npt.NDArray:
return miou.to_f64(self.slices[0].ImagePositionPatient)
@functools.cached_property
def slice_spacing(self) -> float:
spacing: float
if len(self) > 1:
slice_positions_diffs = np.diff(np.sort(self.positions))
# avg. b/c that's what ITK seems to use, so use for consistency
spacing = float(np.mean(slice_positions_diffs).item())
elif len(self) == 1:
spacing = float(getattr(self.slices[0], "SpacingBetweenSlices", 0))
else:
raise RuntimeError("slice_datasets must contain at least one dicom image")
return spacing
@functools.cached_property
def affine(self) -> npt.NDArray:
row_spacing, column_spacing = self.slices[0].PixelSpacing
transform: np.ndarray = np.identity(4, dtype=np.float64)
slice_spacing = self.slice_spacing or 1.0
transform[:3, 0] = self.cosines.row * column_spacing
transform[:3, 1] = self.cosines.column * row_spacing
transform[:3, 2] = self.cosines.slice * slice_spacing
transform[:3, 3] = self.patient_position
transform_ras: npt.NDArray = np.dot(miou.flipxy_44(), transform)
return transform_ras
@staticmethod
def _approx_unique(
values: typing.Sequence[T],
*,
atol: float = ORIENTATION_ATOL,
) -> tuple[T, ...]:
# TODO: improve computational efficiency
# TODO: fix bad init -> bad result
if not values:
return tuple()
approx_unique: dict[T, int] = dict()
for val in values:
min_dist = np.inf
min_dist_val = None
for target_val in approx_unique.keys():
np_val = miou.to_f64(typing.cast(npt.ArrayLike, val))
np_tgt_val = miou.to_f64(typing.cast(npt.ArrayLike, target_val))
if np.allclose(np_val, np_tgt_val, atol=atol):
dist = np.linalg.norm(np_val - np_tgt_val).item()
if dist < min_dist:
min_dist = dist
min_dist_val = target_val
if min_dist_val is None:
approx_unique[val] = 1
else:
approx_unique[min_dist_val] += 1
approx_unq_arrs, _ = miou.unzip(
sorted(
((arr, count) for arr, count in approx_unique.items()),
key=operator.itemgetter(1),
)
)
return approx_unq_arrs
def _zip_with(
self, *args: typing.Iterable[typing.Any]
) -> typing.Iterable[tuple[typing.Any, ...]]:
return zip(self.slices, self.indices, self.positions, *args)
[docs]@dataclasses.dataclass(frozen=True)
class DICOMDir:
slices: tuple[pydicom.Dataset, ...]
positions: tuple[float, ...]
slice_spacing: float
affine: npt.NDArray
paths: tuple[miot.PathLike, ...] | None = None
def __len__(self) -> int:
_len = len(self.slices)
if self.paths is not None and _len != (path_len := len(self.paths)):
raise RuntimeError(f"num slices {_len} != num paths {path_len}")
return _len
[docs] @classmethod
def from_datasets(
cls: typing.Type[DICOMDir],
datasets: typing.Sequence[pydicom.Dataset],
*,
paths: typing.Sequence[miot.PathLike] | None = None,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
remove_anomalous_images: bool = True,
) -> DICOMDir:
if not datasets:
msg = "Must provide at least one image DICOM dataset"
raise mioe.DicomImportException(msg)
sorted_slices = SortedSlices.from_datasets(datasets)
if remove_anomalous_images:
sorted_slices = sorted_slices.remove_anomalous_slices()
sorted_slices.check_nonuniformity(
max_nonuniformity=max_nonuniformity,
fail_outside_max_nonuniformity=fail_outside_max_nonuniformity,
)
positions = tuple(sorted(sorted_slices.positions))
idxs = sorted_slices.indices
dicom_dir = cls(
slices=sorted_slices.slices,
positions=positions,
slice_spacing=sorted_slices.slice_spacing,
affine=sorted_slices.affine,
paths=None if paths is None else tuple(paths[i] for i in idxs),
)
dicom_dir.writable(False)
dicom_dir.validate()
return dicom_dir
[docs] @classmethod
def from_path(
cls: typing.Type[DICOMDir],
dicom_path: miot.PathLike | typing.Iterable[miot.PathLike],
*,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
remove_anomalous_images: bool = True,
defer_size: str | int | None = "1 KB",
extension: str = ".dcm",
) -> DICOMDir:
gathered = gather_dicom(
dicom_path, defer_size=defer_size, extension=extension, return_paths=True
)
assert isinstance(gathered, tuple)
images, paths = gathered # unpack after type check for mypy
return cls.from_datasets(
tuple(images),
paths=paths,
max_nonuniformity=max_nonuniformity,
fail_outside_max_nonuniformity=fail_outside_max_nonuniformity,
remove_anomalous_images=remove_anomalous_images,
)
[docs] @classmethod
def from_zipped_stream(
cls: typing.Type[DICOMDir],
data_stream: typing.IO,
*,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
remove_anomalous_images: bool = True,
encryption_key: bytes | str | None = None,
**zip_kwargs: typing.Any,
) -> DICOMDir:
if encryption_key is not None:
try:
import cryptography.fernet as crypto
except (ModuleNotFoundError, ImportError) as crypto_imp_exn:
msg = "If encryption key provided, cryptography package required."
raise RuntimeError(msg) from crypto_imp_exn
fernet = crypto.Fernet(encryption_key)
data_stream.seek(0)
data_stream = io.BytesIO(fernet.decrypt(data_stream.read()))
with zipfile.ZipFile(data_stream, mode="r", **zip_kwargs) as zf:
datasets = cls.dicom_datasets_from_zip(zf)
return cls.from_datasets(
datasets,
max_nonuniformity=max_nonuniformity,
fail_outside_max_nonuniformity=fail_outside_max_nonuniformity,
remove_anomalous_images=remove_anomalous_images,
)
[docs] @staticmethod
def dicom_datasets_from_zip(
zip_file: zipfile.ZipFile,
) -> list[pydicom.Dataset]:
datasets: list[pydicom.Dataset] = []
for name in zip_file.namelist():
if name.endswith("/"):
continue # skip directories
with zip_file.open(name, mode="r") as f:
try:
datasets.append(pydicom.dcmread(f)) # type: ignore[arg-type]
except pydicom.errors.InvalidDicomError as e:
msg = f"Skipping invalid DICOM file '{name}': {e}"
logger.info(msg)
if not datasets:
msg = "Zipfile does not contain any valid DICOM files"
raise mioe.DicomImportException(msg)
return datasets
[docs] def validate(self) -> None:
invariant_properties = frozenset(
(
"BitsAllocated",
"Columns",
"Modality",
"PixelRepresentation",
"PixelSpacing",
"Rows",
"SamplesPerPixel",
"SeriesInstanceUID",
"SOPClassUID",
)
)
for property_name in invariant_properties:
self._slice_attribute_equal(property_name)
self._slice_attribute_almost_equal(
"ImageOrientationPatient", atol=ORIENTATION_ATOL
)
[docs] def writable(self, value: bool, /) -> None:
self.affine.flags.writeable = value
def _slice_attribute_equal(self, property_name: str) -> None:
initial_value = getattr(self.slices[0], property_name, None)
for dataset in self.slices[1:]:
value = getattr(dataset, property_name, None)
if value != initial_value:
msg = "All slices must have the same value for "
msg += f"'{property_name}': {value} != {initial_value}"
raise mioe.DicomImportException(msg)
def _slice_attribute_almost_equal(
self,
property_name: str,
*,
atol: float,
) -> None:
initial_value: miot.SupportsArray | None
initial_value = getattr(self.slices[0], property_name, None)
for dataset in self.slices[1:]:
value: miot.SupportsArray | None
value = getattr(dataset, property_name, None)
if value is None or initial_value is None:
msg = f"All slices must contain the attribute {property_name}"
raise mioe.DicomImportException(msg)
if not np.allclose(value, initial_value, atol=atol):
msg = "All slices must have the same value for "
msg += f"'{property_name}' within '{atol}': {value} != {initial_value}"
raise mioe.DicomImportException(msg)
@staticmethod
def _is_dicomdir(dataset: pydicom.Dataset) -> bool:
media_sop_class: str | None
media_sop_class = getattr(dataset, "MediaStorageSOPClassUID", None)
result: bool = media_sop_class == "1.2.840.10008.1.3.10"
return result
[docs]class DICOMImage(miob.BasicImage[typing.Any, miot.DType]): # type: ignore[type-arg]
[docs] @classmethod
def from_dicomdir(
cls: typing.Type[DICOMImage],
dicom_dir: DICOMDir,
*,
rescale: bool | None = None,
rescale_dtype: typing.Type[miot.DType] | None = None,
order: typing.Literal["F", "C"] | None = None,
) -> DICOMImage[miot.DType]:
data = cls._merge_slice_pixel_arrays(
dicom_dir.slices, rescale=rescale, rescale_dtype=rescale_dtype, order=order
)
return cls(data=data, affine=dicom_dir.affine)
[docs] @classmethod
def from_path(
cls: typing.Type[DICOMImage],
dicom_path: miot.PathLike | typing.Iterable[miot.PathLike],
*,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
remove_anomalous_images: bool = True,
rescale: bool | None = None,
rescale_dtype: typing.Type[miot.DType] | None = None,
order: typing.Literal["F", "C"] | None = None,
extension: str = ".dcm",
) -> DICOMImage[miot.DType]:
dicomdir = DICOMDir.from_path(
dicom_path,
max_nonuniformity=max_nonuniformity,
fail_outside_max_nonuniformity=fail_outside_max_nonuniformity,
remove_anomalous_images=remove_anomalous_images,
extension=extension,
)
return cls.from_dicomdir(
dicomdir, rescale=rescale, rescale_dtype=rescale_dtype, order=order
)
[docs] @classmethod
def from_zipped_stream(
cls: typing.Type[DICOMImage],
data_stream: typing.IO,
*,
max_nonuniformity: float = 5e-4,
fail_outside_max_nonuniformity: bool = True,
remove_anomalous_images: bool = True,
encryption_key: bytes | str | None = None,
rescale: bool | None = None,
rescale_dtype: typing.Type[miot.DType] | None = None,
order: typing.Literal["F", "C"] | None = None,
**zip_kwargs: typing.Any,
) -> DICOMImage[miot.DType]:
dicomdir = DICOMDir.from_zipped_stream(
data_stream,
max_nonuniformity=max_nonuniformity,
fail_outside_max_nonuniformity=fail_outside_max_nonuniformity,
remove_anomalous_images=remove_anomalous_images,
encryption_key=encryption_key,
**zip_kwargs,
)
return cls.from_dicomdir(
dicomdir, rescale=rescale, rescale_dtype=rescale_dtype, order=order
)
@classmethod
def _merge_slice_pixel_arrays(
cls: typing.Type[DICOMImage],
slices: typing.Sequence[pydicom.Dataset],
*,
rescale: bool | None = None,
rescale_dtype: typing.Type[miot.DType] | None = None,
order: typing.Literal["F", "C"] | None = None,
) -> npt.NDArray:
if rescale is None:
rescale = any(cls._requires_rescaling(d) for d in slices)
if rescale and rescale_dtype is None:
rescale_dtype = np.float32 # type: ignore[assignment]
first_dataset = slices[0]
slice_dtype = first_dataset.pixel_array.dtype
slice_shape = first_dataset.pixel_array.T.shape
slice_order = "F" if first_dataset.pixel_array.T.flags.f_contiguous else "C"
num_slices = len(slices)
voxels_shape = slice_shape + (num_slices,)
voxels_dtype = rescale_dtype if rescale else slice_dtype
voxels = np.empty(
voxels_shape,
dtype=voxels_dtype,
order=typing.cast(typing.Literal["F", "C"], slice_order),
)
for k, dataset in enumerate(slices):
pixel_array = dataset.pixel_array.T.astype(voxels_dtype, copy=False)
if rescale:
slope = float(getattr(dataset, "RescaleSlope", 1.0))
intercept = float(getattr(dataset, "RescaleIntercept", 0.0))
if slope != 1.0:
pixel_array *= slope
if intercept != 0.0:
pixel_array += intercept
voxels[..., k] = pixel_array
if order is not None:
if order == "C":
voxels = np.ascontiguousarray(voxels)
elif order == "F":
voxels = np.asfortranarray(voxels)
else:
msg = f"If order given, must be either 'F' or 'C'. Got {order}."
raise ValueError(msg)
if voxels.dtype != voxels_dtype:
exn_msg = f"voxels.dtype {voxels.dtype} != requested dtype {voxels_dtype}"
raise RuntimeError(exn_msg)
return voxels
@staticmethod
def _requires_rescaling(dataset: pydicom.Dataset) -> bool:
return hasattr(dataset, "RescaleSlope") or hasattr(dataset, "RescaleIntercept")