-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
executable file
·32 lines (26 loc) · 1.15 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
'''Modified from https://github.com/alinlab/LfF/blob/master/util.py'''
import io
import torch
import numpy as np
import torch.nn as nn
class EMA:
def __init__(self, label, num_classes=None, alpha=0.9):
self.label = label.cuda()
self.alpha = alpha
self.parameter = torch.zeros(label.size(0))
self.updated = torch.zeros(label.size(0))
self.num_classes = num_classes
self.max = torch.zeros(self.num_classes).cuda()
def update(self, data, index, curve=None, iter_range=None, step=None):
self.parameter = self.parameter.to(data.device)
self.updated = self.updated.to(data.device)
index = index.to(data.device)
if curve is None:
self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
else:
alpha = curve ** -(step / iter_range)
self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data
self.updated[index] = 1
def max_loss(self, label):
label_index = torch.where(self.label == label)[0]
return self.parameter[label_index].max()