-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_performance_report.py
69 lines (57 loc) · 2.4 KB
/
get_performance_report.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
'''
Project: PolyGNN
-------------------------------------------------------------------------------
'''
import pandas as pd
from sklearn.metrics import (mean_absolute_error, mean_squared_error, r2_score,
mean_absolute_percentage_error
)
import numpy as np
n_epochs = 50
def save_latextable(df, filename):
"""dataframe to latex"""
LATEX_TABLE = r'''\documentclass{{standalone}}
\usepackage{{booktabs}}
\usepackage{{multirow}}
\usepackage{{graphicx}}
\usepackage{{xcolor,colortbl}}
\begin{{document}}
{}
\end{{document}}
'''
a_str = df.style.to_latex()
with open(filename, 'w') as f:
f.write(LATEX_TABLE.format(a_str))
for split in ['Random_split', 'Extrapolation_solute', 'Extrapolation_solvent']:
print('='*70)
print(split)
print('='*70)
if split == 'Random_split':
n_folds = 10
else:
n_folds = 5
for spec in ['MN', 'MW', 'PDI']:
for mode in ['train', 'test']:
print('='*50)
maes, r2s, rmses, mapes = [], [], [], []
for f in range(n_folds):
model_name = 'GHGNN_epochs_'+str(n_epochs)+'_fold_'+str(f)
df = pd.read_csv(split + '/' + spec+ '/'+ model_name+'/'+mode+'_pred.csv')
y_true = df['log_omega'].to_numpy()
y_pred = df[model_name].to_numpy()
maes.append(mean_absolute_error(y_true, y_pred))
r2s.append(r2_score(y_true, y_pred))
rmses.append(mean_squared_error(y_true, y_pred)**0.5)
mapes.append(mean_absolute_percentage_error(y_true, y_pred)*100)
df_res = pd.DataFrame({
'MAE': maes,
'R2': r2s,
'RMSE': rmses,
'MAPE': mapes
})
print('MAE : ', np.mean(maes))
print('R2 : ', np.mean(r2s))
print('RMSE: ', np.mean(rmses))
print('MAPE: ', np.mean(mapes))
df_res.to_csv(split + '/' + spec+'/performance.csv', index=False)
save_latextable(df_res, split + '/' + spec+ '/report_performance_'+mode+'.txt')