-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathGNN.py
99 lines (91 loc) · 3.85 KB
/
GNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import dgl
from dgl.nn.pytorch import GraphConv
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax, GATConv
from conv import myGATConv
class DistMult(nn.Module):
def __init__(self, num_rel, dim):
super(DistMult, self).__init__()
self.W = nn.Parameter(torch.FloatTensor(size=(num_rel, dim, dim)))
nn.init.xavier_normal_(self.W, gain=1.414)
def forward(self, left_emb, right_emb, r_id):
thW = self.W[r_id]
left_emb = torch.unsqueeze(left_emb, 1)
right_emb = torch.unsqueeze(right_emb, 2)
return torch.bmm(torch.bmm(left_emb, thW), right_emb).squeeze()
class Dot(nn.Module):
def __init__(self):
super(Dot, self).__init__()
def forward(self, left_emb, right_emb, r_id):
left_emb = torch.unsqueeze(left_emb, 1)
right_emb = torch.unsqueeze(right_emb, 2)
return torch.bmm(left_emb, right_emb).squeeze()
class myGAT(nn.Module):
def __init__(self,
g,
edge_dim,
num_etypes,
in_dims,
num_hidden,
num_classes,
num_layers,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
alpha,
decode='distmult'):
super(myGAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
self.fc_list = nn.ModuleList([nn.Linear(in_dim, num_hidden, bias=True) for in_dim in in_dims])
for fc in self.fc_list:
nn.init.xavier_normal_(fc.weight, gain=1.414)
# input projection (no residual)
self.gat_layers.append(myGATConv(edge_dim, num_etypes,
num_hidden, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation, alpha=alpha))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(myGATConv(edge_dim, num_etypes,
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation, alpha=alpha))
# output projection
self.gat_layers.append(myGATConv(edge_dim, num_etypes,
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None, alpha=alpha))
self.epsilon = torch.FloatTensor([1e-12]).cuda()
if decode == 'distmult':
self.decoder = DistMult(num_etypes, num_classes*(num_layers+2))
elif decode == 'dot':
self.decoder = Dot()
def l2_norm(self, x):
# This is an equivalent replacement for tf.l2_normalize, see https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/l2_normalize for more information.
return x / (torch.max(torch.norm(x, dim=1, keepdim=True), self.epsilon))
def forward(self, features_list, e_feat, left, right, mid):
h = []
for fc, feature in zip(self.fc_list, features_list):
h.append(fc(feature))
h = torch.cat(h, 0)
emb = [self.l2_norm(h)]
res_attn = None
for l in range(self.num_layers):
h, res_attn = self.gat_layers[l](self.g, h, e_feat, res_attn=res_attn)
emb.append(self.l2_norm(h.mean(1)))
h = h.flatten(1)
# output projection
logits, _ = self.gat_layers[-1](self.g, h, e_feat, res_attn=res_attn)#None)
logits = logits.mean(1)
logits = self.l2_norm(logits)
emb.append(logits)
logits = torch.cat(emb, 1)
left_emb = logits[left]
right_emb = logits[right]
return self.decoder(left_emb, right_emb, mid)