Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freeze arguments to pick_calculator #2695

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,15 @@ jobs:
run: |
echo "COVERAGE_CORE=sysmon" >> $GITHUB_ENV

- name: huggingface hub login
- name: HuggingFace Hub Login
env:
HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
run: huggingface-cli login --token $HF_TOKEN
run: |
if [ -n "$HF_TOKEN" ]; then
huggingface-cli login --token "$HF_TOKEN"
else
echo "HF_TOKEN is not set. Skipping login."
fi

- name: Run tests with pytest
run: pytest -k 'mlp or newtonnet or geodesic' --durations=10 --cov=quacc --cov-report=xml
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"ase>=3.24.0", # for Atoms object and calculators
"custodian>=2024.10.15", # for automated error corrections
"emmet-core>=0.84.3rc6", # for pre-made schemas
"frozendict>=2.4.6", # for caching of dictionaries in @lru_cache
"maggma>=0.64.0", # for database handling
"monty>=2024.5.15", # miscellaneous Python utilities
"numpy>=1.25.0", # for array handling
Expand All @@ -45,7 +46,7 @@ covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-clou
dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"]
defects = ["pymatgen-analysis-defects>=2024.10.22", "shakenbreak>=3.2.0"]
jobflow = ["jobflow>=0.1.14", "jobflow-remote[gui]>=0.1.0"]
mlp1 = ["chgnet>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1", "orb-models>=0.4.1", "fairchem-core>=1.4.0"]
mlp1 = ["chgnet>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1", "orb-models>=0.4.1", "fairchem-core>=1.6.0"]
mlp2 = ["mace-torch>=0.3.3", "matgl>=1.1.2"]
mp = ["atomate2>=0.0.14"]
newtonnet = ["newtonnet>=1.1"]
Expand Down
22 changes: 21 additions & 1 deletion src/quacc/recipes/mlp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

from functools import lru_cache
from functools import lru_cache, wraps
from importlib.util import find_spec
from logging import getLogger
from typing import TYPE_CHECKING

from frozendict import frozendict

if TYPE_CHECKING:
from typing import Literal

Expand All @@ -15,6 +17,24 @@
LOGGER = getLogger(__name__)


def freezeargs(func):
"""Convert a mutable dictionary into immutable.
Useful to make sure dictionary args are compatible with cache
From https://stackoverflow.com/a/53394430
"""

@wraps(func)
def wrapped(*args, **kwargs):
args = (frozendict(arg) if isinstance(arg, dict) else arg for arg in args)
kwargs = {
k: frozendict(v) if isinstance(v, dict) else v for k, v in kwargs.items()
}
return func(*args, **kwargs)

return wrapped


@freezeargs
@lru_cache
def pick_calculator(
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb", "fairchem"],
Expand Down
20 changes: 17 additions & 3 deletions tests/core/recipes/mlp_recipes/test_core_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
if has_orb := find_spec("orb_models"):
methods.append("orb")

if has_fairchem := find_spec("fairchem"):
methods.append("fairchem")
if find_spec("fairchem"):
from huggingface_hub.utils._auth import get_token

if get_token() is not None:
has_fairchem_and_huggingface_token = True
methods.append("fairchem")


@pytest.mark.skipif(has_chgnet is None, reason="chgnet not installed")
Expand Down Expand Up @@ -77,6 +81,16 @@ def test_static_job(tmp_path, monkeypatch, method):
assert output["atoms"] == atoms


@pytest.mark.skipif(has_sevennet is None, reason="sevennet not installed")
def test_static_job_with_dict_kwargs(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

atoms = bulk("Cu")

# Make sure that pick_calculator works even with dictionary kwargs
static_job(atoms, method="sevennet", sevennet_config={"test": 1})


def test_relax_job_missing_pynanoflann(monkeypatch):
def mock_find_spec(name):
if name == "pynanoflann":
Expand All @@ -85,7 +99,7 @@ def mock_find_spec(name):

import quacc.recipes.mlp._base

quacc.recipes.mlp._base.pick_calculator.cache_clear()
quacc.recipes.mlp._base.pick_calculator.__wrapped__.cache_clear()
monkeypatch.setattr("importlib.util.find_spec", mock_find_spec)
monkeypatch.setattr("quacc.recipes.mlp._base.find_spec", mock_find_spec)
with pytest.raises(ImportError, match=r"orb-models requires pynanoflann"):
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
ase==3.24.0
custodian==2024.10.16
emmet-core==0.84.5
frozendict==2.4.6
maggma==0.71.5
monty==2025.3.3
numpy==1.26.4
Expand Down
Loading