Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/cleanvision/dataset/fsspec_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List, Optional, Union, Dict

from PIL import Image
from PIL import Image, UnidentifiedImageError

from cleanvision.utils.constants import IMAGE_FILE_EXTENSIONS
from cleanvision.dataset.base_dataset import Dataset
Expand Down Expand Up @@ -82,6 +82,18 @@ def __get_filepaths(self, dataset_path: str, verbose: bool) -> List[str]:
continue
filepaths += filetype_images
unique_filepaths = list(set(filepaths))
valid_filepaths = []
for idx, path in enumerate(unique_filepaths):
try:
with self.fs.open(path, "rb", **self.storage_opts) as f:
img = Image.open(f)
img.verify() # integrity check (no pixel decode)
valid_filepaths.append(path)
del img
except (UnidentifiedImageError, OSError, ValueError):
# silently drop corrupt images
if verbose:
print(f"Skipping corrupted image: {path}")
return sorted(
unique_filepaths
valid_filepaths
) # sort image names alphabetically and numerically
20 changes: 18 additions & 2 deletions src/cleanvision/dataset/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Union

from PIL import Image
from PIL import Image, UnidentifiedImageError

from cleanvision.dataset.base_dataset import Dataset

Expand All @@ -17,9 +17,14 @@ def __init__(self, torch_dataset: "VisionDataset") -> None:
super().__init__()
self._data = torch_dataset
# todo: catch errors
self._image_idx = None
for i, obj in enumerate(torch_dataset[0]):
if isinstance(obj, Image.Image):
self._image_idx = i

if self._image_idx is None:
raise ValueError("No PIL image found in torchvision dataset sample")

self._set_index()

def __len__(self) -> int:
Expand All @@ -32,4 +37,15 @@ def get_name(self, index: Union[int, str]) -> str:
return f"idx: {index}"

def _set_index(self) -> None:
self.index = [i for i in range(len(self._data))]
valid_indices = []
for i in range(len(self._data)):
try:
img = self._data[i][self._image_idx]
if not isinstance(img, Image.Image):
raise UnidentifiedImageError
valid_indices.append(i)
except (UnidentifiedImageError, OSError, ValueError, TypeError):
print(f"Warning: Skipping corrupted image at index {i}")
continue
# self.index = [i for i in range(len(self._data))]
self.index = valid_indices
Loading