-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrx_train.py
244 lines (208 loc) · 9.04 KB
/
rx_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""AI based RX
name: model training module
status: draft
data file naming "{iq}_t{tau:.1f}_s{snr}"
v0.0.6 > ce_train.py merged to rx_train (data load part is just copied, not integrated yet)
v0.0.5 > get_data() updated
v0.0.4 > data file naming convention {iq}_t{tau:.1f}_s{snr}
last update : (17 May 2024, 10:02)
"""
import wandb
import pickle
import numpy as np
import tensorflow as tf
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint
from ber_util import gen_data, add_awgn
from rx_utils import get_data, show_train, prep_ts_data
# from rx_utils import check_data, get_song_data
from rx_models import gru_temel, save_mdl
# from ce_models import ce_temel, ce_plus
# model: gru_temel, base_bpsk, dense_nn_qpsk, dense_nn_deep, lstm_bpsk, save_mdl, song_bpsk
from rx_config import init_gpu
from constants import FS, G_DELAY, dict_h, data_gen_seed
init_gpu()
TAU = 0.60 # 0.50, 0.60, 0.70, 0.80, 0.90, 1.00
SNR = 10 # 0, 1, .., 10, 'NoNoise' # noqa # FUTURE multiple noise support will be added, e.g. [8, 10, 12]
IQ = 'bpsk' # bpsk, qpsk
init_lr = 0.0005
# train parameters
epochs = 3 # 70
batch_size = 128 # 8192 # reduce batch size for big models...
NoS = int(1e4) # number of symbols default: 1e6
val_split = 0.1
DATA_MODE = 'load' # 'load', 'generate' 'load_npy
WB_ON = False
# number of consecutive sample considered during calculation, min 2
LoN = 4 # e.g. LoN=3; [. . . S . . .], total 7 sample
# FUTURE support for variable sampling frequency, FS
assert FS == 10, "Only FS=10 supported, current implementation does not support 'FS: {fs}'".format(fs=FS)
step = int(TAU * FS)
# f = open('run/{id}/console.log'.format(id=datestr), 'w') TODO
# TODO model selection and handling
model = gru_temel(lon=LoN, init_lr=init_lr)
# 'base' 'dense', 'lstm', 'gru' # TODO review model naming and check consistency
# L = 512 # Length of symbol block/window, {(L-m)//2, m, (L-m)//2} see: song2019
# m = 32 # number of symbols to process at each inference, amount of shift see: song2019
# model = song_bpsk(L=L, m=m)
if DATA_MODE == 'load':
# Load the training data
X_i, y_i = get_data(modulation=IQ, tau=TAU, snr=SNR, NoD=NoS)
if IQ != 'bpsk':
# compact data into 1D, no need to consider real(I) and imaginary(Q) parts as separate dimensions
X_i = np.reshape(X_i, (-1,))
# TODO revise the single data format, is it efficient? why?
elif DATA_MODE == 'load_ce':
raise NotImplementedError
# data for ce models (channel estimation)
# Load the training data
df = pd.read_csv('data/data_ce/rx_data_DL_Ch_Est_BPSK_Ntap3_SNR10dB_alpha03_tau1_Ks10_N256.csv',
header=None)
dfy = pd.read_csv('data/data_ce/tx_data_DL_Ch_Est_BPSK_Ntap3_SNR10dB_alpha03_tau1_Ks10_N256.csv',
usecols=[32, 33, 34, 35, 36, 37], header=None)
for i in range(32):
df[i] = df[i].str.replace('i', 'j').apply(lambda x: np.complex_(x))
# df[str(i)+'r'] = df[i].apply(lambda x: np.real(x))
# df[str(i)+'i'] = df[i].apply(lambda x: np.imag(x))
X_r = np.real(df)
X_i = np.imag(df)
y_r = np.array(dfy[[32, 34, 36]])
y_i = np.array(dfy[[33, 35, 37]])
# X = np.concatenate((dfr[:, :26], dfi[:, :26]), axis=1)
# y = np.concatenate((dfr[:, 26:], dfi[:, 26:]), axis=1)
# instead of keeping real/imag as single data just flat it
# X = np.concatenate((dfr[:, :26], dfi[:, :26]), axis=0)
# y = np.concatenate((dfr[:, 26:], dfi[:, 26:]), axis=0)
X = np.concatenate((X_r, X_i), axis=0).astype(np.float16)
y = np.concatenate((y_r, y_i), axis=0).astype(np.float16)
elif DATA_MODE == 'load_npy':
try:
X_i, y_i = np.load('data/{iq}_t{tau:.1f}_s{snr}.npy'.format(iq=IQ, tau=TAU, snr=SNR))
except FileNotFoundError:
assert 0, 'The given data file not found try to use generate option!'
except ValueError:
assert 0, 'The given data file content does not compatible!'
else:
# G_DELAY FS based h generation
hPSF = np.array(dict_h[G_DELAY]).astype(np.float16)
assert np.array_equal(hPSF, hPSF[::-1]), 'symmetry mismatch!'
# limit amount of data to generate
assert NoS < int(1e7) + 1, 'too many data to generate, load from file'
# [SOURCE] Data Generation
data, bits = gen_data(n=NoS, mod=IQ, seed=data_gen_seed) # IQ options: ('bpsk', 'qpsk')
# [TX] up-sample
# extend the data by up sampling (in order to be able to apply FTN)
s_up_sampled = np.zeros(step * len(data), dtype=np.float16)
s_up_sampled[::step] = data
# [TX] apply FTN (tau)
# apply the filter
tx_data = np.convolve(hPSF, s_up_sampled)
# [CHANNEL] add AWGN noise (snr)
# Channel Modelling, add noise
rch = add_awgn(inputs=tx_data, snr=SNR, seed=1234)
# [RX] apply matched filter
mf = np.convolve(hPSF, rch)
# [RX] down-sample (subsample)
# p_loc = 2 * G_DELAY * FS # 81 for g_delay=4 and FS = 10,
# 4*10=40 from first conv@TX, and +40 from last conv@RX
# remove additional prefix and suffix symbols due to CONV
rx_data = mf[2 * G_DELAY * FS:-(2 * G_DELAY * FS):step]
# X_i, y_i
X_i = rx_data
y_i = bits
# if AUTO_SAVE:
np.save('data/{iq}_t{tau:.1f}_s{snr}_X_i.npy'.format(iq=IQ, tau=TAU, snr=SNR), [X_i, y_i])
print(model.summary())
confs = { # TODO improve, get content from model.get_compile_config()
'loss': model.loss,
'optimizer_class_name': model.optimizer.__class__.__name__,
'optimizer_config': model.optimizer.get_config(),
'metrics': [i['class_name'] for i in model.get_compile_config()['metrics']], # noqa
}
for k, v in confs.items():
print(k, v)
# DATA pre-processing
# [DEBUG] noise generation
# Xs = add_awgn(y*2-1, snr=10)
# [DEBUG] data control
# check_data(rx_data=X_i, ref_bit=y_i, modulation=IQ)
# if 'song' in model.name:
# X, y = get_song_data(X_i, y_i, L=L, m=m)
if 'lstm' in model.name or 'gru' in model.name: # TODO fix/automate/parameterize this step
X = prep_ts_data(X_i, lon=LoN)
else:
# single to time series data
X = X_i
# update label type to float for evaluating performance metrics
y = y_i.astype(np.float16)
y = np.expand_dims(y, axis=1) # TODO check and fix
# TODO merge config_lc and configs(wandb)
config_lc = {'Modulation': IQ, 'TAU': TAU, 'SNR': SNR, 'Number of sample': NoS if NoS != -1 else 'all',
'model': model.name, 'Half window length': LoN, 'Sampling Frequency': FS, 'Group Delay': G_DELAY,
# 'merge features:': merge,
# 'Decision Tree Max.Depth': max_depth, 'Decision Tree criterion': criterion,
# 'D.T. random_state': random_state,
# 'DT splitter': splitter, 'DT min_samples_split': min_samples_split,
# 'DT min_samples_leaf': min_samples_leaf,
# 'DT max_features': max_features, 'training test_ratio': test_ratio, '[RF] n_estimators': n_estimators,
# 'Selected Features': f_set
}
# Weight and Biases integration
# https://docs.wandb.ai/tutorials/keras_models
# TODO set configurations
configs = dict(
modulation=IQ,
tau=TAU, snr=SNR,
model=model.name,
lon=LoN,
dropout=0.3,
optimizer=model.optimizer.get_config(),
batch_size=batch_size,
data_source=DATA_MODE,
num_of_syms=NoS,
num_of_data=len(y_i),
validation_split=val_split,
learning_rate=init_lr,
epochs=epochs
)
if WB_ON:
wandb.init(project='rx_ai',
config=configs
)
callbacks = [WandbMetricsLogger(log_freq='epoch',
initial_global_step=0),
# WandbModelCheckpoint(filepath=os.getcwd()+'/models/tau{:.2f}_'.format(TAU)
# + model.name+'_{epoch:02d}',
WandbModelCheckpoint(filepath='./models/tau{:.2f}_'.format(TAU) + model.name, # TODO fix path
save_best_only=False,
# monitor='val_f1_score',
)
] # WandbCallback()
else:
callbacks = None
tf.keras.backend.clear_session()
history = model.fit(X, y,
validation_split=val_split,
epochs=epochs,
batch_size=batch_size,
callbacks=callbacks,
)
dir_path = save_mdl(model, modulation=IQ, config=config_lc, history=history)
if WB_ON:
wandb.finish()
# plot train process, and get the figure object
fig = show_train(history)
# save the figure as image
fig.savefig(dir_path + '/train_fig.png')
# save the figure as object
pickle.dump(fig, open(dir_path + 'train_fig.pickle', 'wb'))
# TODO save the source code ? ~x~ CANCELLED ~x~
# results = model.evaluate(x_test, y_test, batch_size=128)
y_pred = model.predict(X[:10, :])
# y_true = y[:10, :]
# TODO summarize and save the prediction result
# references
# https://wandb.ai/ayush-thakur/dl-question-bank/reports/LSTM-RNN-in-Keras-Examples-of-One-to-Many-Many-to-One-Many-to-Many---VmlldzoyMDIzOTM
# info
# song2019, "Receiver Design for Faster-than-Nyquist Signaling: Deep-learning-based Architectures"
# 0 1 2 ... 9 10 11 ..... 107 108 109 ... 128
# 107 108 109 ... .. ... 215 216 217 ... 236 ?