Skip to content

Commit

Permalink
Freeze arguments to pick_calculator (#2695)
Browse files Browse the repository at this point in the history
## Summary of Changes

`lru_cache` is used to make sure that each MLP calculator is only
instantiated once (great!). However, if you want to pass dictionaries
through to the underlying calculator (say, config overrides to the
fairchem calculator), you will get errors because dictionaries are
mutable and lru_cache is unhappy. We can patch this by converting any
dictionary kwargs to a frozendict.

### Requirements

- [x] My PR is focused on a [single feature addition or
bugfix](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/getting-started/best-practices-for-pull-requests#write-small-prs).
- [x] My PR has relevant, comprehensive [unit
tests](https://quantum-accelerators.github.io/quacc/dev/contributing.html#unit-tests).
- [x] My PR is on a [custom
branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-and-deleting-branches-within-your-repository)
(i.e. is _not_ named `main`).

Note: If you are an external contributor, you will see a comment from
[@buildbot-princeton](https://github.com/buildbot-princeton). This is
solely for the maintainers.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andrew S. Rosen <asrosen93@gmail.com>
  • Loading branch information
3 people authored Mar 11, 2025
1 parent 6bc9a2d commit 875e93d
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 7 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,15 @@ jobs:
run: |
echo "COVERAGE_CORE=sysmon" >> $GITHUB_ENV
- name: huggingface hub login
- name: HuggingFace Hub Login
env:
HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
run: huggingface-cli login --token $HF_TOKEN
run: |
if [ -n "$HF_TOKEN" ]; then
huggingface-cli login --token "$HF_TOKEN"
else
echo "HF_TOKEN is not set. Skipping login."
fi
- name: Run tests with pytest
run: pytest -k 'mlp or newtonnet or geodesic' --durations=10 --cov=quacc --cov-report=xml
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"ase>=3.24.0", # for Atoms object and calculators
"custodian>=2024.10.15", # for automated error corrections
"emmet-core>=0.84.3rc6", # for pre-made schemas
"frozendict>=2.4.6", # for caching of dictionaries in @lru_cache
"maggma>=0.64.0", # for database handling
"monty>=2024.5.15", # miscellaneous Python utilities
"numpy>=1.25.0", # for array handling
Expand All @@ -45,7 +46,7 @@ covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-clou
dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"]
defects = ["pymatgen-analysis-defects>=2024.10.22", "shakenbreak>=3.2.0"]
jobflow = ["jobflow>=0.1.14", "jobflow-remote[gui]>=0.1.0"]
mlp1 = ["chgnet>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1", "orb-models>=0.4.1", "fairchem-core>=1.4.0"]
mlp1 = ["chgnet>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1", "orb-models>=0.4.1", "fairchem-core>=1.6.0"]
mlp2 = ["mace-torch>=0.3.3", "matgl>=1.1.2"]
mp = ["atomate2>=0.0.14"]
newtonnet = ["newtonnet>=1.1"]
Expand Down
22 changes: 21 additions & 1 deletion src/quacc/recipes/mlp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

from functools import lru_cache
from functools import lru_cache, wraps
from importlib.util import find_spec
from logging import getLogger
from typing import TYPE_CHECKING

from frozendict import frozendict

if TYPE_CHECKING:
from typing import Literal

Expand All @@ -15,6 +17,24 @@
LOGGER = getLogger(__name__)


def freezeargs(func):
"""Convert a mutable dictionary into immutable.
Useful to make sure dictionary args are compatible with cache
From https://stackoverflow.com/a/53394430
"""

@wraps(func)
def wrapped(*args, **kwargs):
args = (frozendict(arg) if isinstance(arg, dict) else arg for arg in args)
kwargs = {
k: frozendict(v) if isinstance(v, dict) else v for k, v in kwargs.items()
}
return func(*args, **kwargs)

return wrapped


@freezeargs
@lru_cache
def pick_calculator(
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb", "fairchem"],
Expand Down
20 changes: 17 additions & 3 deletions tests/core/recipes/mlp_recipes/test_core_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
if has_orb := find_spec("orb_models"):
methods.append("orb")

if has_fairchem := find_spec("fairchem"):
methods.append("fairchem")
if find_spec("fairchem"):
from huggingface_hub.utils._auth import get_token

if get_token() is not None:
has_fairchem_and_huggingface_token = True
methods.append("fairchem")


@pytest.mark.skipif(has_chgnet is None, reason="chgnet not installed")
Expand Down Expand Up @@ -77,6 +81,16 @@ def test_static_job(tmp_path, monkeypatch, method):
assert output["atoms"] == atoms


@pytest.mark.skipif(has_sevennet is None, reason="sevennet not installed")
def test_static_job_with_dict_kwargs(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

atoms = bulk("Cu")

# Make sure that pick_calculator works even with dictionary kwargs
static_job(atoms, method="sevennet", sevennet_config={"test": 1})


def test_relax_job_missing_pynanoflann(monkeypatch):
def mock_find_spec(name):
if name == "pynanoflann":
Expand All @@ -85,7 +99,7 @@ def mock_find_spec(name):

import quacc.recipes.mlp._base

quacc.recipes.mlp._base.pick_calculator.cache_clear()
quacc.recipes.mlp._base.pick_calculator.__wrapped__.cache_clear()
monkeypatch.setattr("importlib.util.find_spec", mock_find_spec)
monkeypatch.setattr("quacc.recipes.mlp._base.find_spec", mock_find_spec)
with pytest.raises(ImportError, match=r"orb-models requires pynanoflann"):
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
ase==3.24.0
custodian==2024.10.16
emmet-core==0.84.5
frozendict==2.4.6
maggma==0.71.5
monty==2025.3.3
numpy==1.26.4
Expand Down

0 comments on commit 875e93d

Please sign in to comment.