Skip to content

Commit 4aa62de

Browse files
authored
Merge branch 'master' into test-monty-reverse-read
2 parents 2598b2c + 389ec50 commit 4aa62de

35 files changed

+873
-584
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ci:
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.9.2
11+
rev: v0.9.4
1212
hooks:
1313
- id: ruff
1414
args: [--fix, --unsafe-fixes]
@@ -27,7 +27,7 @@ repos:
2727
- id: mypy
2828

2929
- repo: https://github.com/codespell-project/codespell
30-
rev: v2.3.0
30+
rev: v2.4.1
3131
hooks:
3232
- id: codespell
3333
stages: [pre-commit, commit-msg]
@@ -48,7 +48,7 @@ repos:
4848
- id: blacken-docs
4949

5050
- repo: https://github.com/igorshubovych/markdownlint-cli
51-
rev: v0.43.0
51+
rev: v0.44.0
5252
hooks:
5353
- id: markdownlint
5454
# MD013: line too long
@@ -65,6 +65,6 @@ repos:
6565
args: [--drop-empty-cells, --keep-output]
6666

6767
- repo: https://github.com/RobertCraigie/pyright-python
68-
rev: v1.1.391
68+
rev: v1.1.393
6969
hooks:
7070
- id: pyright

dev_scripts/potcar_scrambler.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class PotcarScrambler:
2323
"""
24-
Takes a POTCAR and replaces its values with completely random values
24+
Takes a POTCAR and replaces its values with completely random values.
2525
Does type matching and attempts precision matching on floats to ensure
2626
file is read correctly by Potcar and PotcarSingle classes.
2727
@@ -40,14 +40,15 @@ class PotcarScrambler:
4040

4141
def __init__(self, potcars: Potcar | PotcarSingle) -> None:
4242
self.PSP_list = [potcars] if isinstance(potcars, PotcarSingle) else potcars
43-
self.scrambled_potcars_str = ""
43+
self.scrambled_potcars_str: str = ""
4444
for psp in self.PSP_list:
4545
scrambled_potcar_str = self.scramble_single_potcar(psp)
4646
self.scrambled_potcars_str += scrambled_potcar_str
4747

4848
def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5) -> float:
49-
n_prec = len(input_str.split(".")[1])
50-
bd = max(1, bloat * abs(float(input_str))) # ensure we don't get 0
49+
"""Generate a random float from str to replace true values."""
50+
n_prec: int = len(input_str.split(".")[1])
51+
bd: float = max(1.0, bloat * abs(float(input_str))) # ensure we don't get 0
5152
return round(bd * np.random.default_rng().random(), n_prec)
5253

5354
def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5):
@@ -124,14 +125,16 @@ def scramble_single_potcar(self, potcar: PotcarSingle) -> str:
124125
return scrambled_potcar_str
125126

126127
def to_file(self, filename: str) -> None:
128+
"""Write scrambled POTCAR to file."""
127129
with zopen(filename, mode="wt", encoding="utf-8") as file:
128130
file.write(self.scrambled_potcars_str)
129131

130132
@classmethod
131133
def from_file(cls, input_filename: str, output_filename: str | None = None) -> Self:
134+
"""Read a POTCAR from file and generate a scrambled version."""
132135
psp = Potcar.from_file(input_filename)
133136
psp_scrambled = cls(psp)
134-
if output_filename:
137+
if output_filename is not None:
135138
psp_scrambled.to_file(output_filename)
136139
return psp_scrambled
137140

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ dependencies = [
6060
"networkx>=2.7", # PR4116
6161
"palettable>=3.3.3",
6262
"pandas>=2",
63-
"plotly>=4.5.0,<6.0.0",
63+
"plotly>=5.0.0",
6464
"pybtex>=0.24.0",
6565
"requests>=2.32",
6666
"ruamel.yaml>=0.17.0",

src/pymatgen/analysis/eos.py

+61-36
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig, pretty_plot
2525

2626
if TYPE_CHECKING:
27-
from typing import ClassVar
27+
from collections.abc import Sequence
28+
from typing import Any, ClassVar
2829

2930
import matplotlib.pyplot as plt
3031

@@ -40,7 +41,11 @@ class EOSBase(ABC):
4041
implementations.
4142
"""
4243

43-
def __init__(self, volumes, energies):
44+
def __init__(
45+
self,
46+
volumes: Sequence[float],
47+
energies: Sequence[float],
48+
) -> None:
4449
"""
4550
Args:
4651
volumes (Sequence[float]): in Ang^3.
@@ -50,18 +55,28 @@ def __init__(self, volumes, energies):
5055
self.energies = np.array(energies)
5156
# minimum energy(e0), buk modulus(b0),
5257
# derivative of bulk modulus w.r.t. pressure(b1), minimum volume(v0)
53-
self._params = None
58+
self._params: Sequence | None = None
5459
# the eos function parameters. It is the same as _params except for
5560
# equation of states that uses polynomial fits(delta_factor and
5661
# numerical_eos)
57-
self.eos_params = None
62+
self.eos_params: Sequence | None = None
5863

59-
def _initial_guess(self):
64+
def __call__(self, volume: float) -> float:
65+
"""
66+
Args:
67+
volume (float | list[float]): volume(s) in Ang^3.
68+
69+
Returns:
70+
Compute EOS with this volume.
71+
"""
72+
return self.func(volume)
73+
74+
def _initial_guess(self) -> tuple[float, float, float, float]:
6075
"""
6176
Quadratic fit to get an initial guess for the parameters.
6277
6378
Returns:
64-
tuple: 4 floats for (e0, b0, b1, v0)
79+
tuple[float, float, float, float]: e0, b0, b1, v0
6580
"""
6681
a, b, c = np.polyfit(self.volumes, self.energies, 2)
6782
self.eos_params = [a, b, c]
@@ -78,7 +93,7 @@ def _initial_guess(self):
7893

7994
return e0, b0, b1, v0
8095

81-
def fit(self):
96+
def fit(self) -> None:
8297
"""
8398
Do the fitting. Does least square fitting. If you want to use custom
8499
fitting, must override this.
@@ -120,24 +135,20 @@ def func(self, volume):
120135
"""
121136
return self._func(np.array(volume), self.eos_params)
122137

123-
def __call__(self, volume: float) -> float:
124-
"""
125-
Args:
126-
volume (float | list[float]): volume(s) in Ang^3.
127-
128-
Returns:
129-
Compute EOS with this volume.
130-
"""
131-
return self.func(volume)
132-
133138
@property
134139
def e0(self) -> float:
135140
"""The min energy."""
141+
if self._params is None:
142+
raise RuntimeError("params have not be initialized.")
143+
136144
return self._params[0]
137145

138146
@property
139147
def b0(self) -> float:
140148
"""The bulk modulus in units of energy/unit of volume^3."""
149+
if self._params is None:
150+
raise RuntimeError("params have not be initialized.")
151+
141152
return self._params[1]
142153

143154
@property
@@ -156,11 +167,18 @@ def v0(self):
156167
return self._params[3]
157168

158169
@property
159-
def results(self):
170+
def results(self) -> dict[str, Any]:
160171
"""A summary dict."""
161172
return {"e0": self.e0, "b0": self.b0, "b1": self.b1, "v0": self.v0}
162173

163-
def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
174+
def plot(
175+
self,
176+
width: float = 8,
177+
height: float | None = None,
178+
ax: plt.Axes = None,
179+
dpi: float | None = None,
180+
**kwargs,
181+
) -> plt.Axes:
164182
"""
165183
Plot the equation of state.
166184
@@ -170,7 +188,7 @@ def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
170188
golden ratio.
171189
ax (plt.Axes): If supplied, changes will be made to the existing Axes.
172190
Otherwise, new Axes will be created.
173-
dpi:
191+
dpi (float): DPI.
174192
kwargs (dict): additional args fed to pyplot.plot.
175193
supported keys: style, color, text, label
176194
@@ -211,16 +229,18 @@ def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
211229
return ax
212230

213231
@add_fig_kwargs
214-
def plot_ax(self, ax: plt.Axes = None, fontsize=12, **kwargs):
232+
def plot_ax(
233+
self,
234+
ax: plt.Axes | None = None,
235+
fontsize: float = 12,
236+
**kwargs,
237+
) -> plt.Figure:
215238
"""
216239
Plot the equation of state on axis `ax`.
217240
218241
Args:
219242
ax: matplotlib Axes or None if a new figure should be created.
220243
fontsize: Legend fontsize.
221-
color (str): plot color.
222-
label (str): Plot label
223-
text (str): Legend text (options)
224244
225245
Returns:
226246
plt.Figure: matplotlib figure.
@@ -270,7 +290,7 @@ def plot_ax(self, ax: plt.Axes = None, fontsize=12, **kwargs):
270290
class Murnaghan(EOSBase):
271291
"""Murnaghan EOS."""
272292

273-
def _func(self, volume, params):
293+
def _func(self, volume, params: tuple[float, float, float, float]):
274294
"""From PRB 28,5480 (1983)."""
275295
e0, b0, b1, v0 = tuple(params)
276296
return e0 + b0 * volume / b1 * (((v0 / volume) ** b1) / (b1 - 1.0) + 1.0) - v0 * b0 / (b1 - 1.0)
@@ -279,7 +299,7 @@ def _func(self, volume, params):
279299
class Birch(EOSBase):
280300
"""Birch EOS."""
281301

282-
def _func(self, volume, params):
302+
def _func(self, volume, params: tuple[float, float, float, float]):
283303
"""From Intermetallic compounds: Principles and Practice, Vol. I:
284304
Principles Chapter 9 pages 195-210 by M. Mehl. B. Klein,
285305
D. Papaconstantopoulos.
@@ -296,7 +316,7 @@ def _func(self, volume, params):
296316
class BirchMurnaghan(EOSBase):
297317
"""BirchMurnaghan EOS."""
298318

299-
def _func(self, volume, params):
319+
def _func(self, volume, params: tuple[float, float, float, float]):
300320
"""BirchMurnaghan equation from PRB 70, 224107."""
301321
e0, b0, b1, v0 = tuple(params)
302322
eta = (v0 / volume) ** (1 / 3)
@@ -306,7 +326,7 @@ def _func(self, volume, params):
306326
class PourierTarantola(EOSBase):
307327
"""Pourier-Tarantola EOS."""
308328

309-
def _func(self, volume, params):
329+
def _func(self, volume, params: tuple[float, float, float, float]):
310330
"""Pourier-Tarantola equation from PRB 70, 224107."""
311331
e0, b0, b1, v0 = tuple(params)
312332
eta = (volume / v0) ** (1 / 3)
@@ -317,7 +337,7 @@ def _func(self, volume, params):
317337
class Vinet(EOSBase):
318338
"""Vinet EOS."""
319339

320-
def _func(self, volume, params):
340+
def _func(self, volume, params: tuple[float, float, float, float]):
321341
"""Vinet equation from PRB 70, 224107."""
322342
e0, b0, b1, v0 = tuple(params)
323343
eta = (volume / v0) ** (1 / 3)
@@ -335,7 +355,7 @@ class PolynomialEOS(EOSBase):
335355
def _func(self, volume, params):
336356
return np.poly1d(list(params))(volume)
337357

338-
def fit(self, order):
358+
def fit(self, order: int) -> None:
339359
"""
340360
Do polynomial fitting and set the parameters. Uses numpy polyfit.
341361
@@ -345,7 +365,7 @@ def fit(self, order):
345365
self.eos_params = np.polyfit(self.volumes, self.energies, order)
346366
self._set_params()
347367

348-
def _set_params(self):
368+
def _set_params(self) -> None:
349369
"""
350370
Use the fit polynomial to compute the parameter e0, b0, b1 and v0
351371
and set to the _params attribute.
@@ -372,7 +392,7 @@ def _func(self, volume, params):
372392
x = volume ** (-2 / 3.0)
373393
return np.poly1d(list(params))(x)
374394

375-
def fit(self, order=3):
395+
def fit(self, order: int = 3) -> None:
376396
"""Overridden since this eos works with volume**(2/3) instead of volume."""
377397
x = self.volumes ** (-2 / 3.0)
378398
self.eos_params = np.polyfit(x, self.energies, order)
@@ -407,7 +427,12 @@ def _set_params(self):
407427
class NumericalEOS(PolynomialEOS):
408428
"""A numerical EOS."""
409429

410-
def fit(self, min_ndata_factor=3, max_poly_order_factor=5, min_poly_order=2):
430+
def fit(
431+
self,
432+
min_ndata_factor: int = 3,
433+
max_poly_order_factor: int = 5,
434+
min_poly_order: int = 2,
435+
) -> None:
411436
"""Fit the input data to the 'numerical eos', the equation of state employed
412437
in the quasiharmonic Debye model described in the paper:
413438
10.1103/PhysRevB.90.174107.
@@ -539,7 +564,7 @@ class EOS:
539564
eos_fit.plot()
540565
"""
541566

542-
MODELS: ClassVar = {
567+
MODELS: ClassVar[dict[str, Any]] = {
543568
"murnaghan": Murnaghan,
544569
"birch": Birch,
545570
"birch_murnaghan": BirchMurnaghan,
@@ -549,7 +574,7 @@ class EOS:
549574
"numerical_eos": NumericalEOS,
550575
}
551576

552-
def __init__(self, eos_name="murnaghan"):
577+
def __init__(self, eos_name: str = "murnaghan") -> None:
553578
"""
554579
Args:
555580
eos_name (str): Type of EOS to fit.
@@ -562,7 +587,7 @@ def __init__(self, eos_name="murnaghan"):
562587
self._eos_name = eos_name
563588
self.model = self.MODELS[eos_name]
564589

565-
def fit(self, volumes, energies):
590+
def fit(self, volumes: Sequence[float], energies: Sequence[float]) -> EOSBase:
566591
"""Fit energies as function of volumes.
567592
568593
Args:

0 commit comments

Comments
 (0)