Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/collate fn #93

Merged
merged 13 commits into from
Dec 19, 2023
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
matrix:
python-version: ["3.9"]
os: ["ubuntu-latest"] #, "macos-latest", "windows-latest"]
pytorch-version: ["1.12"]
pytorch-version: ["1.13"]

runs-on: ${{ matrix.os }}
timeout-minutes: 30
Expand Down
11 changes: 6 additions & 5 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- dglteam

dependencies:
- python >=3.8
- python >3.8
- pip
- tqdm
- pyyaml
Expand Down Expand Up @@ -33,18 +33,19 @@ dependencies:
- mordredcommunity

# ML
- pytorch =1.12
- pytorch >=1.13
- scikit-learn
- fcd_torch

# Optional: featurizers
- dgl
- dgllife
- dgl >=1.1.1
- dgllife >=0.3.2
- graphormer-pretrained >=0.2.3
- transformers
- tokenizers <0.13.2
- sentencepiece
- biotite # required for ESM models
- biotite # required for ESM model
- pytorch_geometric >=2.4.0
zhu0619 marked this conversation as resolved.
Show resolved Hide resolved

# Optional: viz
- nglview
Expand Down
27 changes: 24 additions & 3 deletions molfeat/trans/graph/adj.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from functools import partial
from typing import Any
from typing import Optional
from typing import Callable
from typing import List
from typing import Union
from typing import Sequence
from typing import TYPE_CHECKING

import torch
import datamol as dm
Expand All @@ -27,10 +30,18 @@
if requires.check("dgllife"):
from dgllife import utils as dgllife_utils


if requires.check("torch_geometric"):
from torch_geometric.data import Data
from torch_geometric.loader.dataloader import Collater

if TYPE_CHECKING:
from torch_geometric.data import Dataset as PygDataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
else:
PygDataset, BaseData, DatasetAdapter = Any, Any, Any


class GraphTransformer(MoleculeTransformer):
"""
Expand Down Expand Up @@ -659,7 +670,9 @@ def patch_feats(*args, **kwargs):


class PYGGraphTransformer(AdjGraphTransformer):
"""Graph transformer for the PYG models"""
"""
Graph transformer for the PYG models
"""

def _graph_featurizer(self, mol: dm.Mol):
# we have used bond_calculator, therefore we need to
Expand Down Expand Up @@ -727,23 +740,31 @@ def transform(self, mols: List[Union[dm.Mol, str]], **kwargs):

def get_collate_fn(
self,
dataset: Optional[Union[PygDataset, Sequence[BaseData], DatasetAdapter]] = None,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
return_pair: Optional[bool] = True,
**kwargs,
):
"""
Get collate function for pyg graphs
Get collate function for pyg graphs.
Note: The `collate_fn` is not required when using `torch_geometric.loader.dataloader.DataLoader`.

Args:
dataset: The dataset from which to load the data and apply the collate function.
This is required if the dataset is <torch_geometric.data.on_disk_dataset.OnDiskDataset>.
follow_batch: Creates assignment batch vectors for each key in the list. (default: :obj:`None`)
exclude_keys: Will exclude each key in the list. (default: :obj:`None`)
return_pair: whether to return a pair of X,y or a databatch (default: :obj:`True`)

Returns:
Collated samples.

See Also:
<torch_geometric.loader.dataloader.Collator>
<torch_geometric.loader.dataloader.DataLoader>
"""
collator = Collater(follow_batch=follow_batch, exclude_keys=exclude_keys)
collator = Collater(dataset=dataset, follow_batch=follow_batch, exclude_keys=exclude_keys)
zhu0619 marked this conversation as resolved.
Show resolved Hide resolved
return partial(self._collate_batch, collator=collator, return_pair=return_pair)

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dynamic = ["version"]
authors = [{ name = "Emmanuel Noutahi", email = "emmanuel.noutahi@hotmail.ca" }]
readme = "README.md"
license = { text = "Apache" }
requires-python = ">=3.8"
requires-python = ">3.8"
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
Expand Down Expand Up @@ -39,7 +39,7 @@ dependencies = [
"pandas",
"numpy",
"scipy",
"torch",
"torch>=1.13",
"datamol >=0.8.0",
"pyyaml",
"platformdirs",
Expand All @@ -57,7 +57,7 @@ dependencies = [
]

[project.optional-dependencies]
dgl = ["dgl", "dgllife"]
dgl = ["dgl>=1.1.1", "dgllife>=0.3.2"]

graphormer = ["graphormer-pretrained"]

Expand All @@ -67,6 +67,8 @@ fcd = ["fcd_torch"]

viz = ["nglview", "ipywidgets"]

pyg = ["pytorch_geometric >=2.4.0"]

all = [
"dgl",
"dgllife",
Expand All @@ -76,6 +78,7 @@ all = [
"fcd_torch",
"nglview",
"ipywidgets",
"pytorch_geometric >=2.4.0"
]

test = ["pytest >=6.0","pytest-dotenv", "pytest-cov", "pytest-xdist", "black >=22", "ruff"]
Expand Down