Skip to content

Commit

Permalink
addresses review
Browse files Browse the repository at this point in the history
  • Loading branch information
dermen committed Aug 25, 2024
1 parent fa209fe commit f19bae4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 37 deletions.
62 changes: 34 additions & 28 deletions reciprocalspaceship/io/dials.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import glob
import logging

import gemmi
import msgpack
import numpy as np
import ray

LOGGER = logging.getLogger("rs.io.dials")
console = logging.StreamHandler()
console.setLevel(logging.INFO)
LOGGER.addHandler(console)
LOGGER.setLevel(logging.INFO)

import reciprocalspaceship as rs
from reciprocalspaceship.decorators import cellify, spacegroupify

MSGPACK_DTYPES = {
"double": np.float64,
Expand Down Expand Up @@ -38,36 +44,30 @@ def get_msgpack_data(data, name):
return vals.T


def get_fnames(dirnames):
fnames = []
for dirname in dirnames:
fnames += glob.glob(dirname + "/*integrated.refl")
print("Found %d files" % len(fnames))
return fnames


def _concat(refl_data):
refl_data = [ds for ds in refl_data if ds is not None]
"""combine output of _get_refl_data"""
print("Combining tables!")
LOGGER.debug("Combining tables!")
ds = rs.concat(refl_data)
expt_ids = set(ds.BATCH)
print(f"Found {len(ds)} refls from {len(expt_ids)} expts.")
print("Mapping batch column.")
LOGGER.debug(f"Found {len(ds)} refls from {len(expt_ids)} expts.")
LOGGER.debug("Mapping batch column.")
expt_id_map = {name: i for i, name in enumerate(expt_ids)}
ds.BATCH = [expt_id_map[eid] for eid in ds.BATCH]
ds = ds.infer_mtz_dtypes().set_index(["H", "K", "L"], drop=True)
return ds


def _get_refl_data(fnames, ucell, symbol, rank=0, size=1):
@cellify
@spacegroupify
def _get_refl_data(fnames, unitcell, spacegroup, rank=0, size=1):
"""
Parameters
----------
fnames: integrated refl fioles
ucell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
symbol: space group name e.g. P4
unitcell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
spacegroup: space group name e.g. P4
rank: process Id [0-N) where N is num proc
size: total number of proc (N)
Expand All @@ -77,14 +77,14 @@ def _get_refl_data(fnames, ucell, symbol, rank=0, size=1):
"""

sg_num = gemmi.find_spacegroup_by_name(symbol).number
all_ds = []

for i_f, f in enumerate(fnames):
if i_f % size != rank:
continue

if rank == 0:
print(f"Loading {i_f+1}/{len(fnames)}")
LOGGER.info(f"Loading {i_f+1}/{len(fnames)}")
_, _, R = msgpack.load(open(f, "rb"), strict_map_key=False)
refl_data = R["data"]
expt_id_map = R["identifiers"]
Expand All @@ -108,8 +108,8 @@ def _get_refl_data(fnames, ucell, symbol, rank=0, size=1):
"X": x,
"Y": y,
},
cell=ucell,
spacegroup=sg_num,
cell=unitcell,
spacegroup=spacegroup,
)
ds["SX"] = sx
ds["SY"] = sy
Expand All @@ -123,27 +123,33 @@ def _get_refl_data(fnames, ucell, symbol, rank=0, size=1):
return all_ds


def read_dials_stills(dirnames, ucell, symbol, nj=10):
@cellify
@spacegroupify
def read_dials_stills(fnames, unitcell, spacegroup, numjobs=10):
"""
Parameters
----------
dirnames: folders containing stills process results (integrated.refl)
ucell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
symbol: space group name e.g. P4
nj: number of jobs
fnames: integration files
unitcell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
spacegroup: space group name e.g. P4
numjobs: number of jobs
Returns
-------
RS dataset (pandas Dataframe)
"""
fnames = get_fnames(dirnames)
ray.init(num_cpus=nj)
ray.init(
num_cpus=numjobs, log_to_driver=LOGGER.level == logging.DEBUG
) # LOGGER.level==logging.DEBUG) #.DEBUG if verbose else logging.CRITICAL)

# get the refl data
get_refl_data = ray.remote(_get_refl_data)
refl_data = ray.get(
[get_refl_data.remote(fnames, ucell, symbol, rank, nj) for rank in range(nj)]
[
get_refl_data.remote(fnames, unitcell, spacegroup, rank, numjobs)
for rank in range(numjobs)
]
)

ds = _concat(refl_data)
Expand Down
17 changes: 8 additions & 9 deletions reciprocalspaceship/io/dials_mpi.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
from mpi4py import MPI

COMM = MPI.COMM_WORLD
from reciprocalspaceship.decorators import cellify, spacegroupify
from reciprocalspaceship.io import dials


def read_dials_stills_mpi(dirnames, ucell, symbol):
@cellify
@spacegroupify
def read_dials_stills_mpi(fnames, unitcell, spacegroup):
"""
Parameters
----------
dirnames: folders containing stills process results (integrated.refl)
ucell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
symbol: space group name e.g. P4
fnames: integrated reflection tables
unitcell: unit cell tuple (6 params Ang,Ang,Ang,deg,deg,deg)
spacegroup: space group name e.g. P4
Returns
-------
RS dataset (pandas Dataframe) if MPI rank==0 else None
"""
fnames = None
if COMM.rank == 0:
fnames = dials.get_fnames(dirnames)
fnames = COMM.bcast(fnames)

refl_data = dials._get_refl_data(fnames, ucell, symbol, COMM.rank, COMM.size)
refl_data = dials._get_refl_data(fnames, unitcell, spacegroup, COMM.rank, COMM.size)
refl_data = COMM.gather(refl_data)
ds = None
if COMM.rank == 0:
Expand Down

0 comments on commit f19bae4

Please sign in to comment.