-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalysis.py
131 lines (95 loc) · 3.34 KB
/
analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Authors:
Jason Youn -jyoun@ucdavis.edu
Description:
Main python file to run.
To-do:
"""
# standard imports
import argparse
import logging as log
import os
import sys
# third party imports
import matplotlib.pyplot as plt
# local imports
from managers.model_manager import ModelManager
from utils.config_parser import ConfigParser
from utils.set_logging import set_logging
from utils.visualization import plot_pr, plot_roc, save_figure
# global variables
DEFAULT_CONFIG_FILE = './config/analysis.ini'
def parse_argument():
"""
Parse input arguments.
Returns:
- parsed arguments
"""
parser = argparse.ArgumentParser(description='Description goes here.')
parser.add_argument(
'--config_file',
default=DEFAULT_CONFIG_FILE,
help='Path to the .ini configuration file.')
return parser.parse_args()
def main():
"""
Main function.
"""
# parse args
args = parse_argument()
# load main config file and set logging
main_config = ConfigParser(args.config_file)
set_logging(log_file=main_config.get_str('log_file'))
# initialize model manager object
model_manager = ModelManager()
# perform analysis on these classifiers
classifiers = main_config.get_str_list('classifier')
# do prediction
classifiers_ys = {}
for classifier in classifiers:
log.info('Running model for classifier \'%s\'', classifier)
# load config parsers
preprocess_config = ConfigParser(main_config.get_str('preprocess_config'))
classifier_config = ConfigParser(main_config.get_str('classifier_config'))
# perform preprocessing
X, y = model_manager.preprocess(preprocess_config, section=classifier)
# run classification model
classifier_config.overwrite('classifier', classifier)
X = model_manager.feature_selector(X, y, classifier_config)
score_avg, score_std, ys = model_manager.run_model_cv(X, y, 'f1', classifier_config)
classifiers_ys[classifier] = ys
# plot PR curve
fig = plt.figure()
lines = []
labels = []
for classifier, ys in classifiers_ys.items():
y_trues, y_preds, y_probs = ys
y_probs_1 = tuple(y_prob[1].to_numpy() for y_prob in y_probs)
line, label = plot_pr(y_trues, y_probs_1, classifier)
lines.append(line)
labels.append(label)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('PR Curve')
plt.legend(lines, labels, loc='lower right', prop={'size': 8})
save_figure(fig, os.path.join(main_config.get_str('visualization_dir'), 'pr_curve.png'))
# plot ROC curve
fig = plt.figure()
lines = []
labels = []
for classifier, ys in classifiers_ys.items():
y_trues, y_preds, y_probs = ys
y_probs_1 = tuple(y_prob[1].to_numpy() for y_prob in y_probs)
line, label = plot_roc(y_trues, y_probs_1, classifier)
lines.append(line)
labels.append(label)
# plt.plot([0, 1], [0, 1], color='k', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(lines, labels, loc='lower right', prop={'size': 8})
save_figure(fig, os.path.join(main_config.get_str('visualization_dir'), 'roc_curve.png'))
if __name__ == '__main__':
main()