-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
99de7cb
commit 3c07d1a
Showing
4 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
loader: | ||
_target_: topobenchmark.data.loaders.OGBGDatasetLoader | ||
parameters: | ||
data_domain: graph | ||
data_type: OGBGDataset | ||
data_name: MOLHIV | ||
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} | ||
|
||
# Dataset parameters | ||
parameters: | ||
num_features: 9 | ||
num_classes: 2 | ||
task: classification | ||
loss_type: cross_entropy | ||
monitor_metric: accuracy | ||
task_level: graph | ||
|
||
#splits | ||
split_params: | ||
learning_setting: inductive | ||
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} | ||
data_seed: 0 | ||
split_type: random #'k-fold' # either "k-fold" or "random" strategies | ||
k: 10 # for "k-fold" Cross-Validation | ||
train_prop: 0.5 # for "random" strategy splitting | ||
|
||
# Dataloader parameters | ||
dataloader_params: | ||
batch_size: 64 | ||
num_workers: 0 | ||
pin_memory: False | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
loader: | ||
_target_: topobenchmark.data.loaders.OGBGDatasetLoader | ||
parameters: | ||
data_domain: graph | ||
data_type: OGBGDataset | ||
data_name: MOLPCBA | ||
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} | ||
|
||
# Dataset parameters | ||
parameters: | ||
num_features: 9 | ||
num_classes: 2 | ||
task: classification | ||
loss_type: cross_entropy | ||
monitor_metric: accuracy | ||
task_level: graph | ||
|
||
#splits | ||
split_params: | ||
learning_setting: inductive | ||
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} | ||
data_seed: 0 | ||
split_type: random #'k-fold' # either "k-fold" or "random" strategies | ||
k: 10 # for "k-fold" Cross-Validation | ||
train_prop: 0.5 # for "random" strategy splitting | ||
|
||
# Dataloader parameters | ||
dataloader_params: | ||
batch_size: 64 | ||
num_workers: 0 | ||
pin_memory: False | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
loader: | ||
_target_: topobenchmark.data.loaders.OGBGDatasetLoader | ||
parameters: | ||
data_domain: graph | ||
data_type: OGBGDataset | ||
data_name: PPA | ||
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} | ||
|
||
# Dataset parameters | ||
parameters: | ||
num_features: 37 | ||
num_classes: 37 | ||
task: classification | ||
loss_type: cross_entropy | ||
monitor_metric: accuracy | ||
task_level: graph | ||
|
||
#splits | ||
split_params: | ||
learning_setting: inductive | ||
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} | ||
data_seed: 0 | ||
split_type: random #'k-fold' # either "k-fold" or "random" strategies | ||
k: 10 # for "k-fold" Cross-Validation | ||
train_prop: 0.5 # for "random" strategy splitting | ||
|
||
# Dataloader parameters | ||
dataloader_params: | ||
batch_size: 64 | ||
num_workers: 0 | ||
pin_memory: False | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Loaders for Graph Property Prediction datasets.""" | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
from ogb.graphproppred import PygGraphPropPredDataset | ||
from omegaconf import DictConfig | ||
from torch_geometric.data import Dataset | ||
|
||
from topobenchmark.data.loaders.base import AbstractLoader | ||
|
||
|
||
class OGBGDatasetLoader(AbstractLoader): | ||
"""Load molecule datasets (molhiv, molpcba, ppa) with predefined splits. | ||
Parameters | ||
---------- | ||
parameters : DictConfig | ||
Configuration parameters containing: | ||
- data_dir: Root directory for data | ||
- data_name: Name of the dataset | ||
- data_type: Type of the dataset (e.g., "molecule") | ||
""" | ||
|
||
def __init__(self, parameters: DictConfig) -> None: | ||
super().__init__(parameters) | ||
self.datasets: list[Dataset] = [] | ||
|
||
def load_dataset(self) -> Dataset: | ||
"""Load the molecule dataset with predefined splits. | ||
Returns | ||
------- | ||
Dataset | ||
The combined dataset with predefined splits. | ||
Raises | ||
------ | ||
RuntimeError | ||
If dataset loading fails. | ||
""" | ||
|
||
split_idx = self._load_splits() | ||
combined_dataset = self._combine_splits() | ||
combined_dataset.split_idx = split_idx | ||
return combined_dataset | ||
|
||
def _load_splits(self) -> None: | ||
"""Load the dataset splits for the specified dataset. | ||
Returns | ||
------- | ||
dict | ||
The split indices for the dataset. | ||
""" | ||
dataset = PygGraphPropPredDataset( | ||
name="ogbg-" + self.parameters.data_name.lower() | ||
) | ||
split_idx = dataset.get_idx_split() | ||
|
||
for split in ["train", "valid", "test"]: | ||
ds = dataset[split_idx[split]] | ||
ds.x = ds.x.long() | ||
self.datasets.append(ds) | ||
return split_idx | ||
|
||
def _combine_splits(self) -> Dataset: | ||
"""Combine the dataset splits into a single dataset. | ||
Returns | ||
------- | ||
Dataset | ||
The combined dataset containing all splits. | ||
""" | ||
return self.datasets[0] + self.datasets[1] + self.datasets[2] | ||
|
||
def get_data_dir(self) -> Path: | ||
"""Get the data directory. | ||
Returns | ||
------- | ||
Path | ||
The path to the dataset directory. | ||
""" | ||
return os.path.join(self.root_data_dir, self.parameters.data_name) |