-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_dataset.py
146 lines (114 loc) · 5.8 KB
/
build_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from torch.utils.data import random_split, ConcatDataset
from torchvision import datasets
from torchvision import transforms
class DataLoader:
def __init__(self, dataset):
self.dataset_name = dataset
self.train_dataset = None
self.test_dataset = None
self.train_data_loader = None
self.test_data_loader = None
def get_train_test_dataloader(self, batch_size, shuffle=True, split_percentage=1, num_workers=1, pin_memory=False):
error_msg = "[!] split_percentage should be in the range [0, 1]."
assert ((split_percentage >= 0) and (split_percentage <= 1)), error_msg
self.__load_train_dataset()
subset_trainA, subset_trainB = self.get_train_dataloader(batch_size=batch_size,
split_percentage=split_percentage)
self.train_data_loader = torch.utils.data.DataLoader(
subset_trainA, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
)
self.__load_test_dataset()
subset_test_dataset = ConcatDataset([subset_trainB, self.test_dataset])
self.test_data_loader = torch.utils.data.DataLoader(
subset_test_dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory
)
return self.train_data_loader, self.test_data_loader
def get_train_dataloader(self, batch_size, split_percentage=1,
shuffle=True, num_workers=1, pin_memory=False):
error_msg = "[!] split_percentage should be in the range [0, 1]."
assert ((split_percentage >= 0) and (split_percentage <= 1)), error_msg
self.__load_train_dataset()
if split_percentage == 1:
print(num_workers)
self.train_data_loader = torch.utils.data.DataLoader(
self.train_dataset, batch_size=batch_size, shuffle=shuffle,
pin_memory=pin_memory)
return self.train_data_loader
else:
# split_lengths = [int(len(self.train_dataset) * split_percentage),
# int(len(self.train_dataset) * (1 - split_percentage))]
#
# subset_trainA, subset_trainB = random_split(self.train_dataset, split_lengths)
return self.get_train_dataset_split(split_percentage=split_percentage)
def get_train_dataset_split(self, split_percentage):
self.__load_train_dataset()
# split_lengths = [int(len(self.train_dataset) * split_percentage),
# int(len(self.train_dataset) * (1 - split_percentage))]
subset_trainA_length = int(len(self.train_dataset) * split_percentage)
subset_trainB_length = len(self.train_dataset) - subset_trainA_length
subset_trainA, subset_trainB = random_split(self.train_dataset,
[subset_trainA_length, subset_trainB_length])
return subset_trainA, subset_trainB
def get_test_dataset_split(self, split_percentage):
subset_trainA, subset_trainB = self.get_train_dataset_split(split_percentage=split_percentage)
self.__load_test_dataset()
subset_test_dataset = ConcatDataset([subset_trainA, self.test_dataset])
return subset_test_dataset
def get_test_loader(self, batch_size, shuffle=False, num_workers=1, pin_memory=False):
self.__load_test_dataset()
self.test_data_loader = torch.utils.data.DataLoader(
self.test_dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory
)
return self.test_data_loader
def get_classes(self):
return self.train_dataset.classes
def __load_train_dataset(self):
if self.dataset_name == "MNIST":
self.__load_train_mnist_dataset(train_transform=self.__get_transform())
elif self.dataset_name == "FashionMNIST":
self.__load_train_fashionMnist_dataset(train_transform=self.__get_transform())
def __load_train_mnist_dataset(self, train_transform=None):
self.train_dataset = datasets.MNIST(
root='./data/mnist', train=True,
download=True, transform=train_transform,
)
def __load_train_fashionMnist_dataset(self, train_transform=None):
self.train_dataset = datasets.FashionMNIST(
root='./data/fashionMnist', train=True,
download=True, transform=train_transform,
)
def __load_test_dataset(self):
if self.dataset_name == "MNIST":
self.__load_test_mnist_dataset(transform=self.__get_transform())
elif self.dataset_name == "FashionMNIST":
self.__load_test_fashionMnist_dataset(transform=self.__get_transform())
def __load_test_mnist_dataset(self, transform=None):
self.test_dataset = datasets.MNIST(
root='data/mnist', train=False,
download=True, transform=transform,
)
def __load_test_fashionMnist_dataset(self, transform=None):
self.test_dataset = datasets.FashionMNIST(
root='./data/fashionMnist', train=False,
download=True, transform=transform,
)
def __get_transform(self):
global normalize
if self.dataset_name == "MNIST":
normalize = self.__get_mnist_normalize_val()
elif self.dataset_name == "FashionMNIST":
normalize = self.__get_fashionMnist_normalize_val()
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
return train_transform
@staticmethod
def __get_mnist_normalize_val():
return transforms.Normalize((0.1307,), (0.3081,))
@staticmethod
def __get_fashionMnist_normalize_val():
return transforms.Normalize((0.5,), (0.5,))