Skip to content

Commit

Permalink
ENH: Add support for multiprocess CPU thread for MultiTEMapping class
Browse files Browse the repository at this point in the history
  • Loading branch information
acsenrafilho committed Nov 16, 2024
1 parent 96daf9a commit a5ad9e7
Showing 1 changed file with 185 additions and 102 deletions.
287 changes: 185 additions & 102 deletions asltk/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
att_map = None
brain_mask = None
asl_data = None
ld_arr = None
pld_arr = None
te_arr = None
tblgm_map = None
t2bl = None
t2gm = None


class CBFMapping(MRIParameters):
Expand Down Expand Up @@ -131,6 +137,7 @@ def create_map(
or len(self._asl_data.get_pld()) == 0
):
raise ValueError('LD or PLD list of values must be provided.')
# TODO Testar se retirando esse if do LD PLD sizes, continua rodando... isso é erro do ASLData

global asl_data, brain_mask
asl_data = self._asl_data
Expand Down Expand Up @@ -176,6 +183,47 @@ def create_map(
}


def _cbf_init_globals(
cbf_map_, att_map_, brain_mask_, asl_data_
): # pragma: no cover
# indirect call method by CBFMapping().create_map()
global cbf_map, att_map, brain_mask, asl_data
cbf_map = cbf_map_
att_map = att_map_
brain_mask = brain_mask_
asl_data = asl_data_


def _cbf_process_slice(
i, x_axis, y_axis, z_axis, BuxtonX, par0, lb, ub
): # pragma: no cover
# indirect call method by CBFMapping().create_map()
for j in range(y_axis):
for k in range(z_axis):
if brain_mask[k, j, i] != 0:
m0_px = asl_data('m0')[k, j, i]

def mod_buxton(Xdata, par1, par2):
return asl_model_buxton(
Xdata[0], Xdata[1], m0_px, par1, par2
)

Ydata = asl_data('pcasl')[0, :, k, j, i]

# Calculate the processing index for the 3D space
index = k * (y_axis * x_axis) + j * x_axis + i

try:
par_fit, _ = curve_fit(
mod_buxton, BuxtonX, Ydata, p0=par0, bounds=(lb, ub)
)
cbf_map[index] = par_fit[0]
att_map[index] = par_fit[1]
except RuntimeError:
cbf_map[index] = 0.0
att_map[index] = 0.0


class MultiTE_ASLMapping(MRIParameters):
def __init__(self, asl_data: ASLData) -> None:
"""Basic MultiTE_ASLMapping constructor
Expand Down Expand Up @@ -295,6 +343,7 @@ def create_map(
ub: list = [np.inf],
lb: list = [0.0],
par0: list = [400],
cores=cpu_count(),
):
"""Create the T1 relaxation exchange between blood and Grey Matter (GM)
, i.e. the T1blGM map resulted from the multi-compartiment TE ASL model.
Expand Down Expand Up @@ -327,11 +376,12 @@ def create_map(
ub (list, optional): The upper limit values. Defaults to [1.0, 5000.0].
lb (list, optional): The lower limit values. Defaults to [0.0, 0.0].
par0 (list, optional): The initial guess parameter for non-linear fitting. Defaults to [1e-5, 1000].
cores (int, optional): Defines how many CPU threads can be used for the class. Defaults is using all the availble threads.
Returns:
(dict): A dictionary with 'cbf', 'att' and 'cbf_norm'
"""
# TODO As entradas ub, lb e par0 não são aplicadas para CBF. Pensar se precisa ter essa flexibilidade para acertar o CBF interno à chamada
# # TODO As entradas ub, lb e par0 não são aplicadas para CBF. Pensar se precisa ter essa flexibilidade para acertar o CBF interno à chamada
self._basic_maps.set_brain_mask(self._brain_mask)

basic_maps = {'cbf': self._cbf_map, 'att': self._att_map}
Expand All @@ -344,59 +394,52 @@ def create_map(
self._cbf_map = basic_maps['cbf']
self._att_map = basic_maps['att']

global asl_data, brain_mask, cbf_map, att_map, t2bl, t2gm
asl_data = self._asl_data
brain_mask = self._brain_mask
cbf_map = self._cbf_map
att_map = self._att_map
ld_arr = self._asl_data.get_ld()
pld_arr = self._asl_data.get_pld()
te_arr = self._asl_data.get_te()
t2bl = self.T2bl
t2gm = self.T2gm

x_axis = self._asl_data('m0').shape[2] # height
y_axis = self._asl_data('m0').shape[1] # width
z_axis = self._asl_data('m0').shape[0] # depth

for i in track(
range(x_axis), description='[green]multiTE-ASL processing...'
):
for j in range(y_axis):
for k in range(z_axis):
if self._brain_mask[k, j, i] != 0:
m0_px = self._asl_data('m0')[k, j, i]
tblgm_map_shared = Array('d', z_axis * y_axis * x_axis, lock=False)

def mod_2comp(Xdata, par1):
return asl_model_multi_te(
Xdata[:, 0],
Xdata[:, 1],
Xdata[:, 2],
m0_px,
basic_maps['cbf'][k, j, i],
basic_maps['att'][k, j, i],
par1,
self.T2bl,
self.T2gm,
)

Ydata = (
self._asl_data('pcasl')[:, :, k, j, i]
.reshape(
(
len(self._asl_data.get_ld())
* len(self._asl_data.get_te()),
1,
)
)
.flatten()
)
with Pool(
processes=cores,
initializer=_multite_init_globals,
initargs=(
cbf_map,
att_map,
brain_mask,
asl_data,
ld_arr,
pld_arr,
te_arr,
tblgm_map_shared,
t2bl,
t2gm,
),
) as pool:
pool.starmap(
_tblgm_multite_process_slice,
[
(i, x_axis, y_axis, z_axis, par0, lb, ub)
for i in track(
range(x_axis), description='multiTE-ASL processing...'
)
],
)

try:
Xdata = self._create_x_data(
self._asl_data.get_ld(),
self._asl_data.get_pld(),
self._asl_data.get_te(),
)
par_fit, _ = curve_fit(
mod_2comp,
Xdata,
Ydata,
p0=par0,
bounds=(lb, ub),
)
self._t1blgm_map[k, j, i] = par_fit[0]
except RuntimeError: # pragma: no cover
self._t1blgm_map[k, j, i] = 0.0
self._t1blgm_map = np.frombuffer(tblgm_map_shared).reshape(
z_axis, y_axis, x_axis
)

# Adjusting output image boundaries
self._t1blgm_map = self._adjust_image_limits(self._t1blgm_map, par0[0])
Expand All @@ -408,19 +451,6 @@ def mod_2comp(Xdata, par1):
't1blgm': self._t1blgm_map,
}

def _create_x_data(self, ld, pld, te):
# array for the x values, assuming an arbitrary size based on the PLD
# and TE vector size
Xdata = np.zeros((len(pld) * len(te), 3))

count = 0
for i in range(len(pld)):
for j in range(len(te)):
Xdata[count] = [ld[i], pld[i], te[j]]
count += 1

return Xdata

def _adjust_image_limits(self, map, init_guess):
img = sitk.GetImageFromArray(map)
thr_filter = sitk.ThresholdImageFilter()
Expand All @@ -433,6 +463,100 @@ def _adjust_image_limits(self, map, init_guess):
return sitk.GetArrayFromImage(img)


def _multite_init_globals(
cbf_map_,
att_map_,
brain_mask_,
asl_data_,
ld_arr_,
pld_arr_,
te_arr_,
tblgm_map_,
t2bl_,
t2gm_,
): # pragma: no cover
# indirect call method by CBFMapping().create_map()
global cbf_map, att_map, brain_mask, asl_data, ld_arr, te_arr, pld_arr, tblgm_map, t2bl, t2gm
cbf_map = cbf_map_
att_map = att_map_
brain_mask = brain_mask_
asl_data = asl_data_
ld_arr = ld_arr_
pld_arr = pld_arr_
te_arr = te_arr_
tblgm_map = tblgm_map_
t2bl = t2bl_
t2gm = t2gm_


def _tblgm_multite_process_slice(
i, x_axis, y_axis, z_axis, par0, lb, ub
): # pragma: no cover
# indirect call method by CBFMapping().create_map()
for j in range(y_axis):
for k in range(z_axis):
if brain_mask[k, j, i] != 0:
m0_px = asl_data('m0')[k, j, i]

def mod_2comp(Xdata, par1):
return asl_model_multi_te(
Xdata[:, 0],
Xdata[:, 1],
Xdata[:, 2],
m0_px,
cbf_map[k, j, i],
att_map[k, j, i],
par1,
t2bl,
t2gm,
)

Ydata = (
asl_data('pcasl')[:, :, k, j, i]
.reshape(
(
len(ld_arr) * len(te_arr),
1,
)
)
.flatten()
)

# Calculate the processing index for the 3D space
index = k * (y_axis * x_axis) + j * x_axis + i

try:
Xdata = _multite_create_x_data(
ld_arr,
pld_arr,
te_arr,
)
par_fit, _ = curve_fit(
mod_2comp,
Xdata,
Ydata,
p0=par0,
bounds=(lb, ub),
)
tblgm_map[index] = par_fit[0]
except RuntimeError: # pragma: no cover
tblgm_map[index] = 0.0


def _multite_create_x_data(ld, pld, te): # pragma: no cover
# array for the x values, assuming an arbitrary size based on the PLD
# and TE vector size
Xdata = np.zeros((len(pld) * len(te), 3))

count = 0
for i in range(len(pld)):
for j in range(len(te)):
Xdata[count] = [ld[i], pld[i], te[j]]
count += 1

return Xdata


class MultiDW_ASLMapping(MRIParameters):
def __init__(self, asl_data: ASLData):
super().__init__()
Expand Down Expand Up @@ -704,44 +828,3 @@ def _check_mask_values(mask, label, ref_shape):
raise TypeError(
f'Image mask dimension does not match with input 3D volume. Mask shape {mask_shape} not equal to {ref_shape}'
)


def _cbf_init_globals(
cbf_map_, att_map_, brain_mask_, asl_data_
): # pragma: no cover
# indirect call method by CBFMapping().create_map()
global cbf_map, att_map, brain_mask, asl_data
cbf_map = cbf_map_
att_map = att_map_
brain_mask = brain_mask_
asl_data = asl_data_


def _cbf_process_slice(
i, x_axis, y_axis, z_axis, BuxtonX, par0, lb, ub
): # pragma: no cover
# indirect call method by CBFMapping().create_map()
for j in range(y_axis):
for k in range(z_axis):
if brain_mask[k, j, i] != 0:
m0_px = asl_data('m0')[k, j, i]

def mod_buxton(Xdata, par1, par2):
return asl_model_buxton(
Xdata[0], Xdata[1], m0_px, par1, par2
)

Ydata = asl_data('pcasl')[0, :, k, j, i]

# Calculate the processing index for the 3D space
index = k * (y_axis * x_axis) + j * x_axis + i

try:
par_fit, _ = curve_fit(
mod_buxton, BuxtonX, Ydata, p0=par0, bounds=(lb, ub)
)
cbf_map[index] = par_fit[0]
att_map[index] = par_fit[1]
except RuntimeError:
cbf_map[index] = 0.0
att_map[index] = 0.0

0 comments on commit a5ad9e7

Please sign in to comment.