Skip to content

Commit 7cafd2f

Browse files
committed
Add test_wavecar.py and update wavecar.py***
1 parent d54ec87 commit 7cafd2f

File tree

3 files changed

+72
-13
lines changed

3 files changed

+72
-13
lines changed

test/test_wavecar.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import unittest
2+
from unittest.mock import patch
3+
from vsh.scripts.wavecar import handle_string_inputs
4+
from pymatgen.io.vasp.outputs import Wavecar
5+
import argparse
6+
7+
test_wavecar1_file = '/Users/wladerer/research/pt3sn/bulk/WAVECAR'
8+
test_poscar1_file = '/Users/wladerer/research/pt3sn/bulk/POSCAR'
9+
10+
# nk = 180, nb =112, spin =1
11+
test_wavecar2_file = '/Users/wladerer/research/tetrads/bi2te3/band/WAVECAR'
12+
test_poscar2_file = '/Users/wladerer/research/tetrads/bi2te3/band/POSCAR'
13+
14+
#lets get some info from the wavecar file
15+
wave = Wavecar(test_wavecar1_file)
16+
print(wave.nk, wave.nb, wave.spin)

vsh/scripts/wavecar.py

+55-13
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,74 @@
11
from pymatgen.io.vasp import Wavecar
22
from pymatgen.io.vasp import Poscar
33
from pymatgen.io.wannier90 import Unk
4+
from itertools import product
45
import numpy as np
56

67

8+
def get_partial_charge_density(wavecar: Wavecar, structure: Poscar, kpoints: list[int], bands: list[int], spins: list[int], spinors: list[int], phases: list[int], scale: float):
9+
'''Returns the partial charge densities for given kpoints, bands, spins, spinors, phases, and scale.'''
10+
wave = Wavecar(wavecar)
11+
poscar = Poscar.from_file(structure)
12+
13+
combinations = list(product(kpoints, bands, spins, spinors, phases))
14+
15+
#initialize the first chgcar
16+
parchg = wave.get_parchg(poscar, *combinations[0], scale=scale)
17+
18+
#add the rest of the chgcar
19+
for kpoint, band, spin, spinor, phase in combinations[1:]:
20+
parchg += wave.get_parchg(poscar, kpoint, band, spin, spinor, phase, scale)
21+
22+
return parchg
23+
24+
def handle_string_inputs(args, wavecar):
25+
'''Checks if all is passed as an input and returns a list of all possible values.'''
26+
args_mapping = {
27+
'kpoints': (wavecar.nk, list(range(wavecar.nk))),
28+
'bands': (wavecar.nb, list(range(wavecar.nb))),
29+
'spins': (wavecar.spin, list(range(wavecar.spin))),
30+
'spinors': (2, list(range(2))),
31+
'phases': (2, list(range(2)))
32+
}
33+
34+
for arg, (max_value, default_value) in args_mapping.items():
35+
try:
36+
if args.__dict__[arg] == 'all':
37+
args.__dict__[arg] = default_value
38+
elif ':' in args.__dict__[arg]:
39+
start, end = map(int, args.__dict__[arg].split(':'))
40+
args.__dict__[arg] = list(range(start, end+1))
41+
else:
42+
args.__dict__[arg] = [int(args.__dict__[arg])]
43+
except (ValueError, TypeError):
44+
print(f"Invalid input for {arg}. Please provide a valid input.")
45+
46+
return args
47+
48+
749
def generate_parchg(args):
850
'''Generates a PARCHG file from a WAVECAR file. '''
951
wave = Wavecar(args.input)
1052
poscar = Poscar.from_file(args.structure)
11-
12-
chgcar = wave.get_parchg(poscar, args.kpoint, args.band, args.spin, args.spinor, args.phase, args.scale)
53+
args = handle_string_inputs(args, wave)
54+
parchg = get_partial_charge_density(wave, poscar, args.kpoints, args.bands, args.spins, args.spinors, args.phases, args.scale)
1355

1456
if args.output:
1557
if args.cube:
16-
chgcar.to_cube(args.output)
58+
parchg.to_cube(args.output)
1759
else:
18-
chgcar.write_file(args.output)
60+
parchg.write_file(args.output)
1961
else:
20-
print(chgcar.__str__())
62+
print(parchg.__str__())
2163

22-
def generate_fft_mesh(args):
23-
'''Generates a COEFFS file from a WAVECAR file. '''
24-
mesh = Wavecar(args.input).fft_mesh(args.kpoint, args.band, args.spin, args.spinor, args.shift)
25-
evals = np.fft.ifftn(mesh)
26-
if args.output:
27-
np.save(args.output, evals)
28-
else:
29-
print(evals)
64+
# def generate_fft_mesh(args):
65+
# '''Generates a COEFFS file from a WAVECAR file. '''
66+
# mesh = Wavecar(args.input).fft_mesh(args.kpoints, args.bands, args.spins, args.spinors)
67+
# evals = np.fft.ifftn(mesh)
68+
# if args.output:
69+
# np.save(args.output, evals)
70+
# else:
71+
# print(evals)
3072

3173
def generate_unk(args):
3274
"""

vsh/utils/structure_tools.py

+1
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,4 @@ def plot_atomic_drift(initial_file: str, final_file: str, output_file: str = Non
147147

148148

149149

150+

0 commit comments

Comments
 (0)