From 2d5eb24e7035ba485b596cd8f8d638ebdccdef3e Mon Sep 17 00:00:00 2001 From: alexanderjaus Date: Thu, 11 Dec 2025 14:54:01 +0100 Subject: [PATCH 1/3] Fix EDT indices allocation for channel-first tensors on CPU/GPU --- monai/transforms/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b50508962f..2da00fad4f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -2498,7 +2498,7 @@ def distance_transform_edt( if return_indices: dtype = torch.int32 if indices is None: - indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore + indices = torch.zeros((img.shape[0],) + (img.dim()-1,) + img.shape[1:], dtype=dtype) # type: ignore else: if not isinstance(indices, torch.Tensor) and indices.device != img.device: raise TypeError("indices must be a torch.Tensor on the same device as img") @@ -2532,7 +2532,7 @@ def distance_transform_edt( raise TypeError("distances must be a numpy.ndarray of dtype float64") if return_indices: if indices is None: - indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32) + indices = np.zeros((img_.shape[0],)+(img_.ndim-1,) + img_.shape[1:], dtype=np.int32) else: if not isinstance(indices, np.ndarray): raise TypeError("indices must be a numpy.ndarray") From 258a4726c8b8163eec0e3bbf6515afe3a2a31a99 Mon Sep 17 00:00:00 2001 From: alexanderjaus Date: Thu, 11 Dec 2025 15:26:15 +0100 Subject: [PATCH 2/3] Apply codestyle fixes --- monai/transforms/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 2da00fad4f..4ad60483fd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -2498,7 +2498,7 @@ def distance_transform_edt( if return_indices: dtype = torch.int32 if indices is None: - indices = torch.zeros((img.shape[0],) + (img.dim()-1,) + img.shape[1:], dtype=dtype) # type: ignore + indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype) # type: ignore else: if not isinstance(indices, torch.Tensor) and indices.device != img.device: raise TypeError("indices must be a torch.Tensor on the same device as img") @@ -2532,7 +2532,7 @@ def distance_transform_edt( raise TypeError("distances must be a numpy.ndarray of dtype float64") if return_indices: if indices is None: - indices = np.zeros((img_.shape[0],)+(img_.ndim-1,) + img_.shape[1:], dtype=np.int32) + indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32) else: if not isinstance(indices, np.ndarray): raise TypeError("indices must be a numpy.ndarray") From 838e347f61fa4f1ebf919e51f60b29354a935ad0 Mon Sep 17 00:00:00 2001 From: alexanderjaus Date: Sat, 13 Dec 2025 15:00:59 +0100 Subject: [PATCH 3/3] DCO Remediation Commit for alexanderjaus I, alexanderjaus , hereby add my Signed-off-by to this commit: 2d5eb24e7035ba485b596cd8f8d638ebdccdef3e I, alexanderjaus , hereby add my Signed-off-by to this commit: 258a4726c8b8163eec0e3bbf6515afe3a2a31a99 Signed-off-by: alexanderjaus