Skip to content

Commit

Permalink
adds mpi support
Browse files Browse the repository at this point in the history
  • Loading branch information
dermen committed Aug 20, 2024
1 parent 1233625 commit fa209fe
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
44 changes: 27 additions & 17 deletions reciprocalspaceship/io/dials.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,29 @@ def get_msgpack_data(data, name):
return vals.T


def get_refl_data(fnames, ucell, symbol, rank=0, size=1):
def get_fnames(dirnames):
fnames = []
for dirname in dirnames:
fnames += glob.glob(dirname + "/*integrated.refl")
print("Found %d files" % len(fnames))
return fnames


def _concat(refl_data):
refl_data = [ds for ds in refl_data if ds is not None]
"""combine output of _get_refl_data"""
print("Combining tables!")
ds = rs.concat(refl_data)
expt_ids = set(ds.BATCH)
print(f"Found {len(ds)} refls from {len(expt_ids)} expts.")
print("Mapping batch column.")
expt_id_map = {name: i for i, name in enumerate(expt_ids)}
ds.BATCH = [expt_id_map[eid] for eid in ds.BATCH]
ds = ds.infer_mtz_dtypes().set_index(["H", "K", "L"], drop=True)
return ds


def _get_refl_data(fnames, ucell, symbol, rank=0, size=1):
"""
Parameters
Expand Down Expand Up @@ -115,26 +137,14 @@ def read_dials_stills(dirnames, ucell, symbol, nj=10):
-------
RS dataset (pandas Dataframe)
"""
fnames = []
for dirname in dirnames:
fnames += glob.glob(dirname + "/*integrated.refl")
print("Found %d files" % len(fnames))
fnames = get_fnames(dirnames)
ray.init(num_cpus=nj)

# get the refl data
_get_refl_data = ray.remote(get_refl_data)
get_refl_data = ray.remote(_get_refl_data)
refl_data = ray.get(
[_get_refl_data.remote(fnames, ucell, symbol, rank, nj) for rank in range(nj)]
[get_refl_data.remote(fnames, ucell, symbol, rank, nj) for rank in range(nj)]
)
refl_data = [ds for ds in refl_data if ds is not None]

print("Combining tables!")
ds = rs.concat(refl_data)
expt_ids = set(ds.BATCH)
print(f"Found {len(ds)} refls from {len(expt_ids)} expts.")
print("Mapping batch column.")
expt_id_map = {name: i for i, name in enumerate(expt_ids)}
ds.BATCH = [expt_id_map[eid] for eid in ds.BATCH]

ds = ds.infer_mtz_dtypes().set_index(["H", "K", "L"], drop=True)
ds = _concat(refl_data)
return ds
31 changes: 31 additions & 0 deletions reciprocalspaceship/io/dials_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from mpi4py import MPI

COMM = MPI.COMM_WORLD
from reciprocalspaceship.io import dials


def read_dials_stills_mpi(dirnames, ucell, symbol):
"""
Parameters
----------
dirnames: folders containing stills process results (integrated.refl)
ucell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
symbol: space group name e.g. P4
Returns
-------
RS dataset (pandas Dataframe) if MPI rank==0 else None
"""
fnames = None
if COMM.rank == 0:
fnames = dials.get_fnames(dirnames)
fnames = COMM.bcast(fnames)

refl_data = dials._get_refl_data(fnames, ucell, symbol, COMM.rank, COMM.size)
refl_data = COMM.gather(refl_data)
ds = None
if COMM.rank == 0:
ds = dials._concat(refl_data)

return ds

0 comments on commit fa209fe

Please sign in to comment.