-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmk_figuresnstats.py
executable file
·136 lines (116 loc) · 4.6 KB
/
mk_figuresnstats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
"""
This script is adapted from http://handbook.datalad.org/en/latest/basics/101-130-yodaproject.html
"""
import pandas as pd
import seaborn as sns
import datalad.api as dl
from sklearn import model_selection
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
# global variables
data = "input/iris.csv"
def read_data(data):
"""
Get and read in data.
"""
# if the data is not retrieved, get it using datalad get.
dl.get(data)
# read data into a pandas dataframe
df = pd.read_csv(data)
attributes = ["sepal_length", "sepal_width", "petal_length", "petal_width",
"class"]
df.columns = attributes
return df
def plot_relationships(df):
"""
Create a pairplot to plot pairwise relationships in the dataset and save the
results as png file
:param df: pandas dataframe
"""
plot = sns.pairplot(df, hue='class', palette='muted')
# save the figure as a png.
plot.savefig('img/pairwise_relationships.png')
def knn(df):
"""
Perform a K-nearest-neighbours classification with scikit-learn and save the
results as a csv file
:param df: pandas dataframe
"""
# split the data into training and testing data
array = df.values
X = array[:, 0:4]
Y = array[:, 4]
test_size = 0.20
seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y,
test_size=test_size,
random_state=seed)
# Fit the model and make predictions on the test dataset
knn = KNeighborsClassifier()
knn.fit(X_train, Y_train)
predictions = knn.predict(X_test)
# Save the classification report
report = classification_report(Y_test, predictions, output_dict=True)
df_report = pd.DataFrame(report).transpose().to_csv('prediction_report.csv',
float_format='%.2f')
return report
def print_prediction_report(report):
"""
Print out items from the prediction report as LaTeX variables.
Those printed variables can later be collected in a .tex file
and embedded into a manuscript.
:param report: dict; sklearn classification report
"""
# iterate through the prediction report, and print each statistic as
# a Latex \newcommand{}{} definition. The resulting variables can be
# used in the manuscript to embed results.
for key, labelprefix in [('Setosa', 'Setosa'),
('Versicolor', 'Versicolor'),
('Virginica', 'Virginica'),
('macro avg', 'MA'),
('weighted avg', 'WA')
]:
for var, varprefix in [('precision', 'Precision'),
('recall', 'Recall'),
('f1-score', 'F'),
('support', 'Support')]:
# round to two floating points
format = "%.2f"
score = report[key][var]
label = str(labelprefix + varprefix)
print('\\newcommand{\\%s}{%s}' % (label, format % score))
# also print accuracy
acc = report['accuracy']
print('\\newcommand{\\accuracy}{%s}' % (acc))
def main(data, figure=True, stats=True):
"""
Run the relevant functions from the script.
:param data: str; path to input data (will be retrieved with datalad, if necessary)
:param figure: bool; if True, plot_relationship() is executed to save a figure
:param stats: bool; if True, prediction stats are saved
"""
# get and load the data
df = read_data(data)
if figure:
# create a plot
plot_relationships(df)
if stats:
# train, predict, evaluate, and save results of KNN classification
report = knn(df)
# print variables
print_prediction_report(report)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'-f', '--figure', help='A switch to control if figures will be produced. '
'Useful if you want to separate statistics'
'from figure generation',
action='store_true', default=True)
parser.add_argument(
'-s', '--stats', help='A switch to control if stats will be produced',
action='store_true', default=True)
args = parser.parse_args()
# generate & save figures; export the stats
main(data, figure=args.figure, stats=args.stats)