-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathrecipe.py
68 lines (57 loc) · 2.13 KB
/
recipe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import time
from typing import List
from rich import box
from rich.table import Table
from rich.console import Console
import prodigy
from prodigy.components.loaders import CSV
class ProgressTable:
def __init__(self):
self.start_time = time.time()
self.n_examples = {
"n_accept": 0,
"n_reject": 0,
"n_skip": 0,
}
self.console = Console()
def make_table(self):
"""Generates a pretty Rich table from the results."""
seconds_sofar = time.time() - self.start_time
minutes = seconds_sofar / 60
total_counts = sum(self.n_examples.values())
time_mark = f"{int(seconds_sofar // 60)}m{int(seconds_sofar % 60)}s"
table = Table(title=f"Summary at {time_mark}", box=box.SIMPLE)
table.add_column("Answer", style="magenta", footer="Total")
table.add_column(
"Count", justify="right", style="cyan", footer=str(total_counts)
)
table.add_column(
"Annot per Hour",
justify="right",
style="green",
footer=str(int(total_counts / minutes) * 60),
)
for key, value in self.n_examples.items():
table.add_row(key, str(value), str(int(value / seconds_sofar * 60 * 60)))
table.show_footer = True
return table
def update(self, examples: List[dict]):
self.n_examples["n_accept"] += len([e for e in examples if e["answer"] == "accept"])
self.n_examples["n_reject"] += len([e for e in examples if e["answer"] == "reject"])
self.n_examples["n_skip"] += len([e for e in examples if e["answer"] == "ignore"])
table = self.make_table()
self.console.print(table)
@prodigy.recipe(
"progress",
dataset=("Dataset to save answers to", "positional", None, str),
examples_csv=("Examples in CSV format to load locally", "positional", None, str)
)
def progress(dataset: str, examples_csv: str):
stream = CSV(examples_csv)
ptable = ProgressTable()
return {
"dataset": dataset,
"view_id": "classification",
"stream": stream,
"update": ptable.update,
}