Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5394658
Final commit: add clinical DICOM preprocessing files, workflow PDF, a…
Hitendrasinhdata7 Dec 14, 2025
aeabbbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2025
24665aa
Add clinical DICOM preprocessing Python module, test module, PDF; rem…
Hitendrasinhdata7 Dec 14, 2025
798f8af
Remove old notebook files after converting to .py modules
Hitendrasinhdata7 Dec 14, 2025
b3a6ac2
Add clinical DICOM preprocessing utilities for CT/MRI with unit tests
Hitendrasinhdata7 Dec 14, 2025
a446448
Update clinical preprocessing utilities and tests per CodeRabbit revi…
Hitendrasinhdata7 Dec 14, 2025
d7f134c
Refactor clinical preprocessing: add custom exceptions, use isinstanc…
Hitendrasinhdata7 Dec 14, 2025
ce7850f
Update clinical preprocessing: add Google-style Returns, parameter ch…
Hitendrasinhdata7 Dec 14, 2025
01b88fa
Fix clinical preprocessing module based on code review feedback
Hitendrasinhdata7 Dec 17, 2025
81b7f5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
821fc9a
Complete fix for all critical code review issues
Hitendrasinhdata7 Dec 17, 2025
1a15437
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
d12bd51
Hitendrasinh Rathod <Hitendrasinh.data7@gmail.com>
Hitendrasinhdata7 Dec 17, 2025
634136a
Merge branch 'clinical-dicom-preprocessing' of https://github.com/Hit…
Hitendrasinhdata7 Dec 17, 2025
34a7aa2
Complete fix for all CI and code review issues
Hitendrasinhdata7 Dec 17, 2025
3a493d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
2b0f6a7
Add MetaTensor import and return type hint
Hitendrasinhdata7 Dec 17, 2025
f0261b8
Hitendrasinh Rathod <Hitendrasinh.data7@gmail.com>
Hitendrasinhdata7 Dec 17, 2025
1418dcc
Merge branch 'clinical-dicom-preprocessing' of https://github.com/Hit…
Hitendrasinhdata7 Dec 17, 2025
1e68a5a
Fix docstring: add Returns section and correct Raises section formatting
Hitendrasinhdata7 Dec 17, 2025
7030e7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
441d48d
Add clinical preprocessing transforms
Hitendrasinhdata7 Dec 20, 2025
11dce12
Resolve merge conflict - keep clinical preprocessing module
Hitendrasinhdata7 Dec 20, 2025
f513c16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2025
2986d2e
Fix CodeRabbit review issues
Hitendrasinhdata7 Dec 20, 2025
b3e8f87
Address CodeRabbit review suggestions
Hitendrasinhdata7 Dec 20, 2025
5a4fbb5
Final fixes per CodeRabbit review
Hitendrasinhdata7 Dec 20, 2025
c273cf7
Fix CI compliance: formatting, type hints, line length
Hitendrasinhdata7 Dec 20, 2025
d0961e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2025
9753c34
Fix CI compliance for clinical preprocessing
Hitendrasinhdata7 Dec 20, 2025
0daee66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2025
885ae4b
Fix CodeRabbit issues: return type and exception classes
Hitendrasinhdata7 Dec 20, 2025
e50560d
Merge remote changes and fix CodeRabbit review issues
Hitendrasinhdata7 Dec 20, 2025
68038cb
Improve tests per CodeRabbit suggestions
Hitendrasinhdata7 Dec 20, 2025
b241ee6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2025
c4c2303
Improve mock setup and add MRI validation
Hitendrasinhdata7 Dec 20, 2025
01b60f0
Merge formatting and apply test improvements
Hitendrasinhdata7 Dec 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/clinical_dicom_workflow.pdf
Binary file not shown.
Binary file added monai/docs/clinical_dicom_workflow.pdf
Binary file not shown.
160 changes: 160 additions & 0 deletions monai/tests/test_clinical_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock, patch

import numpy as np
import pytest

from monai.data import write_nifti
from monai.transforms import EnsureChannelFirst, LoadImage, NormalizeIntensity, ScaleIntensityRange
from monai.transforms.clinical_preprocessing import (
ModalityTypeError,
UnsupportedModalityError,
get_ct_preprocessing_pipeline,
get_mri_preprocessing_pipeline,
preprocess_dicom_series,
preprocess_medical_image,
)


def test_ct_preprocessing_pipeline_structure():
"""Test CT pipeline structure."""
pipeline = get_ct_preprocessing_pipeline()
transforms = pipeline.transforms

assert len(transforms) == 3
assert isinstance(transforms[0], LoadImage)
assert transforms[0].image_only is True
assert transforms[1].__class__ is EnsureChannelFirst
assert isinstance(transforms[2], ScaleIntensityRange)

scale = transforms[2]
assert scale.a_min == -1000
assert scale.a_max == 400
assert scale.b_min == 0.0
assert scale.b_max == 1.0
assert scale.clip is True


def test_mri_preprocessing_pipeline_structure():
"""Test MRI pipeline structure."""
pipeline = get_mri_preprocessing_pipeline()
transforms = pipeline.transforms

assert len(transforms) == 3
assert isinstance(transforms[0], LoadImage)
assert transforms[0].image_only is True
assert transforms[1].__class__ is EnsureChannelFirst
assert isinstance(transforms[2], NormalizeIntensity)
assert transforms[2].nonzero is True


def test_invalid_modality_type():
"""Test non-string modality input."""
with pytest.raises(ModalityTypeError) as exc:
preprocess_dicom_series("dummy", 123)

assert "modality must be a string" in str(exc.value)

with pytest.raises(ModalityTypeError) as exc:
preprocess_medical_image("dummy", None)

assert "modality must be a string" in str(exc.value)


def test_unsupported_modality():
"""Test unsupported modality."""
with pytest.raises(UnsupportedModalityError) as exc:
preprocess_dicom_series("dummy", "PET")

msg = str(exc.value)
assert "Unsupported modality" in msg
assert "PET" in msg
assert "CT" in msg
assert "MR" in msg
assert "MRI" in msg

with pytest.raises(UnsupportedModalityError) as exc:
preprocess_medical_image("dummy", "PET")

msg = str(exc.value)
assert "Unsupported modality" in msg
assert "PET" in msg
assert "CT" in msg
assert "MR" in msg
assert "MRI" in msg


@patch("monai.transforms.clinical_preprocessing.LoadImage")
def test_modality_case_insensitivity(mock_load):
"""Test case-insensitive modality handling with whitespace trimming."""
mock_load.return_value = Mock(return_value=Mock())

test_cases = ["CT", "ct", "Ct", "CT ", " CT", "MR", "mr", "MRI", "mri", " MrI "]

for modality in test_cases:
result = preprocess_dicom_series("dummy.dcm", modality)
assert result is not None, f"Failed for modality: '{modality}'"
result2 = preprocess_medical_image("dummy.dcm", modality)
assert result2 is not None, f"preprocess_medical_image failed for modality: '{modality}'"


@patch("monai.transforms.clinical_preprocessing.LoadImage")
def test_mr_modality_distinct(mock_load):
"""Test MR modality is handled separately from MRI."""
mock_load.return_value = Mock(return_value=Mock())
result = preprocess_dicom_series("dummy.dcm", "MR")
assert result is not None
result2 = preprocess_medical_image("dummy.dcm", "MR")
assert result2 is not None


@patch("monai.transforms.clinical_preprocessing.LoadImage")
def test_edge_cases(mock_load):
"""Test edge cases for modality input."""
mock_load.return_value = Mock(return_value=Mock())

with pytest.raises(UnsupportedModalityError):
preprocess_dicom_series("dummy.dcm", "")

with pytest.raises(UnsupportedModalityError):
preprocess_dicom_series("dummy.dcm", " ")

long_modality = "CT" * 100
with pytest.raises(UnsupportedModalityError):
preprocess_dicom_series("dummy.dcm", long_modality)


def test_preprocess_dicom_series_integration(tmp_path):
"""Integration test with dummy NIfTI file."""
dummy_data = np.random.randn(64, 64, 64).astype(np.float32)
test_file = tmp_path / "test.nii.gz"

write_nifti(dummy_data, test_file)

for modality in ["CT", "MR", "MRI"]:
result = preprocess_dicom_series(str(test_file), modality)
assert result is not None
assert hasattr(result, "shape")
assert len(result.shape) == 4 # (C, H, W, D)
assert result.shape[0] == 1 # single channel

if modality == "CT":
# CT output should be in [0, 1] due to ScaleIntensityRange
assert result.min() >= 0.0
assert result.max() <= 1.0

result2 = preprocess_medical_image(str(test_file), modality)
assert result2 is not None
assert hasattr(result2, "shape")
assert len(result2.shape) == 4
assert result2.shape[0] == 1
10 changes: 10 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,16 @@
TransposeD,
TransposeDict,
)

# Clinical preprocessing utilities
from .clinical_preprocessing import (
ModalityTypeError,
UnsupportedModalityError,
get_ct_preprocessing_pipeline,
get_mri_preprocessing_pipeline,
preprocess_dicom_series,
)

from .utils import (
Fourier,
allow_missing_keys_mode,
Expand Down
153 changes: 153 additions & 0 deletions monai/transforms/clinical_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Clinical preprocessing transforms for medical imaging data.

This module provides modality-specific preprocessing pipelines for common medical imaging modalities.
"""

import torch

from monai.transforms import (
Compose,
EnsureChannelFirst,
LoadImage,
NormalizeIntensity,
ScaleIntensityRange,
)


class ModalityTypeError(TypeError):
"""Raised when modality is not a string."""
def __init__(self):
super().__init__("modality must be a string")


class UnsupportedModalityError(ValueError):
"""Raised when an unsupported modality is requested."""
def __init__(self, modality: str):
super().__init__(
f"Unsupported modality '{modality}'. Supported modalities: CT, MR, MRI"
)


def get_ct_preprocessing_pipeline() -> Compose:
"""
Create a preprocessing pipeline for CT images.

Returns:
Compose: Transform composition for CT preprocessing.
Applies Hounsfield Unit (HU) windowing [-1000, 400] scaled to [0, 1].
This range captures lung (-1000 to -400 HU) and soft tissue (0 to 100 HU) contrast.

Note:
Output will be a single-channel tensor with shape (1, H, W, D)
and values in range [0, 1].
"""
return Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensityRange(
a_min=-1000,
a_max=400,
b_min=0.0,
b_max=1.0,
clip=True,
),
]
)


def get_mri_preprocessing_pipeline() -> Compose:
"""
Create a preprocessing pipeline for MRI images.

Returns:
Compose: Transform composition for MRI preprocessing.
Normalizes intensities using nonzero voxels only, excluding background regions
typical in MRI acquisitions.

Note:
Output will be a single-channel tensor with shape (1, H, W, D)
normalized based on nonzero voxel statistics.
"""
return Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
NormalizeIntensity(nonzero=True),
]
)


def preprocess_medical_image(path: str, modality: str) -> torch.Tensor:
"""
Preprocess a medical image based on imaging modality.

Args:
path: Path to the medical image file. Supports various formats including
DICOM, NIfTI, and others supported by MONAI's LoadImage transform.
modality: Imaging modality. Supported values are "CT", "MR", and "MRI" (case-insensitive).

Returns:
Preprocessed image data as a torch.Tensor (or MetaTensor with metadata).

Raises:
ModalityTypeError: If modality is not a string.
UnsupportedModalityError: If the provided modality is not supported.
"""
if not isinstance(modality, str):
raise ModalityTypeError()

modality_clean = modality.strip().upper()

if modality_clean in {"MR", "MRI"}:
pipeline = get_mri_preprocessing_pipeline()
elif modality_clean == "CT":
pipeline = get_ct_preprocessing_pipeline()
else:
raise UnsupportedModalityError(modality)

return pipeline(path)


# Keep the old function name for backward compatibility
def preprocess_dicom_series(path: str, modality: str) -> torch.Tensor:
"""
Preprocess a DICOM series or file based on imaging modality.

Note: This function also supports other medical image formats
(NIfTI, etc.) through MONAI's LoadImage transform.

Args:
path: Path to the DICOM file or directory containing a DICOM series.
modality: Imaging modality. Supported values are "CT", "MR", and "MRI" (case-insensitive).

Returns:
Preprocessed image data as a torch.Tensor (or MetaTensor with metadata).

Raises:
ModalityTypeError: If modality is not a string.
UnsupportedModalityError: If the provided modality is not supported.
"""
return preprocess_medical_image(path, modality)


__all__ = [
"ModalityTypeError",
"UnsupportedModalityError",
"get_ct_preprocessing_pipeline",
"get_mri_preprocessing_pipeline",
"preprocess_dicom_series",
"preprocess_medical_image",
]
Loading