-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsubgraph.py
117 lines (98 loc) · 4.28 KB
/
subgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pickle
import torch
import numpy as np
from collections import defaultdict as ddict
import lmdb
from tqdm import tqdm
import random
from utils import serialize, get_g, get_hr2t_rt2h_sup_que
import dgl
def gen_subgraph_datasets(args):
print(f'-----There is no sub-graphs for {args.data_name}, so start generating sub-graphs before meta-training!-----')
data = pickle.load(open(args.data_path, 'rb'))
train_g = get_g(data['train_graph']['train'] + data['train_graph']['valid']
+ data['train_graph']['test'])
BYTES_PER_DATUM = get_average_subgraph_size(args, args.num_sample_for_estimate_size, train_g) * 2
map_size = (args.num_train_subgraph + args.num_valid_subgraph) * BYTES_PER_DATUM
env = lmdb.open(args.db_path, map_size=map_size, max_dbs=2)
train_subgraphs_db = env.open_db("train_subgraphs".encode())
valid_subgraphs_db = env.open_db("valid_subgraphs".encode())
for idx in tqdm(range(args.num_train_subgraph)):
str_id = '{:08}'.format(idx).encode('ascii')
datum = sample_one_subgraph(args, train_g)
with env.begin(write=True, db=train_subgraphs_db) as txn:
txn.put(str_id, serialize(datum))
for idx in tqdm(range(args.num_valid_subgraph)):
str_id = '{:08}'.format(idx).encode('ascii')
datum = sample_one_subgraph(args, train_g)
with env.begin(write=True, db=valid_subgraphs_db) as txn:
txn.put(str_id, serialize(datum))
def sample_one_subgraph(args, bg_train_g):
# get graph with bi-direction
bg_train_g_undir = dgl.graph((torch.cat([bg_train_g.edges()[0], bg_train_g.edges()[1]]),
torch.cat([bg_train_g.edges()[1], bg_train_g.edges()[0]])))
# induce sub-graph by sampled nodes
while True:
while True:
sel_nodes = []
for i in range(args.rw_0):
if i == 0:
cand_nodes = np.arange(bg_train_g.num_nodes())
else:
cand_nodes = sel_nodes
rw, _ = dgl.sampling.random_walk(bg_train_g_undir,
np.random.choice(cand_nodes, 1, replace=False).repeat(args.rw_1),
length=args.rw_2)
sel_nodes.extend(np.unique(rw.reshape(-1)))
sel_nodes = list(np.unique(sel_nodes)) if -1 not in sel_nodes else list(np.unique(sel_nodes))[1:]
sub_g = dgl.node_subgraph(bg_train_g, sel_nodes)
if sub_g.num_nodes() >= 50:
break
sub_tri = torch.stack([sub_g.edges()[0],
sub_g.edata['rel'],
sub_g.edges()[1]])
sub_tri = sub_tri.T.tolist()
random.shuffle(sub_tri)
ent_freq = ddict(int)
rel_freq = ddict(int)
triples_reidx = []
ent_reidx = dict()
entidx = 0
for tri in sub_tri:
h, r, t = tri
if h not in ent_reidx.keys():
ent_reidx[h] = entidx
entidx += 1
if t not in ent_reidx.keys():
ent_reidx[t] = entidx
entidx += 1
ent_freq[ent_reidx[h]] += 1
ent_freq[ent_reidx[t]] += 1
rel_freq[r] += 1
triples_reidx.append([ent_reidx[h], r, ent_reidx[t]])
# randomly get query triples
que_tris = []
sup_tris = []
for idx, tri in enumerate(triples_reidx):
h, r, t = tri
if ent_freq[h] > 2 and ent_freq[t] > 2 and rel_freq[r] > 2:
que_tris.append(tri)
ent_freq[h] -= 1
ent_freq[t] -= 1
rel_freq[r] -= 1
else:
sup_tris.append(tri)
if len(que_tris) >= int(len(triples_reidx)*0.1):
break
sup_tris.extend(triples_reidx[idx+1:])
if len(que_tris) >= int(len(triples_reidx)*0.05):
break
# hr2t, rt2h
hr2t, rt2h = get_hr2t_rt2h_sup_que(sup_tris, que_tris)
return sup_tris, que_tris, hr2t, rt2h
def get_average_subgraph_size(args, sample_size, bg_train_g):
total_size = 0
for i in range(sample_size):
datum = sample_one_subgraph(args, bg_train_g)
total_size += len(serialize(datum))
return total_size / sample_size