-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
92 lines (72 loc) · 2.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pandas as pd
import config
import preproc
import cross_val
import model_dispatcher
from sklearn import metrics
import joblib
import os
import argparse
import warnings
warnings.filterwarnings('ignore')
def run(folds, model):
# Loading the dataset
df = pd.read_csv(config.TRAINING_FILE)
# Preprocessing
df = preproc.preprocessing(df)
# cross validation (Stratified K folds)
df = cross_val.create_folds(df)
df.to_csv("input/cross_val_5folds.csv", index = False)
# training and validation set
df_train = df[df.kfold != folds].reset_index(drop = True)
df_valid = df[df.kfold == folds].reset_index(drop = True)
X_train = df_train.drop(columns = ['is_canceled']).values
y_train = df_train.is_canceled.values
X_valid = df_valid.drop(columns = ['is_canceled']).values
y_valid = df_valid.is_canceled.values
clf = model_dispatcher.models[model]
if model != 'dnn':
print ("Training...")
clf.fit(X_train, y_train)
print ("Done!!")
preds = clf.predict(X_valid)
else:
clf.summary()
print ("Training...")
clf.fit(
X_train, y_train,
# validation_data = (X_valid, y_valid),
batch_size = config.batch_size,
epochs = config.epochs
# callbacks = [model_dispatcher.early_stopping],
)
print ("Done!!")
preds = clf.predict(X_valid)
preds = (preds>0.5).astype(int)
acc = metrics.accuracy_score(y_valid, preds)
print ("Fold = {} Accuracy = {}".format(folds, acc))
print ("-------Classification Report")
print (metrics.classification_report(y_valid, preds))
if model !='dnn':
joblib.dump(
clf,
os.path.join(config.MODEL_OUTPUT, f"{model}_fold_{folds}.bin")
)
else:
name = "dnn_model_fold_" + str(folds) + str(".h5")
clf.save(config.MODEL_OUTPUT + name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--folds",
type = int
)
parser.add_argument(
"--model",
type = str
)
args = parser.parse_args()
run(
folds = args.folds,
model = args.model
)