Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New prototype.py module with moyopy-powered AFLOW-style proto-structure labeling #198

Merged
merged 18 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed data/mp/2023-01-10-mp-energies.csv.gz
Binary file not shown.
19 changes: 11 additions & 8 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Download all MP formation and above hull energies on 2023-01-10.
The main purpose of this script is produce the file at DataFiles.mp_energies.path.
Related EDA of MP formation energies:
https://github.com/janosh/pymatviz/blob/-/examples/mp_bimodal_e_form.ipynb
"""
Expand All @@ -9,14 +11,14 @@

import pandas as pd
import pymatviz as pmv
from aviary.wren.utils import get_protostructure_label_from_spglib
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatviz.enums import Key
from tqdm import tqdm

from matbench_discovery import STABILITY_THRESHOLD, today
from matbench_discovery.data import DataFiles
from matbench_discovery.structure import prototype

__author__ = "Janosh Riebesell"
__date__ = "2023-01-10"
Expand Down Expand Up @@ -63,17 +65,18 @@
)

df_cse[Key.structure] = [
Structure.from_dict(cse[Key.structure]) for cse in tqdm(df_cse.entry)
Structure.from_dict(cse[Key.structure])
for cse in tqdm(df_cse.entry, desc="Hydrating structures")
]
df_cse[Key.wyckoff] = [
get_protostructure_label_from_spglib(struct, errors="ignore")
for struct in tqdm(df_cse.structure)
df_cse[f"{Key.protostructure}_moyo"] = [
prototype.get_protostructure_label(struct)
for struct in tqdm(df_cse.structure, desc="Calculating proto-structure labels")
]
# make sure symmetry detection succeeded for all structures
assert df_cse[Key.wyckoff].str.startswith("invalid").sum() == 0
df_mp[Key.wyckoff] = df_cse[Key.wyckoff]
assert df_cse[f"{Key.protostructure}_moyo"].str.startswith("invalid").sum() == 0
df_mp[f"{Key.protostructure}_moyo"] = df_cse[f"{Key.protostructure}_moyo"]

spg_nums = df_mp[Key.wyckoff].str.split("_").str[2].astype(int)
spg_nums = df_mp[f"{Key.protostructure}_moyo"].str.split("_").str[2].astype(int)
# make sure all our spacegroup numbers match MP's
assert (spg_nums.sort_index() == df_spg["number"].sort_index()).all()

Expand Down
64 changes: 29 additions & 35 deletions data/wbm/compile_wbm_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from matbench_discovery.data import DataFiles
from matbench_discovery.energy import calc_energy_from_e_refs, mp_elemental_ref_energies
from matbench_discovery.enums import MbdKey
from matbench_discovery.structure import prototype

try:
import gdown
Expand Down Expand Up @@ -629,41 +630,34 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:


# %%
try:
from aviary.wren.utils import get_protostructure_label_from_spglib

# from initial structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx].get(MbdKey.init_wyckoff)):
continue # Aflow label already computed
try:
struct = Structure.from_dict(df_wbm.loc[idx, Key.init_struct])
df_summary.loc[idx, MbdKey.init_wyckoff] = (
get_protostructure_label_from_spglib(struct)
)
except Exception as exc:
print(f"{idx=} {exc=}")

# from relaxed structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx].get(Key.wyckoff)):
continue

try:
cse = df_wbm.loc[idx, Key.computed_structure_entry]
struct = Structure.from_dict(cse["structure"])
df_summary.loc[idx, Key.wyckoff] = get_protostructure_label_from_spglib(
struct
)
except Exception as exc:
print(f"{idx=} {exc=}")

assert df_summary[MbdKey.init_wyckoff].isna().sum() == 0
assert df_summary[Key.wyckoff].isna().sum() == 0
except ImportError:
print("aviary not installed, skipping Wyckoff label generation")
except Exception as exception:
print(f"Generating Aflow labels raised {exception=}")
# from initial structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx].get(MbdKey.init_wyckoff)):
continue # Aflow label already computed
try:
struct = Structure.from_dict(df_wbm.loc[idx, Key.init_struct])
df_summary.loc[idx, f"{Key.protostructure}_moyo_init"] = (
prototype.get_protostructure_label(struct)
)
except Exception as exc:
print(f"{idx=} {exc=}")

# from relaxed structures
for idx in tqdm(df_wbm.index):
if not pd.isna(df_summary.loc[idx].get(Key.wyckoff)):
continue

try:
cse = df_wbm.loc[idx, Key.computed_structure_entry]
struct = Structure.from_dict(cse["structure"])
df_summary.loc[idx, f"{Key.protostructure}_moyo_relaxed"] = (
prototype.get_protostructure_label(struct)
)
except Exception as exc:
print(f"{idx=} {exc=}")

assert df_summary[MbdKey.init_wyckoff].isna().sum() == 0
assert df_summary[Key.wyckoff].isna().sum() == 0


# %%
Expand Down
8 changes: 4 additions & 4 deletions matbench_discovery/data-files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ mp_elemental_ref_entries:
md5: 6e93b6f38d6e27d6c811d3cafb23a070

mp_energies:
url: https://figshare.com/ndownloader/files/49083124
path: mp/2023-01-10-mp-energies.csv.gz
description: Materials Project formation energies and energies above convex hull in eV/atom as a fast-to-load CSV file
md5: 888579e287c8417e2202330c40e1367f
url: https://figshare.com/ndownloader/files/52080797
path: mp/2025-02-01-mp-energies.csv.gz
description: Materials Project formation energies and energies above convex hull in eV/atom as a fast-to-load CSV file.
md5: 7eb0c49fc169ba92783f2f6d0d19d741

mp_patched_phase_diagram:
url: https://figshare.com/ndownloader/files/48241624
Expand Down
6 changes: 3 additions & 3 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class DataFiles(Files):
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
)
mp_elemental_ref_entries = "mp/2023-02-07-mp-elemental-reference-entries.json.gz"
mp_energies = "mp/2023-01-10-mp-energies.csv.gz"
mp_energies = "mp/2025-02-01-mp-energies.csv.gz"
mp_patched_phase_diagram = "mp/2023-02-07-ppd-mp.pkl.gz"
mp_trj_json_gz = "mp/2022-09-16-mp-trj.json.gz"
mp_trj_extxyz = "mp/2024-09-03-mp-trj.extxyz.zip"
Expand Down Expand Up @@ -423,8 +423,8 @@ class Model(Files, base_dir=f"{ROOT}/models"):
cgcnn_p = "cgcnn/cgcnn+p.yml"

# DeepMD-DPA3 models
dpa3_v1_mptrj = "deepmd_dpa3/dpa3-v1-mptrj.yml"
dpa3_v1_openlam = "deepmd_dpa3/dpa3-v1-openlam.yml"
dpa3_v1_mptrj = "deepmd/dpa3-v1-mptrj.yml"
dpa3_v1_openlam = "deepmd/dpa3-v1-openlam.yml"

# original M3GNet straight from publication, not re-trained
m3gnet = "m3gnet/m3gnet.yml"
Expand Down
34 changes: 34 additions & 0 deletions matbench_discovery/structure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Perturb atomic coordinates of a pymatgen structure."""

import numpy as np
from pymatgen.core import Structure

__author__ = "Janosh Riebesell"
__date__ = "2022-12-02"

rng = np.random.default_rng(seed=0) # ensure reproducible structure perturbations


def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
"""Perturb the atomic coordinates of a pymatgen structure. Used for CGCNN+P
training set augmentation.

Not identical but very similar to the perturbation method used in
https://nature.com/articles/s41524-022-00891-8#Fig5.

Args:
struct (Structure): pymatgen structure to be perturbed
gamma (float, optional): Weibull distribution parameter. Defaults to 1.5.

Returns:
Structure: Perturbed structure
"""
perturbed = struct.copy()
for site in perturbed:
magnitude = rng.weibull(gamma)
vec = rng.normal(3) # TODO maybe make func recursive to deal with 0-vector
vec /= np.linalg.norm(vec) # unit vector
site.coords += vec * magnitude
site.to_unit_cell(in_place=True)

return perturbed
Loading
Loading