-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathload_data.py
69 lines (60 loc) · 2.78 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import json
class Data:
def __init__(self, data_dir, metapaths):
self.metapaths=metapaths
self.reg2id = self.load_region_data(data_dir)
self.ent2id, self.rel2id, self.kg_data = self.load_full_kg(data_dir)
self.mp2data = self.load_subkg_data(data_dir)
self.nreg=len(self.reg2id)
self.mob_adj=np.load(data_dir+'mob-adj.npy')[0]
print('number of node=%d, number of edge=%d, number of relations=%d' % (len(self.ent2id), len(self.kg_data), len(self.rel2id)))
print('sub-KGs:',metapaths)
print('region num={}'.format(len(self.reg2id)))
print('load finished..')
def load_region_data(self, data_dir):
with open(data_dir + 'region2info.json', 'r') as f:
region2info=json.load(f)
regions=sorted(region2info.keys(),key=lambda x:x)
reg2id=dict([(x,i) for i,x in enumerate(regions)])
return reg2id
def load_full_kg(self, data_dir):
ent2id, rel2id = self.reg2id.copy(), {}
kg_data_str = []
with open(data_dir + 'kg.txt', 'r') as f:
for line in f.readlines():
h,r,t=line.strip().split('\t')
kg_data_str.append((h,r,t))
ents = sorted(list(set([x[0] for x in kg_data_str] + [x[2] for x in kg_data_str])))
rels = sorted(list(set([x[1] for x in kg_data_str])))
for i, x in enumerate(ents):
try:
ent2id[x]
except KeyError:
ent2id[x] = len(ent2id)
rel2id = dict([(x, i) for i, x in enumerate(rels)])
kg_data = [[ent2id[x[0]], rel2id[x[1]], ent2id[x[2]]] for x in kg_data_str]
return ent2id, rel2id, kg_data
def load_subkg_data(self, data_dir):
mp2data={}
for mp in self.metapaths:
ent2id, rel2id = self.reg2id.copy(), {}
kg_data_str = []
with open(data_dir + 'kg_{}.txt'.format(mp), 'r') as f:
for line in f.readlines():
h,r,t=line.strip().split('\t')
kg_data_str.append((h,r,t))
ents = sorted(list(set([x[0] for x in kg_data_str] + [x[2] for x in kg_data_str])))
rels = sorted(list(set([x[1] for x in kg_data_str])))
for i, x in enumerate(ents):
try:
ent2id[x]
except KeyError:
ent2id[x] = len(ent2id)
rel2id = dict([(x, i) for i, x in enumerate(rels)])
kg_data = [[ent2id[x[0]], rel2id[x[1]], ent2id[x[2]]] for x in kg_data_str]
ent2kgid={}
for e in ent2id.keys():
ent2kgid[e]=self.ent2id[e]
mp2data[mp]={'ent2id':ent2id,'rel2id':rel2id,'kg_data':kg_data,'ent2kgid':ent2kgid}
return mp2data