-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
157 lines (133 loc) · 5.53 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import jax
import jax.numpy as jnp
import jraph
import networkx as nx
import numpy as onp
import pickle
from typing import Tuple
def load_dataset(path_to_dataset: str) -> jraph.GraphsTuple:
with open(path_to_dataset, 'rb') as f:
cora_ds = pickle.load(f)
return cora_ds
def train_val_test_split_edges(graph: jraph.GraphsTuple,
val_perc: float = 0.05,
test_perc: float = 0.1):
"""Split edges in input graph into train, val and test splits.
For val and test sets, also include negative edges.
Based on torch_geometric.utils.train_test_split_edges.
JAX implementation based fully on:
https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb
"""
mask = graph.senders < graph.receivers
senders = graph.senders[mask]
receivers = graph.receivers[mask]
num_val = int(val_perc * senders.shape[0])
num_test = int(test_perc * senders.shape[0])
permuted_indices = onp.random.permutation(range(senders.shape[0]))
senders = senders[permuted_indices]
receivers = receivers[permuted_indices]
if graph.edges is not None:
edges = graph.edges[permuted_indices]
val_senders = senders[:num_val]
val_receivers = receivers[:num_val]
if graph.edges is not None:
val_edges = edges[:num_val]
test_senders = senders[num_val:num_val + num_test]
test_receivers = receivers[num_val:num_val + num_test]
if graph.edges is not None:
test_edges = edges[num_val:num_val + num_test]
train_senders = senders[num_val + num_test:]
train_receivers = receivers[num_val + num_test:]
train_edges = None
if graph.edges is not None:
train_edges = edges[num_val + num_test:]
# make training edges undirected by adding reverse edges back in
train_senders_undir = jnp.concatenate((train_senders, train_receivers))
train_receivers_undir = jnp.concatenate((train_receivers, train_senders))
train_senders = train_senders_undir
train_receivers = train_receivers_undir
# Negative edges.
num_nodes = graph.n_node[0]
# Create a negative adjacency mask, s.t. mask[i, j] = True iff edge i->j does
# not exist in the original graph.
neg_adj_mask = onp.ones((num_nodes, num_nodes), dtype=onp.uint8)
# upper triangular part
neg_adj_mask = onp.triu(neg_adj_mask, k=1)
neg_adj_mask[graph.senders, graph.receivers] = 0
neg_adj_mask = neg_adj_mask.astype(bool)
neg_senders, neg_receivers = neg_adj_mask.nonzero()
perm = onp.random.permutation(range(len(neg_senders)))
neg_senders = neg_senders[perm]
neg_receivers = neg_receivers[perm]
val_neg_senders = neg_senders[:num_val]
val_neg_receivers = neg_receivers[:num_val]
test_neg_senders = neg_senders[num_val:num_val + num_test]
test_neg_receivers = neg_receivers[num_val:num_val + num_test]
train_graph = jraph.GraphsTuple(
nodes=graph.nodes,
edges=train_edges,
senders=train_senders,
receivers=train_receivers,
n_node=graph.n_node,
n_edge=jnp.array([len(train_senders)]),
globals=graph.globals)
return (
train_graph, neg_adj_mask,
val_senders, val_receivers, val_neg_senders, val_neg_receivers,
test_senders, test_receivers, test_neg_senders, test_neg_receivers)
def negative_sampling(
graph: jraph.GraphsTuple, num_neg_samples: int,
key: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""Samples negative edges, i.e. edges that don't exist in the input graph.
Based fully on:
https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb
"""
num_nodes = graph.n_node[0]
total_possible_edges = num_nodes**2
# convert 2D edge indices to 1D representation.
pos_idx = graph.senders * num_nodes + graph.receivers
# Percentage to oversample edges, so most likely will sample enough neg edges.
alpha = jnp.abs(1 / (1 - 1.1 *
(graph.senders.shape[0] / total_possible_edges)))
perm = jax.random.randint(
key,
shape=(int(alpha * num_neg_samples),),
minval=0,
maxval=total_possible_edges,
dtype=jnp.uint32)
# mask where sampled edges are positive edges.
mask = jnp.isin(perm, pos_idx)
# remove positive edges.
perm = perm[~mask][:num_neg_samples]
# convert 1d back to 2d edge indices.
neg_senders = perm // num_nodes
neg_receivers = perm % num_nodes
return neg_senders, neg_receivers
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
"""Converts a JAX GraphsTuple to a NetworkX graph.
Based fully on:
https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb
"""
nodes, edges, receivers, senders, _, _, _ = jraph_graph
nx_graph = nx.DiGraph()
if nodes is None:
for n in range(jraph_graph.n_node[0]):
nx_graph.add_node(n)
else:
for n in range(jraph_graph.n_node[0]):
nx_graph.add_node(n, node_feature=nodes[n])
if edges is None:
for e in range(jraph_graph.n_edge[0]):
nx_graph.add_edge(int(senders[e]), int(receivers[e]))
else:
for e in range(jraph_graph.n_edge[0]):
nx_graph.add_edge(
int(senders[e]), int(receivers[e]), edge_feature=edges[e])
return nx_graph
def compute_norm_and_weights(graph: jraph.GraphsTuple) -> Tuple[float, float]:
graph_n_node = graph.n_node.item()
graph_adj = nx.to_numpy_matrix(convert_jraph_to_networkx_graph(graph))
adj_sum = onp.sum(graph_adj)
pos_weight = float(graph_n_node**2 - adj_sum) / adj_sum
norm_adj = graph_n_node**2 / 2.0*(graph_n_node**2 - adj_sum)
return pos_weight, norm_adj