-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch-dataset
41 lines (35 loc) · 1.26 KB
/
pytorch-dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import xarray as xr
import numpy as np
import torch
from torch.nn.functional import one_hot
from torch.utils.data import Dataset
class xarrayPixelDataset(Dataset):
"""
Pixel time series dataset
"""
def __init__(self, data_file, num_classes) -> None:
super().__init__()
self.num_classes = num_classes
self.data_is_loaded = False
self.data_file = data_file
def __getitem__(self, index):
if not self.data_is_loaded:
self.data = xr.load_dataset(self.data_file, engine="h5netcdf")
self.data_is_loaded = True
sample = self.data.isel(pixel=index)
images = torch.tensor(
sample.images.values.astype(np.int16), dtype=torch.float32
)
labels = torch.tensor(sample.labels.values.astype(np.int16), dtype=torch.int64)
return images, labels
def __len__(self):
with xr.open_dataset(self.data_file) as ds:
return ds.sizes["pixel"]
def main():
train_datasets = []
for file in <list_of_nc_files>:
data_file = os.path.join(self.root_dir, "train", file)
train_datasets.append(xarrayPixelDataset(data_file=data_file, num_classes=self.num_classes))
train_dataset = ConcatDataset(train_datasets)
if __name__ == "__main__":
main()