-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathutils.py
36 lines (32 loc) · 1003 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
class dist_average:
def __init__(self,local_rank):
self.world_size=dist.get_world_size()
self.rank=dist.get_rank()
self.local_rank=local_rank
self.acc=torch.zeros(1).to(local_rank)
self.count=0
def step(self,input_):
self.count+=1
if type(input_)!=torch.Tensor:
input_=torch.tensor(input_).to(self.local_rank,dtype=torch.float)
else:
input_=input_.detach()
self.acc+=input_
def get(self):
dist.all_reduce(self.acc,op=dist.ReduceOp.SUM)
self.acc/=self.world_size
return self.acc.item()/self.count
def ACC(x,y):
with torch.no_grad():
a=torch.max(x,dim=1)[1]
acc= torch.sum(a==y).float()/x.shape[0]
#print(y,a,acc)
return acc
def cont_grad(x,rate=1):
return rate*x+(1-rate)*x.detach()