8
8
import matplotlib .pyplot as plt
9
9
import numpy as np
10
10
import scipy .constants as const
11
+ from matplotlib import colors
11
12
from matplotlib .collections import LineCollection
13
+ from matplotlib .colors import LinearSegmentedColormap
12
14
from monty .json import jsanitize
13
15
from pymatgen .electronic_structure .plotter import BSDOSPlotter , plot_brillouin_zone
14
16
from pymatgen .phonon .bandstructure import PhononBandStructureSymmLine
@@ -1052,35 +1054,75 @@ def bs_plot_data(self) -> dict[str, Any]:
1052
1054
"lattice" : self ._bs .lattice_rec .as_dict (),
1053
1055
}
1054
1056
1055
- def get_plot_gs (self , ylim : float | None = None , ** kwargs ) -> Axes :
1057
+ def get_plot_gs (self , ylim : float | None = None , plot_ph_bs_with_gruneisen : bool = False , ** kwargs ) -> Axes :
1056
1058
"""Get a matplotlib object for the Gruneisen bandstructure plot.
1057
1059
1058
1060
Args:
1059
1061
ylim: Specify the y-axis (gruneisen) limits; by default None let
1060
1062
the code choose.
1063
+ plot_ph_bs_with_gruneisen (bool): Plot phonon band-structure with bands coloured
1064
+ as per Gruneisen parameter values on a logarithmic scale
1061
1065
**kwargs: additional keywords passed to ax.plot().
1062
1066
"""
1067
+ u = freq_units (kwargs .get ("units" , "THz" ))
1063
1068
ax = pretty_plot (12 , 8 )
1064
1069
1070
+ # Create a colormap (default is red to blue)
1071
+ cmap = LinearSegmentedColormap .from_list ("cmap" , kwargs .get ("cmap" , ["red" , "blue" ]))
1072
+
1065
1073
kwargs .setdefault ("linewidth" , 2 )
1066
1074
kwargs .setdefault ("marker" , "o" )
1067
1075
kwargs .setdefault ("markersize" , 2 )
1068
1076
1069
1077
data = self .bs_plot_data ()
1070
- for dist_idx in range (len (data ["distances" ])):
1078
+
1079
+ # extract min and max Grüneisen parameter values
1080
+ max_gruneisen = np .array (data ["gruneisen" ]).max ()
1081
+ min_gruneisen = np .array (data ["gruneisen" ]).min ()
1082
+
1083
+ # LogNormalize colormap based on the min and max Grüneisen parameter values
1084
+ norm = colors .SymLogNorm (
1085
+ vmin = min_gruneisen ,
1086
+ vmax = max_gruneisen ,
1087
+ linthresh = 1e-2 ,
1088
+ linscale = 1 ,
1089
+ )
1090
+
1091
+ sc = None
1092
+ for (dists_inx , dists ), (_ , freqs ) in zip (enumerate (data ["distances" ]), enumerate (data ["frequency" ])):
1071
1093
for band_idx in range (self .n_bands ):
1072
- ys = [data ["gruneisen" ][dist_idx ][band_idx ][idx ] for idx in range (len (data ["distances" ][dist_idx ]))]
1094
+ if plot_ph_bs_with_gruneisen :
1095
+ ys = [freqs [band_idx ][j ] * u .factor for j in range (len (dists ))]
1096
+ ys_gru = [
1097
+ data ["gruneisen" ][dists_inx ][band_idx ][idx ] for idx in range (len (data ["distances" ][dists_inx ]))
1098
+ ]
1099
+ sc = ax .scatter (dists , ys , c = ys_gru , cmap = cmap , norm = norm , marker = "o" , s = 1 )
1100
+ else :
1101
+ keys_to_remove = ("units" , "cmap" ) # needs to be removed before passing to line-plot
1102
+ for k in keys_to_remove :
1103
+ kwargs .pop (k , None )
1104
+ ys = [
1105
+ data ["gruneisen" ][dists_inx ][band_idx ][idx ] for idx in range (len (data ["distances" ][dists_inx ]))
1106
+ ]
1073
1107
1074
- ax .plot (data ["distances" ][dist_idx ], ys , "b-" , ** kwargs )
1108
+ ax .plot (data ["distances" ][dists_inx ], ys , "b-" , ** kwargs )
1075
1109
1076
1110
self ._make_ticks (ax )
1077
1111
1078
1112
# plot y=0 line
1079
1113
ax .axhline (0 , linewidth = 1 , color = "black" )
1080
1114
1081
1115
# Main X and Y Labels
1082
- ax .set_xlabel (r"$\mathrm{Wave\ Vector}$" , fontsize = 30 )
1083
- ax .set_ylabel (r"$\mathrm{Grüneisen\ Parameter}$" , fontsize = 30 )
1116
+ if plot_ph_bs_with_gruneisen :
1117
+ ax .set_xlabel (r"$\mathrm{Wave\ Vector}$" , fontsize = 30 )
1118
+ units = kwargs .get ("units" , "THz" )
1119
+ ax .set_ylabel (f"Frequencies ({ units } )" , fontsize = 30 )
1120
+
1121
+ cbar = plt .colorbar (sc , ax = ax )
1122
+ cbar .set_label (r"$\gamma \ \mathrm{(logarithmized)}$" , fontsize = 30 )
1123
+ else :
1124
+ ax .set_xlabel (r"$\mathrm{Wave\ Vector}$" , fontsize = 30 )
1125
+ ax .set_ylabel (r"$\mathrm{Grüneisen\ Parameter}$" , fontsize = 30 )
1084
1126
1085
1127
# X range (K)
1086
1128
# last distance point
@@ -1094,24 +1136,37 @@ def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:
1094
1136
1095
1137
return ax
1096
1138
1097
- def show_gs (self , ylim : float | None = None ) -> None :
1139
+ def show_gs (self , ylim : float | None = None , plot_ph_bs_with_gruneisen : bool = False , ** kwargs ) -> None :
1098
1140
"""Show the plot using matplotlib.
1099
1141
1100
1142
Args:
1101
1143
ylim: Specifies the y-axis limits.
1144
+ plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured
1145
+ as per Gruneisen parameter values on a logarithmic scale
1146
+ **kwargs: kwargs passed to get_plot_gs
1102
1147
"""
1103
- self .get_plot_gs (ylim )
1148
+ self .get_plot_gs (ylim = ylim , plot_ph_bs_with_gruneisen = plot_ph_bs_with_gruneisen , ** kwargs )
1104
1149
plt .show ()
1105
1150
1106
- def save_plot_gs (self , filename : str | PathLike , img_format : str = "eps" , ylim : float | None = None ) -> None :
1151
+ def save_plot_gs (
1152
+ self ,
1153
+ filename : str | PathLike ,
1154
+ img_format : str = "eps" ,
1155
+ ylim : float | None = None ,
1156
+ plot_ph_bs_with_gruneisen : bool = False ,
1157
+ ** kwargs ,
1158
+ ) -> None :
1107
1159
"""Save matplotlib plot to a file.
1108
1160
1109
1161
Args:
1110
1162
filename: Filename to write to.
1111
1163
img_format: Image format to use. Defaults to EPS.
1112
1164
ylim: Specifies the y-axis limits.
1165
+ plot_ph_bs_with_gruneisen: Plot phonon band-structure with bands coloured
1166
+ as per Gruneisen parameter values on a logarithmic scale
1167
+ **kwargs: kwargs passed to get_plot_gs
1113
1168
"""
1114
- self .get_plot_gs (ylim = ylim )
1169
+ self .get_plot_gs (ylim = ylim , plot_ph_bs_with_gruneisen = plot_ph_bs_with_gruneisen , ** kwargs )
1115
1170
plt .savefig (filename , format = img_format )
1116
1171
plt .close ()
1117
1172
0 commit comments