Skip to content

Commit 0f22690

Browse files
authored
add optional gruneisen parameter colorbar plot (#3908)
1 parent 1bddad2 commit 0f22690

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

src/pymatgen/phonon/plotter.py

+65-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import matplotlib.pyplot as plt
99
import numpy as np
1010
import scipy.constants as const
11+
from matplotlib import colors
1112
from matplotlib.collections import LineCollection
13+
from matplotlib.colors import LinearSegmentedColormap
1214
from monty.json import jsanitize
1315
from pymatgen.electronic_structure.plotter import BSDOSPlotter, plot_brillouin_zone
1416
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
@@ -1052,35 +1054,75 @@ def bs_plot_data(self) -> dict[str, Any]:
10521054
"lattice": self._bs.lattice_rec.as_dict(),
10531055
}
10541056

1055-
def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:
1057+
def get_plot_gs(self, ylim: float | None = None, plot_ph_bs_with_gruneisen: bool = False, **kwargs) -> Axes:
10561058
"""Get a matplotlib object for the Gruneisen bandstructure plot.
10571059
10581060
Args:
10591061
ylim: Specify the y-axis (gruneisen) limits; by default None let
10601062
the code choose.
1063+
plot_ph_bs_with_gruneisen (bool): Plot phonon band-structure with bands coloured
1064+
as per Gruneisen parameter values on a logarithmic scale
10611065
**kwargs: additional keywords passed to ax.plot().
10621066
"""
1067+
u = freq_units(kwargs.get("units", "THz"))
10631068
ax = pretty_plot(12, 8)
10641069

1070+
# Create a colormap (default is red to blue)
1071+
cmap = LinearSegmentedColormap.from_list("cmap", kwargs.get("cmap", ["red", "blue"]))
1072+
10651073
kwargs.setdefault("linewidth", 2)
10661074
kwargs.setdefault("marker", "o")
10671075
kwargs.setdefault("markersize", 2)
10681076

10691077
data = self.bs_plot_data()
1070-
for dist_idx in range(len(data["distances"])):
1078+
1079+
# extract min and max Grüneisen parameter values
1080+
max_gruneisen = np.array(data["gruneisen"]).max()
1081+
min_gruneisen = np.array(data["gruneisen"]).min()
1082+
1083+
# LogNormalize colormap based on the min and max Grüneisen parameter values
1084+
norm = colors.SymLogNorm(
1085+
vmin=min_gruneisen,
1086+
vmax=max_gruneisen,
1087+
linthresh=1e-2,
1088+
linscale=1,
1089+
)
1090+
1091+
sc = None
1092+
for (dists_inx, dists), (_, freqs) in zip(enumerate(data["distances"]), enumerate(data["frequency"])):
10711093
for band_idx in range(self.n_bands):
1072-
ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))]
1094+
if plot_ph_bs_with_gruneisen:
1095+
ys = [freqs[band_idx][j] * u.factor for j in range(len(dists))]
1096+
ys_gru = [
1097+
data["gruneisen"][dists_inx][band_idx][idx] for idx in range(len(data["distances"][dists_inx]))
1098+
]
1099+
sc = ax.scatter(dists, ys, c=ys_gru, cmap=cmap, norm=norm, marker="o", s=1)
1100+
else:
1101+
keys_to_remove = ("units", "cmap") # needs to be removed before passing to line-plot
1102+
for k in keys_to_remove:
1103+
kwargs.pop(k, None)
1104+
ys = [
1105+
data["gruneisen"][dists_inx][band_idx][idx] for idx in range(len(data["distances"][dists_inx]))
1106+
]
10731107

1074-
ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs)
1108+
ax.plot(data["distances"][dists_inx], ys, "b-", **kwargs)
10751109

10761110
self._make_ticks(ax)
10771111

10781112
# plot y=0 line
10791113
ax.axhline(0, linewidth=1, color="black")
10801114

10811115
# Main X and Y Labels
1082-
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
1083-
ax.set_ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30)
1116+
if plot_ph_bs_with_gruneisen:
1117+
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
1118+
units = kwargs.get("units", "THz")
1119+
ax.set_ylabel(f"Frequencies ({units})", fontsize=30)
1120+
1121+
cbar = plt.colorbar(sc, ax=ax)
1122+
cbar.set_label(r"$\gamma \ \mathrm{(logarithmized)}$", fontsize=30)
1123+
else:
1124+
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
1125+
ax.set_ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30)
10841126

10851127
# X range (K)
10861128
# last distance point
@@ -1094,24 +1136,37 @@ def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:
10941136

10951137
return ax
10961138

1097-
def show_gs(self, ylim: float | None = None) -> None:
1139+
def show_gs(self, ylim: float | None = None, plot_ph_bs_with_gruneisen: bool = False, **kwargs) -> None:
10981140
"""Show the plot using matplotlib.
10991141
11001142
Args:
11011143
ylim: Specifies the y-axis limits.
1144+
plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured
1145+
as per Gruneisen parameter values on a logarithmic scale
1146+
**kwargs: kwargs passed to get_plot_gs
11021147
"""
1103-
self.get_plot_gs(ylim)
1148+
self.get_plot_gs(ylim=ylim, plot_ph_bs_with_gruneisen=plot_ph_bs_with_gruneisen, **kwargs)
11041149
plt.show()
11051150

1106-
def save_plot_gs(self, filename: str | PathLike, img_format: str = "eps", ylim: float | None = None) -> None:
1151+
def save_plot_gs(
1152+
self,
1153+
filename: str | PathLike,
1154+
img_format: str = "eps",
1155+
ylim: float | None = None,
1156+
plot_ph_bs_with_gruneisen: bool = False,
1157+
**kwargs,
1158+
) -> None:
11071159
"""Save matplotlib plot to a file.
11081160
11091161
Args:
11101162
filename: Filename to write to.
11111163
img_format: Image format to use. Defaults to EPS.
11121164
ylim: Specifies the y-axis limits.
1165+
plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured
1166+
as per Gruneisen parameter values on a logarithmic scale
1167+
**kwargs: kwargs passed to get_plot_gs
11131168
"""
1114-
self.get_plot_gs(ylim=ylim)
1169+
self.get_plot_gs(ylim=ylim, plot_ph_bs_with_gruneisen=plot_ph_bs_with_gruneisen, **kwargs)
11151170
plt.savefig(filename, format=img_format)
11161171
plt.close()
11171172

tests/phonon/test_gruneisen.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

33
import matplotlib.pyplot as plt
4+
import numpy as np
45
import pytest
6+
from matplotlib import colors
57
from pymatgen.io.phonopy import get_gruneisen_ph_bs_symm_line, get_gruneisenparameter
68
from pymatgen.phonon.gruneisen import GruneisenParameter
79
from pymatgen.phonon.plotter import GruneisenPhononBandStructureSymmLine, GruneisenPhononBSPlotter, GruneisenPlotter
@@ -31,6 +33,39 @@ def test_plot(self):
3133
ax = plotter.get_plot_gs()
3234
assert isinstance(ax, plt.Axes)
3335

36+
def test_ph_plot_w_gruneisen(self):
37+
plotter = GruneisenPhononBSPlotter(bs=self.bs_symm_line)
38+
ax = plotter.get_plot_gs(plot_ph_bs_with_gruneisen=True, units="THz", cmap=["red", "royalblue"])
39+
assert ax.get_ylabel() == "Frequencies (THz)"
40+
assert ax.get_xlabel() == "$\\mathrm{Wave\\ Vector}$"
41+
assert ax.get_figure()._localaxes[-1].get_ylabel() == "$\\gamma \\ \\mathrm{(logarithmized)}$"
42+
assert len(ax._children) == plotter.n_bands + 1 # check for number of bands
43+
# check for x and y data is really the band-structure data
44+
for inx, band in enumerate(plotter._bs.bands):
45+
xy_data = {
46+
"x": [point[0] for point in ax._children[inx].get_offsets().data],
47+
"y": [point[1] for point in ax._children[inx].get_offsets().data],
48+
}
49+
assert band == pytest.approx(xy_data["y"])
50+
assert plotter._bs.distance == pytest.approx(xy_data["x"])
51+
52+
# check if color bar max value matches maximum gruneisen parameter value
53+
data = plotter.bs_plot_data()
54+
55+
# get reference min and max Grüneisen parameter values
56+
max_gruneisen = np.array(data["gruneisen"]).max()
57+
min_gruneisen = np.array(data["gruneisen"]).min()
58+
59+
norm = colors.SymLogNorm(
60+
vmin=min_gruneisen,
61+
vmax=max_gruneisen,
62+
linthresh=1e-2,
63+
linscale=1,
64+
)
65+
66+
assert max(norm.inverse(ax.get_figure()._localaxes[-1].get_yticks())) == pytest.approx(max_gruneisen)
67+
assert isinstance(ax, plt.Axes)
68+
3469
def test_as_dict_from_dict(self):
3570
new_dict = self.bs_symm_line.as_dict()
3671
self.new_bs_symm_line = GruneisenPhononBandStructureSymmLine.from_dict(new_dict)

0 commit comments

Comments
 (0)