-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHI_utils.py
159 lines (118 loc) · 5.48 KB
/
HI_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import Compose, RandomVerticalFlip, RandomHorizontalFlip
import csv
def load_HI_data(r_path: str= '', return_funcs: bool=False, test_confs: int=3, n_params: int=2,
sigmoid_normalization: bool=True):
"""
Loads and normalizes the HI data
:param r_path: root path for HI data
:param return_funcs: a flag indicating whether the normalization functions should be returned or not
:param test_confs: a flag indicating whether the test should be returned
:param n_params: number of parameters to return from the simulations
:param sigmoid_normalization: a flag indicating whether the normalization should pass through a sigmoid function
:return:
"""
# ---------------------------------------------------------------------- Load HI maps and transform
data = np.load(r_path + 'Images_HI_IllustrisTNG_z=5.99.npy')
# preprocessing of the data
data = np.log10(data)
# turn into a torch tensor with a channel dimension
data = torch.from_numpy(data)[:, None].float()
# ---------------------------------------------------------------------- Define data normalization transforms
# if sigmoid normalization, standardize data
if sigmoid_normalization: dmean, dstd = torch.mean(data), torch.std(data)
# if not sigmoid, make sure normalization is so that the data is between 0 and 1
else: dmean, dstd = torch.min(data), (torch.max(data) - torch.min(data))
def norm_func(d: torch.Tensor) -> torch.Tensor:
d = (d - dmean) / dstd
if sigmoid_normalization:
d = 2 / (1+torch.exp(-d)) - 1
return d
def ret_func(d: torch.Tensor) -> torch.Tensor:
if sigmoid_normalization:
d = (torch.clamp(d, -.999, .999) + 1)/2
d = torch.log(d/(1-d))
return d*dstd + dmean
data_funcs = tuple([norm_func, ret_func])
data = norm_func(data)
# ---------------------------------------------------------------------- Load parameter values
sim_params = []
param_names = [
r'$\Omega_m$',
r'$\sigma_8$',
r'$A_{SN1}$',
r'$A_{AGN1}$',
r'$A_{SN2}$',
r'$A_{AGN2}$',
]
with open(r_path + 'CAMELs_params.csv') as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
for i, row in enumerate(csv_reader):
if i > 0:
r = [float(it) for it in row[1:-1]]
for _ in range(15): sim_params.append(r)
conds = torch.from_numpy(np.array(sim_params)).float()
# use only the number of simulation parameters that are needed (2 for cosmological)
conds = conds[:, :n_params]
param_names = param_names[:n_params]
# ---------------------------------------------------------------------- Define parameters normalization transforms
cmin, cmax = conds.min(dim=0)[0], conds.max(dim=0)[0]
def cnorm_func(c: torch.Tensor) -> torch.Tensor: return 2 * (c - cmin[None]) / (cmax - cmin)[None] - 1
def cret_func(n: torch.Tensor) -> torch.Tensor: return (cmax - cmin)[None]*(n+1)/2 + cmin[None]
cond_funcs = tuple([cnorm_func, cret_func])
conds = cnorm_func(conds)
# ---------------------------------------------------------------------- Split into train and test
train_inds = [True if (i + 1) % 15 <= 15 - test_confs else False for i in range(data.shape[0])]
test_inds = [not ind for ind in train_inds]
dtrain, dtest = data[train_inds], data[test_inds]
ctrain, ctest = conds[train_inds], conds[test_inds]
if return_funcs:
return (dtrain, ctrain), (dtest, ctest), param_names, data_funcs, cond_funcs
return (dtrain, ctrain), (dtest, ctest)
def HI_dataset(r_path: str='', train: bool=True, sigmoid_normalization: bool=False):
(dtrain, ctrain), (dtest, ctest) = load_HI_data(r_path, sigmoid_normalization=sigmoid_normalization)
return TensorDataset(dtrain, ctrain) if train else TensorDataset(dtest, ctest)
class CyclicShiftTransform:
def __init__(self, image_size: int):
self.size = image_size
def __call__(self, sample):
image, label = sample
shiftx, shifty = np.random.choice(self.size-1, 2)
return torch.roll(image, shifts=[shiftx, shifty], dims=[-2, -1]), label
class RandomFlip:
def __init__(self, dim: int=0):
self.dim = dim
def __call__(self, sample):
image, label = sample
if np.random.rand() > .5: image = torch.flip(image, dims=[self.dim+1])
return image, label
HI_transform = Compose([
CyclicShiftTransform(64),
RandomFlip(1),
RandomFlip(2),
])
class EarlyStop:
def __init__(self, test_loader: DataLoader, validate_every: int=5, patience: int=5, ):
self.loader = test_loader
self.counter = 0
self.epoch = 0
self.losses = []
self.epochs = []
self.validate_every = validate_every
self.patience = patience
self.stop = False
def __call__(self, loss_func):
self.epoch += 1
if not self.epoch%self.validate_every:
loss = []
with torch.set_grad_enabled(False):
for d in self.loader: loss.append(loss_func(d).item())
loss = np.mean(loss)
self.losses.append(loss)
self.epochs.append(self.epoch)
if self.losses[-1] > self.losses[-2]: self.counter += 1
else: self.counter = 0
if self.counter > self.patience: self.stop = True
return self.stop