-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
110 lines (80 loc) · 3.02 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from __future__ import with_statement
import os
import glob
import h5py
import random
import numpy as np
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
def random_crop(hr, lr, size, scale):
h, w = lr.shape[:-1]
x = random.randint(0, w-size)
y = random.randint(0, h-size)
hsize = size*scale
hx, hy = x*scale, y*scale
crop_lr = lr[y:y+size, x:x+size].copy()
crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()
return crop_hr, crop_lr
def random_flip_and_rotate(im1, im2):
if random.random() < 0.5:
im1 = np.flipud(im1)
im2 = np.flipud(im2)
if random.random() < 0.5:
im1 = np.fliplr(im1)
im2 = np.fliplr(im2)
angle = random.choice([0, 1, 2, 3])
im1 = np.rot90(im1, angle)
im2 = np.rot90(im2, angle)
# have to copy before be called by transform function
return im1.copy(), im2.copy()
class TrainDataset(data.Dataset):
def __init__(self, path, size, scale):
super(TrainDataset, self).__init__()
self.size = size
h5f = h5py.File(path, "r")
self.hr = [v[:] for v in h5f["HR"].values()]
self.scale = [scale]
self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]
h5f.close()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
size = self.size
item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
item = [random_flip_and_rotate(hr, lr) for hr, lr in item]
return [(self.transform(hr), self.transform(lr)) for hr, lr in item]
def __len__(self):
return len(self.hr)
class TestDataset(data.Dataset):
def __init__(self, dirname, scale):
super(TestDataset, self).__init__()
self.name = dirname.split("/")[-1]
self.scale = scale
if "DIV" in self.name:
self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname),
"X{}/*.png".format(scale)))
else:
all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
self.hr = [name for name in all_files if "HR" in name]
self.lr = [name for name in all_files if "LR" in name]
self.hr.sort()
self.lr.sort()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
hr = Image.open(self.hr[index])
lr = Image.open(self.lr[index])
hr = hr.convert("RGB")
lr = lr.convert("RGB")
filename = self.hr[index].split("/")[-1]
return self.transform(hr), self.transform(lr), filename
def __len__(self):
return len(self.hr)