Skip to content

Commit 96f05f2

Browse files
committed
removed spin ARG in bond()
- alpha/beta split is purely defined by `omods` ARG (evaluated in `get_repr()`)
1 parent 69f67f7 commit 96f05f2

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

qstack/spahm/rho/bond.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def bond(mols, dms,
1212
bpath=defaults.bpath, cutoff=defaults.cutoff, omods=defaults.omod,
13-
spin=None, elements=None, only_m0=False, zeros=False, printlevel=0,
13+
elements=None, only_m0=False, zeros=False, printlevel=0,
1414
pairfile=None, dump_and_exit=False, same_basis=False, only_z=[]):
1515
""" Computes SPAHM-b representations for a set of molecules.
1616
@@ -20,7 +20,6 @@ def bond(mols, dms,
2020
- bpath (str): path to the directory containing bond-optimized basis-functions (.bas)
2121
- cutoff (float): the cutoff distance (angstrom) between atoms to be considered as bond
2222
- omods (list of str): the selected mode for open-shell computations
23-
- spin (list of int): list of spins for each molecule
2423
- elements (list of str): list of all elements present in the set of molecules
2524
- only_m0 (bool): use only basis functions with `m=0`
2625
- zeros (bool): add zeros features for non-existing bond pairs
@@ -40,8 +39,6 @@ def bond(mols, dms,
4039
elements, mybasis, qqs0, qqs4q, idx, M = dmbb.read_basis_wrapper(mols, bpath, only_m0, printlevel,
4140
elements=elements, cutoff=cutoff,
4241
pairfile=pairfile, dump_and_exit=dump_and_exit, same_basis=same_basis)
43-
if np.array(spin==None, ndmin=1).all():
44-
omods = [None]
4542
qqs = qqs0 if zeros else qqs4q
4643
maxlen = max([dmbb.bonds_dict_init(qqs[q0], M)[1] for q0 in elements])
4744
if len(only_z) > 0:
@@ -116,10 +113,13 @@ def get_repr(mols, xyzlist, guess, xc=defaults.xc, spin=None, readdm=None,
116113
all_atoms = np.array([z for mol in mols for z in mol.elements if z in only_z], ndmin=2)
117114
else:
118115
all_atoms = np.array([mol.elements for mol in mols])
116+
spin = np.array(spin) ## a bit dirty but couldn't find a better way to ensure Iterable type!
117+
if (spin == None).all():
118+
omods = [None]
119119

120120
allvec = bond(mols, dms, bpath, cutoff, omods,
121-
spin=spin, elements=elements,
122-
only_m0=only_m0, zeros=zeros, printlevel=printlevel,
121+
elements=elements, only_m0=only_m0,
122+
zeros=zeros, printlevel=printlevel,
123123
pairfile=pairfile, dump_and_exit=dump_and_exit, same_basis=same_basis, only_z=only_z)
124124
maxlen=allvec.shape[-1]
125125
natm = allvec.shape[-2]

qstack/spahm/rho/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def mols_guess(mols, xyzlist, guess, xc=defaults.xc, spin=None, readdm=False, pr
6060
if printlevel>0: print(xyzfile, flush=True)
6161
if not readdm:
6262
e, v = spahm.get_guess_orbitals(mol, guess, xc=xc)
63-
dm = guesses.get_dm(v, mol.nelec, mol.spin if spin is not None else None)
63+
dm = guesses.get_dm(v, mol.nelec, mol.spin if spin is not None else None) # mol.spin can not be `None`
6464
else:
6565
dm = np.load(readdm+'/'+os.path.basename(xyzfile)+'.npy')
6666
if spin and dm.ndim==2:

tests/test_spahm_b.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def test_water():
99
mols = utils.load_mols([xyz_in], [0], [0], 'minao')
1010
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[0])
1111
X = bond.get_repr(mols, [xyz_in], 'LB', spin=[0], with_symbols=False, same_basis=False)
12-
#X = np.hstack(X) # merging alpha-beta components for spin unrestricted representation #TODO: should be included into function not in main
1312
true_file = path+'/data/H2O_spahm_b.npy_alpha_beta.npy'
1413
X_true = np.load(true_file)
1514
assert(X_true.shape == X.shape)
@@ -21,7 +20,7 @@ def test_water_O_only():
2120
xyz_in = path+'/data/H2O.xyz'
2221
mols = utils.load_mols([xyz_in], [0], [0], 'minao')
2322
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[0])
24-
X = bond.bond(mols, dms, spin=[0], only_z=['O'])
23+
X = bond.bond(mols, dms, only_z=['O'])
2524
X = np.squeeze(X) #contains a single elements but has shape (1,Nfeat)
2625
X = np.hstack(X) # merging alpha-beta components for spin unrestricted representation #TODO: should be included into function not in main
2726
true_file = path+'/data/H2O_spahm_b.npy_alpha_beta.npy'
@@ -36,7 +35,7 @@ def test_water_same_basis():
3635
xyz_in = path+'/data/H2O.xyz'
3736
mols = utils.load_mols([xyz_in], [0], [0], 'minao')
3837
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[0])
39-
X = bond.bond(mols, dms, spin=[0], same_basis=True)
38+
X = bond.bond(mols, dms, same_basis=True)
4039
X = np.squeeze(X) #contains a single elements but has shape (1,Nfeat)
4140
X = np.hstack(X) # merging alpha-beta components for spin unrestricted representation #TODO: should be included into function not in main
4241
true_file = path+'/data/H2O_spahm_b_CCbas.npy_alpha_beta.npy'
@@ -50,7 +49,7 @@ def test_ecp():
5049
xyz_in = path+'/data/I2.xyz'
5150
mols = utils.load_mols([xyz_in], [0], [None], 'minao', ecp='def2-svp')
5251
dms = utils.mols_guess(mols, [xyz_in], 'LB', spin=[None])
53-
X = bond.bond(mols, dms, spin=[None], same_basis=True)
52+
X = bond.bond(mols, dms, same_basis=True)
5453
X = np.squeeze(X) #contains a single elements but has shape (1,Nfeat)
5554
X = np.hstack(X) # merging alpha-beta components for spin unrestricted representation #TODO: should be included into function not in main
5655
true_file = path+'/data/I2_spahm-b_minao-def2-svp_alpha-beta.npy'

0 commit comments

Comments
 (0)