-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_reconstructions.py
83 lines (64 loc) · 2.56 KB
/
get_reconstructions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
'''
This script is used to generate and save a dataset of testing data.
Testing data can include things such as:
1. LISA noise
2. Noisy EMRIs not seen by the model
3. Other types of GW sources e.g. MBHBs etc.
4. Glitches
'''
import numpy as np
import cupy as xp
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from test_and_train_loop import *
from model_architecture import ConvAE
from few.utils.constants import YRSID_SI
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from EMRI_generator_TDI import EMRIGeneratorTDI
# GPU check
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
#Specify some variables
model_state_dict_dir= "model_current.pt"#"model_group_conv_w_normalisation.pt"
#Load model's weights and architecture
model= ConvAE().to(device)
model.load_state_dict(torch.load(model_state_dict_dir))
model.eval()
#Specify EMRI generator params
EMRI_params_dir="training_data/11011_EMRI_params_SNRs_60_100.npy"#"training_data/EMRI_params_SNRs_20_100_fixed_redshift.npy"
batch_size=4#This needs to be such that val_dataset_size/batch_size is evenly divisible
dim=2**20
TDI_channels="AE"
dt=10
seed=2023
add_noise=False
#Set some seeds
torch.manual_seed(seed)
#Initialise the dataset classes for training and val
EMRI_params_dir="training_data/11011_EMRI_params_SNRs_60_100.npy"
EMRI_params= np.load(EMRI_params_dir, allow_pickle=True)
_, val_params= train_test_split(EMRI_params, test_size=0.3, random_state=seed)
validation_set= EMRIGeneratorTDI(val_params, dim=dim, dt=dt, TDI_channels=TDI_channels, add_noise=add_noise, seed=seed)#"training_data/EMRI_params_SNRs_20_100_fixed_redshift.npy"
#Initialise the data generators as PyTorch dataloaders
validation_dataloader= torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=True)
#Generate one batch of data
X_EMRIs, y_true_EMRIs = next(iter(validation_dataloader))
#Normalise X
max_abs_tensor= torch.as_tensor([0.9098072, 0.5969127], device="cuda").reshape(2,1)
X_EMRIs= X_EMRIs/max_abs_tensor
#Make predictions with the model
y_pred_EMRIs= model(X_EMRIs)
#Convert everything to numpy arrays
X_EMRIs= X_EMRIs.detach().cpu().numpy()
y_true_EMRIs= y_true_EMRIs.detach().cpu().numpy()
y_pred_EMRIs= y_pred_EMRIs.detach().cpu().numpy()
#Save the example EMRIs and their reconstructions!
np.save("Val_X_EMRIs_NORMALISED.npy",X_EMRIs)
np.save("Val_pred_EMRIs.npy",y_pred_EMRIs)