diff --git a/src/cleanvision/dataset/fsspec_dataset.py b/src/cleanvision/dataset/fsspec_dataset.py index 942dd05b..ad70104f 100644 --- a/src/cleanvision/dataset/fsspec_dataset.py +++ b/src/cleanvision/dataset/fsspec_dataset.py @@ -2,7 +2,7 @@ from typing import List, Optional, Union, Dict -from PIL import Image +from PIL import Image, UnidentifiedImageError from cleanvision.utils.constants import IMAGE_FILE_EXTENSIONS from cleanvision.dataset.base_dataset import Dataset @@ -82,6 +82,18 @@ def __get_filepaths(self, dataset_path: str, verbose: bool) -> List[str]: continue filepaths += filetype_images unique_filepaths = list(set(filepaths)) + valid_filepaths = [] + for idx, path in enumerate(unique_filepaths): + try: + with self.fs.open(path, "rb", **self.storage_opts) as f: + img = Image.open(f) + img.verify() # integrity check (no pixel decode) + valid_filepaths.append(path) + del img + except (UnidentifiedImageError, OSError, ValueError): + # silently drop corrupt images + if verbose: + print(f"Skipping corrupted image: {path}") return sorted( - unique_filepaths + valid_filepaths ) # sort image names alphabetically and numerically diff --git a/src/cleanvision/dataset/torch_dataset.py b/src/cleanvision/dataset/torch_dataset.py index 8e8b7973..d63ad4c7 100644 --- a/src/cleanvision/dataset/torch_dataset.py +++ b/src/cleanvision/dataset/torch_dataset.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Union -from PIL import Image +from PIL import Image, UnidentifiedImageError from cleanvision.dataset.base_dataset import Dataset @@ -17,9 +17,14 @@ def __init__(self, torch_dataset: "VisionDataset") -> None: super().__init__() self._data = torch_dataset # todo: catch errors + self._image_idx = None for i, obj in enumerate(torch_dataset[0]): if isinstance(obj, Image.Image): self._image_idx = i + + if self._image_idx is None: + raise ValueError("No PIL image found in torchvision dataset sample") + self._set_index() def __len__(self) -> int: @@ -32,4 +37,15 @@ def get_name(self, index: Union[int, str]) -> str: return f"idx: {index}" def _set_index(self) -> None: - self.index = [i for i in range(len(self._data))] + valid_indices = [] + for i in range(len(self._data)): + try: + img = self._data[i][self._image_idx] + if not isinstance(img, Image.Image): + raise UnidentifiedImageError + valid_indices.append(i) + except (UnidentifiedImageError, OSError, ValueError, TypeError): + print(f"Warning: Skipping corrupted image at index {i}") + continue + # self.index = [i for i in range(len(self._data))] + self.index = valid_indices