-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_coref_degradation_metrics.py
108 lines (92 loc) · 3.48 KB
/
plot_coref_degradation_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations
import argparse, json
import pathlib as pl
from more_itertools.recipes import flatten
import scienceplots
import matplotlib.pyplot as plt
def precision_recall_metrics(metrics: set[str]) -> set[str]:
precision_metrics = {f"{metric}_precision" for metric in metrics}
recall_metrics = {f"{metric}_recall" for metric in metrics}
return precision_metrics | recall_metrics
GMETRIC_2_PRETTYNAME = {
"node_f1": "$F1_V$",
"node_precision": "$Pre_V$",
"node_recall": "$Rec_V$",
"edge_f1": "$F1_E$",
"edge_precision": "$Pre_E$",
"edge_recall": "$Rec_E$",
"weighted_edge_f1": "$WF1_E$",
"weighted_edge_precision": "$WPre_E$",
"weighted_edge_recall": "$WRec_E$",
}
def capitalize_snakecase_text(text: str) -> str:
if text in GMETRIC_2_PRETTYNAME:
return GMETRIC_2_PRETTYNAME[text]
if text.startswith("lea_"):
return "LEA " + capitalize_snakecase_text(text[4:])
if text.startswith("ceaf_"):
return "CEAF " + capitalize_snakecase_text(text[5:])
if text.startswith("b_cubed_"):
return "$B^3$ " + capitalize_snakecase_text(text[8:])
if text.startswith("muc_"):
return "MUC " + capitalize_snakecase_text(text[4:])
splitted = text.split("_")
return " ".join([word.capitalize() for word in splitted])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--run-dict",
type=eval,
default={
"add_spurious_mention": "./runs/add_spurious_mention/",
"add_spurious_link": "./runs/add_spurious_link/",
"remove_correct_mention": "./runs/remove_correct_mention/",
"remove_correct_link": "./runs/remove_correct_link/",
},
)
parser.add_argument("-o", "--output", type=str)
args = parser.parse_args()
FONTSIZE = 8
TEXT_WIDTH_IN = 6.29921
ASPECT_RATIO = 0.6
plt.style.use(["science", "grid"])
plt.rc("xtick", labelsize=FONTSIZE)
plt.rc("ytick", labelsize=FONTSIZE)
cmap = plt.get_cmap("tab20")
fig, axs = plt.subplots(2, 2, figsize=(TEXT_WIDTH_IN, TEXT_WIDTH_IN * ASPECT_RATIO))
TASK_METRICS = {"muc_f1", "b_cubed_f1", "ceaf_f1", "blanc_f1", "lea_f1"}
GRAPH_METRICS = precision_recall_metrics({"node", "edge", "weighted_edge"})
for run_i, (pkey, ppath) in enumerate(args.run_dict.items()):
ax = list(flatten(axs))[run_i]
with open(pl.Path(ppath) / "metrics.json") as f:
metrics_dict = json.load(f)
for metric_i, metric in enumerate(sorted(TASK_METRICS | GRAPH_METRICS)):
metrics_key = f"mean_{metric}"
steps = metrics_dict[metrics_key]["steps"]
values = metrics_dict[metrics_key]["values"]
ax.plot(
steps,
values,
linestyle="-" if metric in GRAPH_METRICS else "--",
label=capitalize_snakecase_text(metric),
linewidth=1.5,
color=cmap(metric_i),
)
ax.set_title(capitalize_snakecase_text(pkey), fontsize=FONTSIZE)
ax.set_xlabel("Degradation Steps", fontsize=FONTSIZE)
handles, labels = ax.get_legend_handles_labels()
fig.legend(
handles,
labels,
fancybox=True,
loc="upper left",
bbox_to_anchor=(1.0, 0.9),
fontsize=FONTSIZE,
)
plt.tight_layout()
if args.output:
plt.savefig(args.output)
print(f"plot saved at {args.output}")
else:
plt.show()