-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathdataset.py
executable file
·122 lines (95 loc) · 3.47 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
__all__ = ['MultiScaleDataset',
'ImageDataset'
]
from io import BytesIO
import math
import lmdb
from PIL import Image
from torch.utils.data import Dataset
import torch
import numpy as np
import tensor_transforms as tt
class MultiScaleDataset(Dataset):
def __init__(self, path, transform, resolution=256, to_crop=False, crop_size=64, integer_values=False):
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
self.crop_size = crop_size
self.integer_values = integer_values
self.n = resolution // crop_size
self.log_size = int(math.log(self.n, 2))
self.crop = tt.RandomCrop(crop_size)
self.crop_resolution = tt.RandomCrop(resolution)
self.to_crop = to_crop
self.coords = tt.convert_to_coord_format(1, resolution, resolution, integer_values=self.integer_values)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
self.resolution = resolution
self.transform = transform
def __len__(self):
return self.length
def __getitem__(self, index):
data = {}
with self.env.begin(write=False) as txn:
key = f'{str(index).zfill(7)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
img = self.transform(img).unsqueeze(0)
if self.to_crop:
img = self.crop_resolution(img)
stack = torch.cat([img, self.coords], 1)
del img
data[0] = self.crop(stack).squeeze(0)
stack = stack.squeeze(0)
stack_strided = None
for ls in range(self.log_size, 0, -1):
n = 2 ** ls
bias = self.resolution - n*self.crop_size + n
bw = np.random.randint(bias)
bh = np.random.randint(bias)
stack_strided = stack[:, bw::n, bh::n]
if stack_strided.size(2) != self.crop or stack_strided.size(1) != self.crop:
data[ls] = self.crop(stack_strided.unsqueeze(0)).squeeze(0)
else:
data[ls] = stack_strided
del stack
del stack_strided
return data
class ImageDataset(Dataset):
def __init__(self, path, transform, resolution=256, to_crop=False):
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
self.crop = tt.RandomCrop(resolution)
self.to_crop = to_crop
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
self.resolution = resolution
self.transform = transform
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{str(index).zfill(7)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
img = self.transform(img)
if self.to_crop:
img = self.crop(img.unsqueeze(0)).squeeze(0)
return img