Skip to content

Commit

Permalink
Debugging PyTorchRGCN model
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 16, 2024
1 parent 13c8999 commit 29cde05
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 73 deletions.
110 changes: 55 additions & 55 deletions configs/downstream/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@ logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_test
datasets:
- name: Immunogenicity
task: classification
- name: Glycosylation
task: classification
- name: Taxonomy_Domain
task: multilabel
- name: Taxonomy_Kingdom
task: multilabel
- name: Taxonomy_Phylum
task: multilabel
- name: Taxonomy_Class
task: multilabel
- name: Taxonomy_Order
task: multilabel
- name: Taxonomy_Family
task: multilabel
- name: Taxonomy_Genus
task: multilabel
- name: Taxonomy_Species
task: multilabel
# - name: Glycosylation
# task: classification
#- name: Taxonomy_Domain
# task: multilabel
#- name: Taxonomy_Kingdom
# task: multilabel
#- name: Taxonomy_Phylum
# task: multilabel
#- name: Taxonomy_Class
# task: multilabel
#- name: Taxonomy_Order
# task: multilabel
#- name: Taxonomy_Family
# task: multilabel
#- name: Taxonomy_Genus
# task: multilabel
#- name: Taxonomy_Species
# task: multilabel
pre-transforms:
GIFFLARTransform:
GNNGLYTransform:
Expand All @@ -37,33 +37,33 @@ pre-transforms:
dim: 20
individual: False
model:
- name: rf
n_estimators: 500
n_jobs: -1
random_state: 42
- name: svm
random_state: 42
- name: xgb
random_state: 42
- name: mlp
feat_dim: 1024
hidden_dim: 1024
batch_size: 256
num_layers: 3
epochs: 100
patience: 30
learning_rate: 0
optimizer: Adam
- name: sweetnet
feat_dim: 128
hidden_dim: 1024
batch_size: 512
num_layers: 16
epochs: 1
patience: 30
learning_rate: 0.001
optimizer: Adam
suffix:
#- name: rf
# n_estimators: 500
# n_jobs: -1
# random_state: 42
#- name: svm
# random_state: 42
#- name: xgb
# random_state: 42
#- name: mlp
# feat_dim: 1024
# hidden_dim: 1024
# batch_size: 256
# num_layers: 3
# epochs: 1
# patience: 30
# learning_rate: 0
# optimizer: Adam
#- name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 512
# num_layers: 16
# epochs: 1
# patience: 30
# learning_rate: 0.001
# optimizer: Adam
# suffix:
- name: gnngly
feat_dim: 128
hidden_dim: 1024
Expand All @@ -74,15 +74,15 @@ model:
learning_rate: 0.001
optimizer: Adam
suffix:
- name: rgcn
feat_dim: 128
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 1
learning_rate: 0.001
optimizer: Adam
suffix:
#- name: rgcn
# feat_dim: 128
# hidden_dim: 1024
# batch_size: 256
# num_layers: 8
# epochs: 1
# learning_rate: 0.001
# optimizer: Adam
# suffix:
- name: pyrgcn
feat_dim: 128
hidden_dim: 1024
Expand All @@ -100,5 +100,5 @@ model:
epochs: 1
learning_rate: 0.001
optimizer: Adam
pooling: global_pool
pooling: global_mean
suffix: _128_8_gp
36 changes: 34 additions & 2 deletions gifflar/data/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def __getitem__(self, item: str) -> Any:
return getattr(self, item)


def determine_concat_dim(tensors):
if len(tensors[0].shape) == 1:
return 0
concat_dim = None
if all(tensor.shape[1] == tensors[0].shape[1] for tensor in tensors):
concat_dim = 0 # Concatenate along rows (dim 0)
elif all(tensor.shape[0] == tensors[0].shape[0] for tensor in tensors):
concat_dim = 1 # Concatenate along columns (dim 1)
return concat_dim


def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]]]) -> HeteroDataBatch:
"""
Collate a list of HeteroData objects to a batch thereof.
Expand All @@ -71,7 +82,7 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]
edge_attr_dict = {}

# Include data for the baselines and other kwargs for house-keeping
baselines = {"gnngly", "sweetnet"}
baselines = {"gnngly", "sweetnet", "rgcn"}
kwargs = {key: [] for key in dict(data[0]) if all(b not in key for b in baselines)}

# Store the node counts to offset edge indices when collating
Expand Down Expand Up @@ -121,13 +132,31 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]
kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0)
edges = []
batch = []
e_types = []
n_types = []
node_counts = 0
for i, d in enumerate(data):
edges.append(d[f"{b}_edge_index"] + node_counts)
node_counts += d[f"{b}_num_nodes"]
if b == "rgcn":
e_types += d["rgcn_edge_type"]
n_types += d["rgcn_node_type"]
batch.append(torch.full((d[f"{b}_num_nodes"],), i, dtype=torch.long))
kwargs[f"{b}_edge_index"] = torch.cat(edges, dim=1)
kwargs[f"{b}_batch"] = torch.cat(batch, dim=0)
if b == "rgcn":
#for d in data:
# print(d["rgcn_x"].shape)
# print(d["rgcn_edge_type"].shape)
kwargs["rgcn_edge_type"] = torch.tensor(e_types)
kwargs["rgcn_node_type"] = n_types
if hasattr(data[0], "rgcn_rw_pe"):
kwargs["rgcn_rw_pe"] = torch.cat([d["rgcn_rw_pe"] for d in data], dim=0)
if hasattr(data[0], "rgcn_lap_pe"):
kwargs["rgcn_lap_pe"] = torch.cat([d["rgcn_lap_pe"] for d in data], dim=0)
#print(kwargs["rgcn_x"].shape)
#print(kwargs["rgcn_edge_type"].shape)
#print(len(kwargs["rgcn_node_type"]))

# Remove all incompletely given data and concat lists of tensors into single tensors
num_nodes = {node_type: x_dict[node_type].shape[0] for node_type in node_types}
Expand All @@ -137,7 +166,10 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]
elif len(value) != len(data):
del kwargs[key]
elif isinstance(value[0], torch.Tensor):
kwargs[key] = torch.cat(value, dim=0)
dim = determine_concat_dim(value)
if dim is None:
raise ValueError(f"Tensors for key {key} cannot be concatenated.")
kwargs[key] = torch.cat(value, dim=dim)

# Finally create and return the HeteroDataBatch
return HeteroDataBatch(x_dict=x_dict, edge_index_dict=edge_index_dict, edge_attr_dict=edge_attr_dict,
Expand Down
40 changes: 29 additions & 11 deletions gifflar/model/baselines/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import nn
from torch_geometric.nn import HeteroConv, GINConv, RGCNConv
from torch_geometric.nn import HeteroConv, GINConv, RGCNConv, global_mean_pool

from gifflar.data.hetero import HeteroDataBatch
from gifflar.model.downstream import DownstreamGGIN
Expand Down Expand Up @@ -57,21 +57,39 @@ def __init__(
dims += [hidden_dim]
dims += [hidden_dim] * (num_layers - 1)
self.convs = torch.nn.ModuleList([
RGCNConv(dims[i], dims[i + 1], hidden_dim, num_relations=5) for i in range(num_layers)
RGCNConv(dims[i], dims[i + 1], num_relations=5) for i in range(num_layers)
])

del self.pooling
self.pooling = global_mean_pool # GIFFLARPooling()


def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
"""
Compute the node embeddings.
def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
"""
Compute the node embeddings.
Args:
batch: The batch of data to process
Args:
batch: The batch of data to process
Returns:
node_embed: The node embeddings
"""

Returns:
node_embed: The node embeddings
"""
self.convs.forward(batch["rgcn_x"], batch["edge_index"], batch["edge_type"])
node_embeds = [torch.stack([self.embedding.forward(batch["rgcn_x"][i], batch["rgcn_node_type"][i]) for i in range(len(batch["rgcn_x"]))])]
for pe in self.addendum:
node_embeds.append(batch[f"rgcn_{pe}"])
node_embeds = torch.concat(node_embeds, dim=1)

for conv in self.convs:
node_embeds = conv(node_embeds, batch["rgcn_edge_index"], batch["rgcn_edge_type"])

graph_embed = self.pooling(node_embeds, batch["rgcn_batch"])
pred = self.head(graph_embed)
return {
"node_embed": node_embeds,
"graph_embed": graph_embed,
"preds": pred,
}


class RGCN(DownstreamGGIN):
Expand Down
3 changes: 2 additions & 1 deletion gifflar/model/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def to(self, device: torch.device) -> "DownstreamGGIN":
self: The model moved to the specified device
"""
super(DownstreamGGIN, self).to(device)
self.pooling.to(device)
if isinstance(self.pooling, GIFFLARPooling):
self.pooling.to(device)
for split, metric in self.metrics.items():
self.metrics[split] = metric.to(device)
return self
Expand Down
12 changes: 8 additions & 4 deletions gifflar/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def split_hetero_graph(data: HeteroData) -> tuple[Data, Data, Data]:
def hetero_to_homo(data: HeteroData) -> Data:
"""Convert a heterogeneous graph to a homogeneous by collapsing the node types and removing all node features."""
bond_edge_index = data["bonds", "boundary", "bonds"]["edge_index"] + data["atoms"]["num_nodes"]
monosacchs_edge_index = data["bonds", "boundary", "bonds"]["edge_index"] + data["atoms"]["num_nodes"] + \
monosacchs_edge_index = data["monosacchs", "boundary", "monosacchs"]["edge_index"] + data["atoms"]["num_nodes"] + \
data["bonds"]["num_nodes"]
return Data(
edge_index=torch.cat([
Expand Down Expand Up @@ -168,6 +168,7 @@ def __call__(self, data: HeteroData) -> HeteroData:
data["bonds"].x,
data["monosacchs"].x
])
data["rgcn_node_type"] = ["atoms"] * len(data["atoms"].x) + ["bonds"] * len(data["bonds"].x) + ["monosacchs"] * len(data["monosacchs"].x)
data["rgcn_num_nodes"] = len(data["rgcn_x"])
data["rgcn_edge_type"] = torch.tensor(
[0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1]
Expand Down Expand Up @@ -327,16 +328,19 @@ def __call__(self, data: HeteroData) -> HeteroData:
if self.k == 0:
d[self.attr_name] = torch.tensor([[]])
else:
super(LaplacianPE, self).forward(d)
d = super(LaplacianPE, self).forward(d)
pad = torch.zeros(d["num_nodes"], self.max_dim - d[self.attr_name].size(1))
d[self.attr_name] = torch.cat([d[self.attr_name], pad], dim=1)
self.k = self.max_dim
else:
super(LaplacianPE, self).forward(d)
d = super(LaplacianPE, self).forward(d)
data[f"{name}_{self.attr_name}"] = d[self.attr_name]
else: # or for the whole graph
d = hetero_to_homo(data)
super(LaplacianPE, self).forward(d)
if self.k == 0:
d[self.attr_name] = torch.tensor([[]])
else:
d = super(LaplacianPE, self).forward(d)
data[f"atoms_{self.attr_name}"] = d[self.attr_name][:data["atoms"]["num_nodes"]]
data[f"bonds_{self.attr_name}"] = d[self.attr_name][
data["atoms"]["num_nodes"]:-data["monosacchs"]["num_nodes"]]
Expand Down
1 change: 1 addition & 0 deletions gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def train(**kwargs: Any) -> None:
],
max_epochs=kwargs["model"]["epochs"],
logger=logger,
accelerator="cpu",
)
start = time.time()
trainer.fit(model, datamodule)
Expand Down

0 comments on commit 29cde05

Please sign in to comment.