-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfolder.py
141 lines (108 loc) · 5.38 KB
/
folder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# ---------------------------------------
# Modified from torchvision by QIU Tian
# ---------------------------------------
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from ._base_ import BaseDataset
__all__ = ['ImageFolder']
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
class ImageFolder(BaseDataset):
def __init__(self, root, split, transform=None, target_transform=None, loader=None,
is_valid_file=None):
super().__init__(root, split, transform, target_transform, loader)
extensions = IMG_EXTENSIONS if is_valid_file is None else None
split_folder = os.path.join(self.root, split)
classes, class_to_idx = self.find_classes(split_folder)
samples = self.make_dataset(split_folder, class_to_idx, extensions, is_valid_file)
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
if class_to_idx is None:
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path, format="RGB")
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)