Skip to content

Commit 195f19e

Browse files
authored
Update inequality in get_linear_interpolated_value (#4299)
* Update inequality in `get_linear_interpolated_value` * Add test for updated edge case behaviour * Rename `firs` to `fict_ionic_radii` for clarity and to avoid `codespell` errors in linting * Add `if not below_fermi or not above_fermi:` catch to `DOS.get_interpolated_gap()` (to be consistent with `Dos.get_interpolated_gap()`) * Return zero gap for case of VBM index = CBM index -1 (i.e. no gap found)
1 parent 1b8c5f8 commit 195f19e

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

src/pymatgen/analysis/local_env.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3745,23 +3745,23 @@ def get_nn_info(self, structure: Structure, n: int):
37453745

37463746
if self.use_fictive_radius:
37473747
# calculate fictive ionic radii
3748-
firs = [_get_fictive_ionic_radius(site, neighbor) for neighbor in neighbors]
3748+
fict_ionic_radii = [_get_fictive_ionic_radius(site, neighbor) for neighbor in neighbors]
37493749
else:
37503750
# just use the bond distance
3751-
firs = [neighbor.nn_distance for neighbor in neighbors]
3751+
fict_ionic_radii = [neighbor.nn_distance for neighbor in neighbors]
37523752

37533753
# calculate mean fictive ionic radius
3754-
mefir = _get_mean_fictive_ionic_radius(firs)
3754+
mefir = _get_mean_fictive_ionic_radius(fict_ionic_radii)
37553755

37563756
# iteratively solve MEFIR; follows equation 4 in Hoppe's EconN paper
37573757
prev_mefir = float("inf")
37583758
while abs(prev_mefir - mefir) > 1e-4:
37593759
# this is guaranteed to converge
37603760
prev_mefir = mefir
3761-
mefir = _get_mean_fictive_ionic_radius(firs, minimum_fir=mefir)
3761+
mefir = _get_mean_fictive_ionic_radius(fict_ionic_radii, minimum_fir=mefir)
37623762

37633763
siw = []
3764-
for nn, fir in zip(neighbors, firs, strict=True):
3764+
for nn, fir in zip(neighbors, fict_ionic_radii, strict=True):
37653765
if nn.nn_distance < self.cutoff:
37663766
w = math.exp(1 - (fir / mefir) ** 6)
37673767
if w > self.tol:

src/pymatgen/electronic_structure/dos.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,12 @@ def get_interpolated_gap(
107107
energies = self.x
108108
below_fermi = [i for i in range(len(energies)) if energies[i] < self.efermi and tdos[i] > tol]
109109
above_fermi = [i for i in range(len(energies)) if energies[i] > self.efermi and tdos[i] > tol]
110+
if not below_fermi or not above_fermi:
111+
return 0.0, self.efermi, self.efermi
112+
110113
vbm_start = max(below_fermi)
111114
cbm_start = min(above_fermi)
112-
if vbm_start == cbm_start:
115+
if vbm_start in [cbm_start, cbm_start - 1]:
113116
return 0.0, self.efermi, self.efermi
114117

115118
# Interpolate between adjacent values
@@ -311,7 +314,7 @@ def get_interpolated_gap(
311314

312315
vbm_start = max(below_fermi)
313316
cbm_start = min(above_fermi)
314-
if vbm_start == cbm_start:
317+
if vbm_start in [cbm_start, cbm_start - 1]:
315318
return 0.0, self.efermi, self.efermi
316319

317320
# Interpolate between adjacent values

src/pymatgen/util/coord.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def get_linear_interpolated_value(x_values: ArrayLike, y_values: ArrayLike, x: f
135135
"""
136136
arr = np.array(sorted(zip(x_values, y_values, strict=True), key=lambda d: d[0]))
137137

138-
indices = np.where(arr[:, 0] >= x)[0]
138+
indices = np.where(arr[:, 0] > x)[0]
139139

140140
if len(indices) == 0 or indices[0] == 0:
141141
raise ValueError(f"{x=} is out of range of provided x_values ({min(x_values)}, {max(x_values)})")

tests/util/test_coord.py

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def test_get_linear_interpolated_value(self):
1919
with pytest.raises(ValueError, match=r"x=6 is out of range of provided x_values \(0, 5\)"):
2020
coord.get_linear_interpolated_value(x_vals, y_vals, 6)
2121

22+
# test when x is equal to first value in x_vals (previously broke, fixed in #4299):
23+
assert coord.get_linear_interpolated_value(x_vals, y_vals, 0) == approx(3)
24+
with pytest.raises(ValueError, match=r"x=-0.5 is out of range of provided x_values \(0, 5\)"):
25+
coord.get_linear_interpolated_value(x_vals, y_vals, -0.5)
26+
2227
def test_in_coord_list(self):
2328
coords = [[0, 0, 0], [0.5, 0.5, 0.5]]
2429
test_coord = [0.1, 0.1, 0.1]

0 commit comments

Comments
 (0)