-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
64 lines (47 loc) · 1.44 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
import os
import numpy as np
import torch
import dgl
import pickle
def serialize(data):
return pickle.dumps(data)
def deserialize(data):
data_tuple = pickle.loads(data)
return data_tuple
def get_g(tri_list):
triples = np.array(tri_list)
g = dgl.graph((triples[:, 0].T, triples[:, 2].T))
g.edata['rel'] = torch.tensor(triples[:, 1].T)
return g
def init_dir(args):
# state
if not os.path.exists(args.state_dir):
os.makedirs(args.state_dir)
# tensorboard log
if not os.path.exists(args.tb_log_dir):
os.makedirs(args.tb_log_dir)
# logging
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
class Log(object):
def __init__(self, log_dir, name):
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s | %(name)s | %(message)s',
"%Y-%m-%d %H:%M:%S")
# file handler
log_file = os.path.join(log_dir, name + '.log')
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
# console handler
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
self.logger.addHandler(fh)
self.logger.addHandler(sh)
fh.close()
sh.close()
def get_logger(self):
return self.logger