Skip to content

Commit a56bc74

Browse files
committed
ruff formatting
1 parent b46f3ce commit a56bc74

27 files changed

+1624
-1228
lines changed

setup.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
from setuptools import setup, find_packages
22

33
setup(
4-
name='vsh',
5-
version='0.4',
4+
name="vsh",
5+
version="0.4",
66
packages=find_packages(),
77
install_requires=[
8-
'ase',
9-
'numpy',
10-
'matplotlib',
11-
'pymatgen',
12-
'mp-api',
13-
'pyprocar',
14-
'pydantic'
8+
"ase",
9+
"numpy",
10+
"matplotlib",
11+
"pymatgen",
12+
"mp-api",
13+
"pyprocar",
14+
"pydantic",
1515
],
16-
entry_points={
17-
'console_scripts': [
18-
'vsh = vsh.cli:main'
19-
]
20-
}
16+
entry_points={"console_scripts": ["vsh = vsh.cli:main"]},
2117
)

test/test_kpoints.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,46 @@
44
import argparse
55
from vsh.scripts.kpoints import write_kpoints
66

7-
class TestWriteKpoints(unittest.TestCase):
87

8+
class TestWriteKpoints(unittest.TestCase):
99
def print_expected_and_result(self):
10-
args = argparse.Namespace(mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output=None)
10+
args = argparse.Namespace(
11+
mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output=None
12+
)
1113
expected_kpoints = Kpoints.monkhorst_automatic(kpts=(4, 4, 4))
1214

13-
with patch('builtins.print') as mock_print, patch('pymatgen.io.vasp.inputs.Kpoints.write_file') as mock_write_file:
15+
with patch("builtins.print") as mock_print, patch(
16+
"pymatgen.io.vasp.inputs.Kpoints.write_file"
17+
) as mock_write_file:
1418
result = write_kpoints(args)
1519

1620
print(expected_kpoints)
1721
print(result)
1822

19-
2023
def test_write_kpoints_monkhorst(self):
21-
args = argparse.Namespace(mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output=None)
24+
args = argparse.Namespace(
25+
mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output=None
26+
)
2227
expected_kpoints = Kpoints.monkhorst_automatic(kpts=(4, 4, 4))
2328

24-
with patch('builtins.print') as mock_print, patch('pymatgen.io.vasp.inputs.Kpoints.write_file') as mock_write_file:
29+
with patch("builtins.print") as mock_print, patch(
30+
"pymatgen.io.vasp.inputs.Kpoints.write_file"
31+
) as mock_write_file:
2532
result = write_kpoints(args)
2633

2734
self.assertEqual(result, expected_kpoints)
2835
mock_print.assert_not_called()
2936
mock_write_file.assert_not_called()
3037

3138
def test_write_kpoints_gamma(self):
32-
args = argparse.Namespace(mesh_type="gamma", mesh=(4, 4, 4), input=None, output=None)
39+
args = argparse.Namespace(
40+
mesh_type="gamma", mesh=(4, 4, 4), input=None, output=None
41+
)
3342
expected_kpoints = Kpoints.gamma_automatic(kpts=(4, 4, 4))
3443

35-
with patch('builtins.print') as mock_print, patch('pymatgen.io.vasp.inputs.Kpoints.write_file') as mock_write_file:
44+
with patch("builtins.print") as mock_print, patch(
45+
"pymatgen.io.vasp.inputs.Kpoints.write_file"
46+
) as mock_write_file:
3647
result = write_kpoints(args)
3748

3849
self.assertEqual(result, expected_kpoints)
@@ -51,26 +62,35 @@ def test_write_kpoints_gamma(self):
5162
# mock_write_file.assert_not_called()
5263

5364
def test_write_kpoints_no_output(self):
54-
args = argparse.Namespace(mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output=None)
65+
args = argparse.Namespace(
66+
mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output=None
67+
)
5568
expected_kpoints = Kpoints.monkhorst_automatic(kpts=(4, 4, 4))
5669

57-
with patch('builtins.print') as mock_print, patch('pymatgen.io.vasp.inputs.Kpoints.write_file') as mock_write_file:
70+
with patch("builtins.print") as mock_print, patch(
71+
"pymatgen.io.vasp.inputs.Kpoints.write_file"
72+
) as mock_write_file:
5873
result = write_kpoints(args)
5974

6075
self.assertEqual(result, expected_kpoints)
6176
mock_print.assert_called_once_with(expected_kpoints)
6277
mock_write_file.assert_not_called()
6378

6479
def test_write_kpoints_with_output(self):
65-
args = argparse.Namespace(mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output="KPOINTS")
80+
args = argparse.Namespace(
81+
mesh_type="monkhorst", mesh=(4, 4, 4), input=None, output="KPOINTS"
82+
)
6683
expected_kpoints = Kpoints.monkhorst_automatic(kpts=(4, 4, 4))
6784

68-
with patch('builtins.print') as mock_print, patch('pymatgen.io.vasp.inputs.Kpoints.write_file') as mock_write_file:
85+
with patch("builtins.print") as mock_print, patch(
86+
"pymatgen.io.vasp.inputs.Kpoints.write_file"
87+
) as mock_write_file:
6988
result = write_kpoints(args)
7089

7190
self.assertEqual(result, expected_kpoints)
7291
mock_print.assert_not_called()
73-
mock_write_file.assert_called_once_with('KPOINTS')
92+
mock_write_file.assert_called_once_with("KPOINTS")
93+
7494

75-
if __name__ == '__main__':
95+
if __name__ == "__main__":
7696
unittest.main()

test/test_wavecar.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pymatgen.io.vasp.outputs import Wavecar
44
import argparse
55

6-
'''
6+
"""
77
Attributes:
88
filename (str): String of the input file (usually WAVECAR).
99
vasp_type (str): String that determines VASP type the WAVECAR was generated with.
@@ -31,50 +31,45 @@
3131
(e.g. self.coeffs[kp][b] corresponds to k-point kp and band b). For spin-polarized calculations,
3232
the first index is for the spin. If the calculation was non-collinear, then self.coeffs[kp][b] will have
3333
two columns (one for each component of the spinor).
34-
'''
34+
"""
3535

3636

37-
bi2se3_wavecar = '/home/wladerer/github/vsh/test/files/Bi2Se3/WAVECAR'
38-
bi2se3_poscar = '/home/wladerer/github/vsh/test/files/Bi2Se3/POSCAR'
37+
bi2se3_wavecar = "/home/wladerer/github/vsh/test/files/Bi2Se3/WAVECAR"
38+
bi2se3_poscar = "/home/wladerer/github/vsh/test/files/Bi2Se3/POSCAR"
3939
# nk = 6, nb = 64, spin = 1, efermi = 4.0126579574404495, encut = 600.0
40-
#lets get some info from the wavecar file
40+
# lets get some info from the wavecar file
4141
import pandas as pd
4242
from pymatgen.io.vasp.outputs import Wavecar
4343

44-
def get_band_occupancy_info(wavecar: Wavecar):
4544

45+
def get_band_occupancy_info(wavecar: Wavecar):
4646
band_energy = wavecar.band_energy
4747
dfs = []
4848
for kpoint, bands in enumerate(band_energy):
49-
df = pd.DataFrame(bands, columns=['energy', 'idk', 'occupancy'])
50-
df['kpoint'] = kpoint
49+
df = pd.DataFrame(bands, columns=["energy", "idk", "occupancy"])
50+
df["kpoint"] = kpoint
5151

5252
# Update each row of bands to include band index (starting at 0)
5353
for i, row in df.iterrows():
54-
df.at[i, 'band'] = i
54+
df.at[i, "band"] = i
5555

5656
dfs.append(df)
5757

5858
df = pd.concat(dfs)
5959

6060
# Sum the occupation of each band and divide by the number of kpoints to get a relative occupancy
61-
df['relative_occupancy'] = df.groupby('band')['occupancy'].transform('sum') / wavecar.nk
61+
df["relative_occupancy"] = (
62+
df.groupby("band")["occupancy"].transform("sum") / wavecar.nk
63+
)
6264

6365
# Drop the kpoint and idk columns
64-
df = df.drop(columns=['kpoint', 'idk', 'occupancy'])
66+
df = df.drop(columns=["kpoint", "idk", "occupancy"])
6567
# Remove duplicate bands
6668
df = df.drop_duplicates()
6769

68-
return df[df['relative_occupancy'] > 0].tail(19)
70+
return df[df["relative_occupancy"] > 0].tail(19)
71+
6972

7073
wavecar = Wavecar(bi2se3_wavecar)
7174
df = get_band_occupancy_info(wavecar)
7275
print(df)
73-
74-
75-
76-
77-
78-
79-
80-

vsh/cli.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@ def lazy_import(name):
1010

1111

1212
def parse_app_args(args=None):
13-
parser = argparse.ArgumentParser(description='vsh command line utility')
14-
subparsers = parser.add_subparsers(dest='command', required=True)
13+
parser = argparse.ArgumentParser(description="vsh command line utility")
14+
subparsers = parser.add_subparsers(dest="command", required=True)
1515
scripts.setup(subparsers)
1616
return parser.parse_args()
1717

1818

1919
def main():
2020
args = parse_app_args()
21-
command = lazy_import('vsh.scripts.'+args.command)
21+
command = lazy_import("vsh.scripts." + args.command)
2222
command.run(args)
2323

24+
2425
if __name__ == "__main__":
2526
main()

0 commit comments

Comments
 (0)