Skip to content

Commit

Permalink
link MPtrj dataset from /contribute page "Direct Download" section
Browse files Browse the repository at this point in the history
update MACE readme for 16M MPtrj checkpoint from pbenner
define formula_col to ensure consistency across code base
  • Loading branch information
janosh committed Nov 15, 2023
1 parent 4ce353b commit db999fd
Show file tree
Hide file tree
Showing 24 changed files with 94 additions and 70 deletions.
6 changes: 4 additions & 2 deletions data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from tqdm import tqdm

from matbench_discovery import ROOT, id_col, today
from matbench_discovery import ROOT, formula_col, id_col, today
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.plots import plt
Expand Down Expand Up @@ -68,7 +68,9 @@


# %%
df_wbm["chem_sys"] = df_wbm.formula.str.replace("[0-9]+", "", regex=True).str.split()
df_wbm["chem_sys"] = (
df_wbm[formula_col].str.replace("[0-9]+", "", regex=True).str.split()
)
df_wbm["anion"] = None
df_wbm["anion"][df_wbm.chem_sys.astype(str).str.contains("'O'")] = "oxide"
df_wbm["anion"][df_wbm.chem_sys.astype(str).str.contains("'S'")] = "sulfide"
Expand Down
25 changes: 17 additions & 8 deletions data/wbm/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
)
from pymatviz.io import save_fig

from matbench_discovery import PDF_FIGS, ROOT, SITE_FIGS, STABILITY_THRESHOLD, id_col
from matbench_discovery import (
PDF_FIGS,
ROOT,
SITE_FIGS,
STABILITY_THRESHOLD,
formula_col,
id_col,
)
from matbench_discovery import plots as plots
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.energy import mp_elem_reference_entries
Expand All @@ -35,8 +42,10 @@


# %%
wbm_occu_counts = count_elements(df_wbm.formula, count_mode="occurrence").astype(int)
wbm_comp_counts = count_elements(df_wbm.formula, count_mode="composition")
wbm_occu_counts = count_elements(df_wbm[formula_col], count_mode="occurrence").astype(
int
)
wbm_comp_counts = count_elements(df_wbm[formula_col], count_mode="composition")

mp_occu_counts = count_elements(df_mp.formula_pretty, count_mode="occurrence").astype(
int
Expand All @@ -60,16 +69,16 @@
df_wbm["step"] = df_wbm.index.str.split("-").str[1].astype(int)
assert df_wbm.step.between(1, 5).all()
for batch in range(1, 6):
count_elements(df_wbm[df_wbm.step == batch].formula).to_json(
count_elements(df_wbm[df_wbm.step == batch][formula_col]).to_json(
f"{data_page}/wbm-element-counts-{batch=}.json"
)

# export element counts by arity (how many elements in the formula)
comp_col = "composition"
df_wbm[comp_col] = df_wbm.formula.map(Composition)
df_wbm[comp_col] = df_wbm[formula_col].map(Composition)

for arity, df_mp in df_wbm.groupby(df_wbm[comp_col].map(len)):
count_elements(df_mp.formula).to_json(
count_elements(df_mp[formula_col]).to_json(
f"{data_page}/wbm-element-counts-{arity=}.json"
)

Expand Down Expand Up @@ -206,7 +215,7 @@
y="2d t-SNE 2",
color=color_col,
hover_name=id_col,
hover_data=("formula", each_true_col),
hover_data=(formula_col, each_true_col),
range_color=(0, clr_range_max),
)
fig.show()
Expand All @@ -219,7 +228,7 @@
y="3d t-SNE 2",
z="3d t-SNE 3",
color=color_col,
custom_data=[id_col, "formula", each_true_col, color_col],
custom_data=[id_col, formula_col, each_true_col, color_col],
range_color=(0, clr_range_max),
)
fig.data[0].hovertemplate = (
Expand Down
18 changes: 9 additions & 9 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pymatviz.io import save_fig
from tqdm import tqdm

from matbench_discovery import SITE_FIGS, id_col, today
from matbench_discovery import SITE_FIGS, formula_col, id_col, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.plots import pio
Expand Down Expand Up @@ -289,7 +289,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:

# %%
col_map = {
"# comp": "formula",
"# comp": formula_col,
"nsites": "n_sites",
"vol": "volume",
"e": "uncorrected_energy",
Expand Down Expand Up @@ -319,7 +319,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:

assert sum(no_id_mask := df_summary.index.isna()) == 6, f"{sum(no_id_mask)=}"
# the 'None' materials have 0 volume, energy, n_sites, bandgap, etc.
assert all(df_summary[no_id_mask].drop(columns=["formula"]) == 0)
assert all(df_summary[no_id_mask].drop(columns=[formula_col]) == 0)
assert len(df_summary.query("volume > 0")) == len(df_wbm) + len(nan_init_structs_ids)
# make sure dropping materials with 0 volume removes exactly 6 materials, the same ones
# listed in bad_struct_ids above
Expand Down Expand Up @@ -378,13 +378,13 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:

# sort formulas alphabetically
df_summary["alph_formula"] = [
Composition(x).alphabetical_formula for x in df_summary.formula
Composition(x).alphabetical_formula for x in df_summary[formula_col]
]
# alphabetical formula and original formula differ due to spaces, number 1 after element
# symbols (FeO vs Fe1 O1), and element order (FeO vs OFe)
assert sum(df_summary.alph_formula != df_summary.formula) == 257_483
assert sum(df_summary.alph_formula != df_summary[formula_col]) == 257_483

df_summary["formula"] = df_summary.pop("alph_formula")
df_summary[formula_col] = df_summary.pop("alph_formula")


# %% write initial structures and computed structure entries to compressed json
Expand All @@ -404,10 +404,10 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
# df_summary and df_wbm formulas differ because summary formulas are reduced while
# df_wbm formulas are not (e.g. Ac6 U2 vs Ac3 U1 in summary). unreduced is more
# informative so we use it.
assert sum(df_summary.formula != df_wbm.formula_from_cse) == 114_273
assert sum(df_summary.formula == df_wbm.formula_from_cse) == 143_214
assert sum(df_summary[formula_col] != df_wbm.formula_from_cse) == 114_273
assert sum(df_summary[formula_col] == df_wbm.formula_from_cse) == 143_214

df_summary.formula = df_wbm.formula_from_cse
df_summary[formula_col] = df_wbm.formula_from_cse


# fix bad energy which is 0 in df_summary but a more realistic -63.68 in CSE
Expand Down
1 change: 1 addition & 0 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
init_struct_col = "initial_structure"
struct_col = "structure"
e_form_col = "formation_energy_per_atom"
formula_col = "formula"
8 changes: 4 additions & 4 deletions models/chgnet/analyze_chgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
from pymatviz.io import save_fig

from matbench_discovery import PDF_FIGS, id_col
from matbench_discovery import PDF_FIGS, formula_col, id_col
from matbench_discovery import plots as plots
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.preds import PRED_FILES
Expand All @@ -26,7 +26,7 @@
df_chgnet_v020 = pd.read_csv(
f"{module_dir}/2023-03-06-chgnet-0.2.0-wbm-IS2RE.csv.gz", index_col=id_col
)
df_chgnet["formula"] = df_wbm.formula
df_chgnet[formula_col] = df_wbm[formula_col]

e_form_2000 = "e_form_per_atom_chgnet_relax_steps_2000"
e_form_500 = "e_form_per_atom_chgnet_relax_steps_500"
Expand All @@ -51,15 +51,15 @@
x=e_form_500,
y=e_form_2000,
hover_name=id_col,
hover_data=["formula"],
hover_data=[formula_col],
backend="plotly",
title=f"{len(df_diff)} structures have > {min_e_diff} eV/atom energy diff after "
"longer relaxation",
)


# %%
fig = ptable_heatmap_plotly(df_bad.formula)
fig = ptable_heatmap_plotly(df_bad[formula_col])
title = "structures with larger error<br>after longer relaxation"
fig.layout.title.update(text=f"{len(df_diff)} {title}", x=0.4, y=0.9)
fig.show()
Expand Down
4 changes: 2 additions & 2 deletions models/chgnet/ctk_structure_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer

from matbench_discovery import id_col
from matbench_discovery import formula_col, id_col
from matbench_discovery.preds import PRED_FILES

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -47,7 +47,7 @@
y=e_form_2000,
backend="plotly",
hover_name=id_col,
hover_data=["formula"],
hover_data=[formula_col],
labels=plot_labels,
size=e_form_abs_diff,
color=e_form_abs_diff,
Expand Down
6 changes: 3 additions & 3 deletions models/chgnet/join_chgnet_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymatviz import density_scatter
from tqdm import tqdm

from matbench_discovery import id_col
from matbench_discovery import formula_col, id_col
from matbench_discovery.data import as_dict_handler
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.preds import df_preds, e_form_col
Expand Down Expand Up @@ -54,11 +54,11 @@

# %% compute corrected formation energies
e_form_chgnet_col = "e_form_per_atom_chgnet"
df_chgnet["formula"] = df_preds.formula
df_chgnet[formula_col] = df_preds[formula_col]
df_chgnet[e_form_chgnet_col] = [
get_e_form_per_atom(dict(energy=ene, composition=formula))
for formula, ene in tqdm(
df_chgnet.set_index("formula").chgnet_energy.items(), total=len(df_chgnet)
df_chgnet.set_index(formula_col).chgnet_energy.items(), total=len(df_chgnet)
)
]
df_preds[e_form_chgnet_col] = df_chgnet[e_form_chgnet_col]
Expand Down
4 changes: 2 additions & 2 deletions models/chgnet/test_chgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pymatgen.core import Structure
from tqdm import tqdm

from matbench_discovery import id_col, timestamp, today
from matbench_discovery import formula_col, id_col, timestamp, today
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
Expand Down Expand Up @@ -125,7 +125,7 @@
df_wbm[e_pred_col] = df_out[e_pred_col]
table = wandb.Table(
dataframe=df_wbm.dropna()[
["uncorrected_energy", e_pred_col, "formula"]
["uncorrected_energy", e_pred_col, formula_col]
].reset_index()
)

Expand Down
4 changes: 2 additions & 2 deletions models/mace/analyze_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst
from pymatviz.io import save_fig

from matbench_discovery import id_col
from matbench_discovery import formula_col, id_col
from matbench_discovery import plots as plots
from matbench_discovery.data import df_wbm
from matbench_discovery.preds import PRED_FILES
Expand Down Expand Up @@ -44,7 +44,7 @@


# %%
fig = ptable_heatmap_plotly(df_low.formula)
fig = ptable_heatmap_plotly(df_low[formula_col])
title = f"Elements in {len(df_low):,} MACE severe energy underpredictions"
fig.layout.title.update(text=title, x=0.4, y=0.95)
fig.show()
Expand Down
8 changes: 4 additions & 4 deletions models/mace/join_mace_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pymatviz import density_scatter
from tqdm import tqdm

from matbench_discovery import id_col
from matbench_discovery import formula_col, id_col
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.preds import e_form_col
Expand Down Expand Up @@ -80,11 +80,11 @@

# %% compute corrected formation energies
e_form_mace_col = "e_form_per_atom_mace"
df_mace["formula"] = df_wbm.formula
df_mace[formula_col] = df_wbm[formula_col]
df_mace[e_form_mace_col] = [
get_e_form_per_atom(dict(energy=cse.energy, composition=formula))
for formula, cse in tqdm(
df_mace.set_index("formula")[entry_col].items(), total=len(df_mace)
df_mace.set_index(formula_col)[entry_col].items(), total=len(df_mace)
)
]
df_wbm[e_form_mace_col] = df_mace[e_form_mace_col]
Expand All @@ -106,6 +106,6 @@
df_bad[e_form_col] = df_wbm[e_form_col]
df_bad.to_csv(f"{out_path}-bad.csv")

# in_path = f"{module_dir}/2023-08-14-mace-wbm-IS2RE-FIRE"
# in_path = f"{module_dir}/2023-11-02-mace-wbm-IS2RE-FIRE"
# df_mace = pd.read_csv(f"{in_path}.csv.gz").set_index(id_col)
# df_mace = pd.read_json(f"{in_path}.json.gz").set_index(id_col)
5 changes: 2 additions & 3 deletions models/mace/json_to_extxyz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""This script converts the MPTrj relaxation trajectories from JSON to
extended XYZ format. The JSON data was downloaded from
https://figshare.com/articles/dataset/23713842.
"""This script converts the MPTrj relaxation trajectories downloaded from
https://figshare.com/articles/dataset/23713842 from JSON to extended XYZ format.
"""

import json
Expand Down
2 changes: 1 addition & 1 deletion models/mace/metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ hyperparams:

notes:
description: |
The Many-body Atomic Convolutional Energies (MACE) is a higher-order equivariant message-passing neural network for fast and accurate force fields.
MACE is a higher-order equivariant message-passing neural network for fast and accurate force fields.
training: Using pre-trained model released with paper. Training set unspecified at time of writing.
corrections: None
# n_params: 2_026_624 # 2023-09-03-mace-yuan-mptrj-slower-14-lr-13_run-3
Expand Down
4 changes: 3 additions & 1 deletion models/mace/readme.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
## MACE formation energy predictions on WBM test set

This submission uses the [`2023-08-14-mace-yuan-trained-mptrj-04.model`](https://figshare.com/ndownloader/files/42374049) checkpoint trained by Yuan Chiang on the [MPtrj dataset](https://figshare.com/articles/dataset/23713842).
The original MACE submission used the 2M parameter checkpoint [`2023-08-14-mace-yuan-trained-mptrj-04.model`](https://figshare.com/ndownloader/files/42374049) trained by Yuan Chiang on the [MPtrj dataset](https://figshare.com/articles/dataset/23713842).
We initially tested the `2023-07-14-mace-universal-2-big-128-6.model` checkpoint trained on the much smaller [original M3GNet training set](https://figshare.com/articles/dataset/MPF_2021_2_8/19470599) which we received directly from Ilyes Batatia. MPtrj-trained MACE performed better and was used for the Matbench Discovery v1 submission.

In late October (received 2023-10-29), Philipp Benner trained a much larger 16M parameter MACE for over 100 epochs in MPtrj which achieved an (at the time SOTA) F1 score of 0.64 and DAF of 3.13.

### Convergence criteria

MACE relaxed each test set structure until the maximum force in the training set dropped below 0.05 eV/Å or 500 optimization steps were reached, whichever occurred first.
Expand Down
7 changes: 4 additions & 3 deletions models/mace/test_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pymatgen.io.ase import AseAtomsAdaptor
from tqdm import tqdm

from matbench_discovery import ROOT, id_col, timestamp, today
from matbench_discovery import ROOT, formula_col, id_col, timestamp, today
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
Expand Down Expand Up @@ -60,7 +60,8 @@

# %%
slurm_array_task_id = int(os.getenv("SLURM_ARRAY_TASK_ID", "0"))
out_path = f"{out_dir}/mace-preds-{slurm_array_task_id:>03}.json.gz"
slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
out_path = f"{out_dir}/{slurm_job_id}-{slurm_array_task_id:>03}.json.gz"

if os.path.isfile(out_path):
raise SystemExit(f"{out_path=} already exists, exciting early")
Expand Down Expand Up @@ -164,7 +165,7 @@
df_wbm[e_pred_col] = df_out[e_pred_col]
table = wandb.Table(
dataframe=df_wbm.dropna()[
["uncorrected_energy", e_pred_col, "formula"]
["uncorrected_energy", e_pred_col, formula_col]
].reset_index()
)

Expand Down
4 changes: 2 additions & 2 deletions models/voronoi/train_test_voronoi_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline

from matbench_discovery import ROOT, id_col, today
from matbench_discovery import ROOT, formula_col, id_col, today
from matbench_discovery.data import DATA_FILES, df_wbm, glob_to_df
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.preds import e_form_col as test_e_form_col
Expand Down Expand Up @@ -123,7 +123,7 @@
df_wbm[pred_col].round(4).to_csv(out_path)

table = wandb.Table(
dataframe=df_wbm[["formula", test_e_form_col, pred_col]].reset_index()
dataframe=df_wbm[[formula_col, test_e_form_col, pred_col]].reset_index()
)

df_wbm[pred_col].isna().sum()
Expand Down
4 changes: 2 additions & 2 deletions models/wrenformer/analyze_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pymatviz.ptable import ptable_heatmap_plotly
from pymatviz.utils import add_identity_line, bin_df_cols

from matbench_discovery import PDF_FIGS, SITE_FIGS, id_col
from matbench_discovery import PDF_FIGS, SITE_FIGS, formula_col, id_col
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col

Expand Down Expand Up @@ -85,7 +85,7 @@


# %%
fig = ptable_heatmap_plotly(df_bad.formula)
fig = ptable_heatmap_plotly(df_bad[formula_col])
fig.layout.title = f"Elements in {title}"
fig.layout.margin = dict(l=0, r=0, t=50, b=0)
fig.show()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ running-models = [

"aviary@git+https://github.com/CompRhys/aviary",
"m3gnet",
"mace@git+https://github.com/ACEsuit/mace",
"maml",
"megnet",
]
Expand Down
Loading

0 comments on commit db999fd

Please sign in to comment.