diff --git a/configs/downstream/all.yaml b/configs/downstream/all.yaml index fad4868..3c20873 100644 --- a/configs/downstream/all.yaml +++ b/configs/downstream/all.yaml @@ -5,24 +5,24 @@ logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/logs_test datasets: - name: Immunogenicity task: classification - - name: Glycosylation - task: classification - - name: Taxonomy_Domain - task: multilabel - - name: Taxonomy_Kingdom - task: multilabel - - name: Taxonomy_Phylum - task: multilabel - - name: Taxonomy_Class - task: multilabel - - name: Taxonomy_Order - task: multilabel - - name: Taxonomy_Family - task: multilabel - - name: Taxonomy_Genus - task: multilabel - - name: Taxonomy_Species - task: multilabel + # - name: Glycosylation + # task: classification + #- name: Taxonomy_Domain + # task: multilabel + #- name: Taxonomy_Kingdom + # task: multilabel + #- name: Taxonomy_Phylum + # task: multilabel + #- name: Taxonomy_Class + # task: multilabel + #- name: Taxonomy_Order + # task: multilabel + #- name: Taxonomy_Family + # task: multilabel + #- name: Taxonomy_Genus + # task: multilabel + #- name: Taxonomy_Species + # task: multilabel pre-transforms: GIFFLARTransform: GNNGLYTransform: @@ -37,33 +37,33 @@ pre-transforms: dim: 20 individual: False model: - - name: rf - n_estimators: 500 - n_jobs: -1 - random_state: 42 - - name: svm - random_state: 42 - - name: xgb - random_state: 42 - - name: mlp - feat_dim: 1024 - hidden_dim: 1024 - batch_size: 256 - num_layers: 3 - epochs: 100 - patience: 30 - learning_rate: 0 - optimizer: Adam - - name: sweetnet - feat_dim: 128 - hidden_dim: 1024 - batch_size: 512 - num_layers: 16 - epochs: 1 - patience: 30 - learning_rate: 0.001 - optimizer: Adam - suffix: + #- name: rf + # n_estimators: 500 + # n_jobs: -1 + # random_state: 42 + #- name: svm + # random_state: 42 + #- name: xgb + # random_state: 42 + #- name: mlp + # feat_dim: 1024 + # hidden_dim: 1024 + # batch_size: 256 + # num_layers: 3 + # epochs: 1 + # patience: 30 + # learning_rate: 0 + # optimizer: Adam + #- name: sweetnet + # feat_dim: 128 + # hidden_dim: 1024 + # batch_size: 512 + # num_layers: 16 + # epochs: 1 + # patience: 30 + # learning_rate: 0.001 + # optimizer: Adam + # suffix: - name: gnngly feat_dim: 128 hidden_dim: 1024 @@ -74,15 +74,15 @@ model: learning_rate: 0.001 optimizer: Adam suffix: - - name: rgcn - feat_dim: 128 - hidden_dim: 1024 - batch_size: 256 - num_layers: 8 - epochs: 1 - learning_rate: 0.001 - optimizer: Adam - suffix: + #- name: rgcn + # feat_dim: 128 + # hidden_dim: 1024 + # batch_size: 256 + # num_layers: 8 + # epochs: 1 + # learning_rate: 0.001 + # optimizer: Adam + # suffix: - name: pyrgcn feat_dim: 128 hidden_dim: 1024 @@ -100,5 +100,5 @@ model: epochs: 1 learning_rate: 0.001 optimizer: Adam - pooling: global_pool + pooling: global_mean suffix: _128_8_gp diff --git a/gifflar/data/hetero.py b/gifflar/data/hetero.py index fb69dd4..b891d2d 100644 --- a/gifflar/data/hetero.py +++ b/gifflar/data/hetero.py @@ -46,6 +46,17 @@ def __getitem__(self, item: str) -> Any: return getattr(self, item) +def determine_concat_dim(tensors): + if len(tensors[0].shape) == 1: + return 0 + concat_dim = None + if all(tensor.shape[1] == tensors[0].shape[1] for tensor in tensors): + concat_dim = 0 # Concatenate along rows (dim 0) + elif all(tensor.shape[0] == tensors[0].shape[0] for tensor in tensors): + concat_dim = 1 # Concatenate along columns (dim 1) + return concat_dim + + def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData]]]) -> HeteroDataBatch: """ Collate a list of HeteroData objects to a batch thereof. @@ -71,7 +82,7 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] edge_attr_dict = {} # Include data for the baselines and other kwargs for house-keeping - baselines = {"gnngly", "sweetnet"} + baselines = {"gnngly", "sweetnet", "rgcn"} kwargs = {key: [] for key in dict(data[0]) if all(b not in key for b in baselines)} # Store the node counts to offset edge indices when collating @@ -121,13 +132,31 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] kwargs[f"{b}_x"] = torch.cat([d[f"{b}_x"] for d in data], dim=0) edges = [] batch = [] + e_types = [] + n_types = [] node_counts = 0 for i, d in enumerate(data): edges.append(d[f"{b}_edge_index"] + node_counts) node_counts += d[f"{b}_num_nodes"] + if b == "rgcn": + e_types += d["rgcn_edge_type"] + n_types += d["rgcn_node_type"] batch.append(torch.full((d[f"{b}_num_nodes"],), i, dtype=torch.long)) kwargs[f"{b}_edge_index"] = torch.cat(edges, dim=1) kwargs[f"{b}_batch"] = torch.cat(batch, dim=0) + if b == "rgcn": + #for d in data: + # print(d["rgcn_x"].shape) + # print(d["rgcn_edge_type"].shape) + kwargs["rgcn_edge_type"] = torch.tensor(e_types) + kwargs["rgcn_node_type"] = n_types + if hasattr(data[0], "rgcn_rw_pe"): + kwargs["rgcn_rw_pe"] = torch.cat([d["rgcn_rw_pe"] for d in data], dim=0) + if hasattr(data[0], "rgcn_lap_pe"): + kwargs["rgcn_lap_pe"] = torch.cat([d["rgcn_lap_pe"] for d in data], dim=0) + #print(kwargs["rgcn_x"].shape) + #print(kwargs["rgcn_edge_type"].shape) + #print(len(kwargs["rgcn_node_type"])) # Remove all incompletely given data and concat lists of tensors into single tensors num_nodes = {node_type: x_dict[node_type].shape[0] for node_type in node_types} @@ -137,7 +166,10 @@ def hetero_collate(data: Optional[Union[list[list[HeteroData]], list[HeteroData] elif len(value) != len(data): del kwargs[key] elif isinstance(value[0], torch.Tensor): - kwargs[key] = torch.cat(value, dim=0) + dim = determine_concat_dim(value) + if dim is None: + raise ValueError(f"Tensors for key {key} cannot be concatenated.") + kwargs[key] = torch.cat(value, dim=dim) # Finally create and return the HeteroDataBatch return HeteroDataBatch(x_dict=x_dict, edge_index_dict=edge_index_dict, edge_attr_dict=edge_attr_dict, diff --git a/gifflar/model/baselines/rgcn.py b/gifflar/model/baselines/rgcn.py index 6a07529..8951250 100644 --- a/gifflar/model/baselines/rgcn.py +++ b/gifflar/model/baselines/rgcn.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_geometric.nn import HeteroConv, GINConv, RGCNConv +from torch_geometric.nn import HeteroConv, GINConv, RGCNConv, global_mean_pool from gifflar.data.hetero import HeteroDataBatch from gifflar.model.downstream import DownstreamGGIN @@ -57,21 +57,39 @@ def __init__( dims += [hidden_dim] dims += [hidden_dim] * (num_layers - 1) self.convs = torch.nn.ModuleList([ - RGCNConv(dims[i], dims[i + 1], hidden_dim, num_relations=5) for i in range(num_layers) + RGCNConv(dims[i], dims[i + 1], num_relations=5) for i in range(num_layers) ]) + + del self.pooling + self.pooling = global_mean_pool # GIFFLARPooling() -def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]: - """ - Compute the node embeddings. + def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]: + """ + Compute the node embeddings. + + Args: + batch: The batch of data to process - Args: - batch: The batch of data to process + Returns: + node_embed: The node embeddings + """ - Returns: - node_embed: The node embeddings - """ - self.convs.forward(batch["rgcn_x"], batch["edge_index"], batch["edge_type"]) + node_embeds = [torch.stack([self.embedding.forward(batch["rgcn_x"][i], batch["rgcn_node_type"][i]) for i in range(len(batch["rgcn_x"]))])] + for pe in self.addendum: + node_embeds.append(batch[f"rgcn_{pe}"]) + node_embeds = torch.concat(node_embeds, dim=1) + + for conv in self.convs: + node_embeds = conv(node_embeds, batch["rgcn_edge_index"], batch["rgcn_edge_type"]) + + graph_embed = self.pooling(node_embeds, batch["rgcn_batch"]) + pred = self.head(graph_embed) + return { + "node_embed": node_embeds, + "graph_embed": graph_embed, + "preds": pred, + } class RGCN(DownstreamGGIN): diff --git a/gifflar/model/downstream.py b/gifflar/model/downstream.py index cf3e4d8..29c146d 100644 --- a/gifflar/model/downstream.py +++ b/gifflar/model/downstream.py @@ -54,7 +54,8 @@ def to(self, device: torch.device) -> "DownstreamGGIN": self: The model moved to the specified device """ super(DownstreamGGIN, self).to(device) - self.pooling.to(device) + if isinstance(self.pooling, GIFFLARPooling): + self.pooling.to(device) for split, metric in self.metrics.items(): self.metrics[split] = metric.to(device) return self diff --git a/gifflar/pretransforms.py b/gifflar/pretransforms.py index bae36d7..3218f67 100644 --- a/gifflar/pretransforms.py +++ b/gifflar/pretransforms.py @@ -44,7 +44,7 @@ def split_hetero_graph(data: HeteroData) -> tuple[Data, Data, Data]: def hetero_to_homo(data: HeteroData) -> Data: """Convert a heterogeneous graph to a homogeneous by collapsing the node types and removing all node features.""" bond_edge_index = data["bonds", "boundary", "bonds"]["edge_index"] + data["atoms"]["num_nodes"] - monosacchs_edge_index = data["bonds", "boundary", "bonds"]["edge_index"] + data["atoms"]["num_nodes"] + \ + monosacchs_edge_index = data["monosacchs", "boundary", "monosacchs"]["edge_index"] + data["atoms"]["num_nodes"] + \ data["bonds"]["num_nodes"] return Data( edge_index=torch.cat([ @@ -168,6 +168,7 @@ def __call__(self, data: HeteroData) -> HeteroData: data["bonds"].x, data["monosacchs"].x ]) + data["rgcn_node_type"] = ["atoms"] * len(data["atoms"].x) + ["bonds"] * len(data["bonds"].x) + ["monosacchs"] * len(data["monosacchs"].x) data["rgcn_num_nodes"] = len(data["rgcn_x"]) data["rgcn_edge_type"] = torch.tensor( [0] * data["atoms", "coboundary", "atoms"].edge_index.shape[1] @@ -327,16 +328,19 @@ def __call__(self, data: HeteroData) -> HeteroData: if self.k == 0: d[self.attr_name] = torch.tensor([[]]) else: - super(LaplacianPE, self).forward(d) + d = super(LaplacianPE, self).forward(d) pad = torch.zeros(d["num_nodes"], self.max_dim - d[self.attr_name].size(1)) d[self.attr_name] = torch.cat([d[self.attr_name], pad], dim=1) self.k = self.max_dim else: - super(LaplacianPE, self).forward(d) + d = super(LaplacianPE, self).forward(d) data[f"{name}_{self.attr_name}"] = d[self.attr_name] else: # or for the whole graph d = hetero_to_homo(data) - super(LaplacianPE, self).forward(d) + if self.k == 0: + d[self.attr_name] = torch.tensor([[]]) + else: + d = super(LaplacianPE, self).forward(d) data[f"atoms_{self.attr_name}"] = d[self.attr_name][:data["atoms"]["num_nodes"]] data[f"bonds_{self.attr_name}"] = d[self.attr_name][ data["atoms"]["num_nodes"]:-data["monosacchs"]["num_nodes"]] diff --git a/gifflar/train.py b/gifflar/train.py index 106bd31..e727b8a 100644 --- a/gifflar/train.py +++ b/gifflar/train.py @@ -146,6 +146,7 @@ def train(**kwargs: Any) -> None: ], max_epochs=kwargs["model"]["epochs"], logger=logger, + accelerator="cpu", ) start = time.time() trainer.fit(model, datamodule)