forked from zkx06111/WSRGlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
111 lines (87 loc) · 3.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch.nn as nn
import torch
import traceback
from functools import wraps
import numpy as np
def load_ckpt(model, checkpoint_path):
fake_task = nn.Module()
fake_task.model = model
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
fake_task.load_state_dict(state_dict, strict=True)
print(f"| load ckpt from {checkpoint_path}.")
def num_params(model, print_out=True, model_name="model"):
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
if print_out:
print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
return parameters
def print_arch(model, model_name='model'):
print(f"| {model_name} Arch: ", model)
num_params(model, model_name=model_name)
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values) if max_len is None else max_len
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if shift_right:
dst[1:] = src[:-1]
dst[0] = shift_id
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
"""Convert a list of 2d tensors into a padded 3d tensor."""
size = max(v.size(0) for v in values) if max_len is None else max_len
res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if shift_right:
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res
def data_loader(fn):
"""
Decorator to make any fx with this use the lazy property
:param fn:
:return:
"""
wraps(fn)
attr_name = '_lazy_' + fn.__name__
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
except AttributeError:
try:
value = fn(self) # Lazy evaluation, done only once.
if (
value is not None and
not isinstance(value, list) and
fn.__name__ in ['test_dataloader', 'val_dataloader']
):
value = [value]
except AttributeError as e:
# Guard against AttributeError suppression. (Issue #142)
traceback.print_exc()
error = f'{fn.__name__}: An AttributeError was encountered: ' + \
str(e)
raise RuntimeError(error) from e
setattr(self, attr_name, value) # Memoize evaluation.
return value
return _get_data_loader
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt