-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference.py
66 lines (44 loc) · 1.86 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from deepsleep import InferenceHeartSequenceLoader
import numpy as np
import os, glob
# from data_generator import InferenceHeartSequenceLoader
from keras.models import load_model, model_from_json
def prepare_data():
HDF_DIR = os.path.join(os.getcwd(), "deepsleep/data")
for hdf_file in glob.glob(HDF_DIR + '/' + '*.hdf'):
print hdf_file
# hdf_file = HDF_DIR + '/subject36.hdf'
prep_data = PrepareDataset(hdf_file, seq_len=7500, batch_size=64, n_classes=4)
prep_data.prep_dataset()
def load_data(subject_file):
loader = InferenceHeartSequenceLoader(seq_len=7500, batch_size=16, n_classes=4)
heart_signal, labels = loader.get_data(subject_file)
return heart_signal, labels
def initialize_model(model_name):
model = load_model(model_name)
return model
def initialize_model_json(json_model, model_weights):
loaded_model = model_from_json(json_model)
loaded_model = loaded_model.load_weights(model_weights)
return loaded_model
def inference(model, heart_data):
predictions = model.predict(heart_data)
pred_class = predictions.argmax(axis=-1)
return predictions, pred_class
def evaluate_inference(model, heart_data, labels):
score = model.evaluate(x=heart_data, y=labels)
# print score
return score
if __name__ == '__main__':
# prepare_data()
subject_data = "Subject17_03082017" + '.npz'
model_path = "model_output/8b2fResReg1Val_PTr_128/"
model_arch = model_path + "model_arch.json"
model_name = model_path + 'weights.29-0.73-0.92.hdf5'
model_name = model_path + "full_model.hdf5"
heart_signal, labels = load_data(subject_data)
model = initialize_model(model_name)
# model = initialize_model_json(model_arch, model_name)
predictions, classes = inference(model, heart_signal)
# score = evaluate_inference(model, heart_signal, labels)
print classes