-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
38 lines (29 loc) · 1.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import time
import torch
class Timer():
def __init__(self):
self.times = [time.time()]
self.total_time = 0.0
def __call__(self, include_in_total=True):
self.times.append(time.time())
delta_t = self.times[-1] - self.times[-2]
if include_in_total:
self.total_time += delta_t
return delta_t
class TableLogger():
def append(self, output):
if not hasattr(self, 'keys'):
self.keys = output.keys()
print(*(f'{k:>12s}' for k in self.keys))
filtered = [output[k] for k in self.keys]
print(*(f'{v:12.4f}' if isinstance(v, float) else f'{v:12}' for v in filtered))
class StatsLogger():
def __init__(self, keys):
self._stats = {k:[] for k in keys}
def append(self, output):
for k,v in self._stats.items():
v.append(output[k].detach())
def mean(self, key):
return torch.cat(self._stats[key]).float().mean().item()
def shuffle_tensor(tensor):
return tensor[torch.randperm(len(tensor))]