diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b50508962f..4ad60483fd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -2498,7 +2498,7 @@ def distance_transform_edt( if return_indices: dtype = torch.int32 if indices is None: - indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore + indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype) # type: ignore else: if not isinstance(indices, torch.Tensor) and indices.device != img.device: raise TypeError("indices must be a torch.Tensor on the same device as img") @@ -2532,7 +2532,7 @@ def distance_transform_edt( raise TypeError("distances must be a numpy.ndarray of dtype float64") if return_indices: if indices is None: - indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32) + indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32) else: if not isinstance(indices, np.ndarray): raise TypeError("indices must be a numpy.ndarray")