From cb0b8311a39b5b1c803f53b5c9364c28a6703ae6 Mon Sep 17 00:00:00 2001 From: "Georg Schramm (Windows)" Date: Thu, 26 Sep 2024 14:31:45 +0200 Subject: [PATCH] first STIR SPD3O implementation --- SPD3O.py | 772 +++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 72 ++++- main_svrg.py | 338 ++++++++++++++++++++++ test_petric.py | 18 +- 4 files changed, 1179 insertions(+), 21 deletions(-) create mode 100644 SPD3O.py create mode 100644 main_svrg.py diff --git a/SPD3O.py b/SPD3O.py new file mode 100644 index 0000000..d761e90 --- /dev/null +++ b/SPD3O.py @@ -0,0 +1,772 @@ +"""script to compare SPDHG, SGD and SVRG for PET reconstruction with quad. and RDP prior""" + +# %% +from __future__ import annotations + +try: + import array_api_compat.cupy as xp +except ImportError: + import array_api_compat.numpy as xp + +import parallelproj +import array_api_compat.numpy as np +import matplotlib.pyplot as plt + +from array_api_compat import to_device +from scipy.optimize import fmin_l_bfgs_b +from pathlib import Path + +from utils import ( + SubsetNegPoissonLogLWithPrior, + split_fwd_model, + OSEM, + SGD, + SVRG, + rdp_preconditioner, +) + +import sys + +sys.path.append("../") +from rdp import RDP, neighbor_product + +# choose a device (CPU or CUDA GPU) +if "numpy" in xp.__name__: + # using numpy, device must be cpu + dev = "cpu" +elif "cupy" in xp.__name__: + # using cupy, only cuda devices are possible + dev = xp.cuda.Device(0) + +# %% +# set up + +# input parameters +seed = 1 + +# true counts, reasonable range: 1e6, 1e7 (high counts), 1e5 (low counts) +true_counts = 1e6 +# regularization weight, reasonable range: 5e-5 * (true_counts / 1e6) is medium regularization +beta = 5e-3 * (true_counts / 1e6) +# RDP gamma parameter +gamma_rdp = 2.0 + +# number of epochs / subsets for stochastic gradient algorithms +num_epochs = 2 +num_subsets = 108 + +# max number of updates for reference L-BFGS-B solution +num_iter_bfgs_ref = 400 + +# number of rings of simulated PET scanner, should be odd in this example +num_rings = 1 +# resolution of the simulated PET scanner in mm +fwhm_data_mm = 4.5 +# simulated TOF or non-TOF system +tof = False +# mean of contamination sinogram, relative to mean of trues sinogram, reasonable range: 0.5 - 1.0 +contam_fraction = 0.5 +# verbose output +verbose = False +# track cost function values after every update (slow) +track_cost = False + +# number of epochs / subsets for intial OSEM +num_epochs_osem = 1 +num_subsets_osem = 27 + +nrmse_limit = 5e-3 + +# random seed +np.random.seed(seed) + +# Setup of the forward model :math:`\bar{y}(x) = A x + s` +# -------------------------------------------------------- +# +# We setup a linear forward operator :math:`A` consisting of an +# image-based resolution model, a non-TOF PET projector and an attenuation model +# +# .. note:: +# The OSEM implementation below works with all linear operators that +# subclass :class:`.LinearOperator` (e.g. the high-level projectors). + +scanner = parallelproj.RegularPolygonPETScannerGeometry( + xp, + dev, + radius=300.0, + num_sides=36, + num_lor_endpoints_per_side=12, + lor_spacing=4.0, + ring_positions=xp.linspace( + -5 * (num_rings - 1) / 2, 5 * (num_rings - 1) / 2, num_rings + ), + symmetry_axis=2, +) + +# setup the LOR descriptor that defines the sinogram + +img_shape = (100, 100, 2 * num_rings - 1) +voxel_size = (2.0, 2.0, 2.0) + +lor_desc = parallelproj.RegularPolygonPETLORDescriptor( + scanner, + radial_trim=140, + sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP, +) + +if lor_desc.num_views % num_subsets != 0: + raise ValueError( + f"num_subsets ({num_subsets}) must be a divisor of num_views ({lor_desc.num_views})" + ) +if lor_desc.num_views % num_subsets_osem != 0: + raise ValueError( + f"num_subsets_osem ({num_subsets_osem}) must be a divisor of num_views ({lor_desc.num_views})" + ) + +proj = parallelproj.RegularPolygonPETProjector( + lor_desc, img_shape=img_shape, voxel_size=voxel_size +) + +# setup a simple test image containing a few "hot rods" +x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32) +c0 = proj.in_shape[0] // 2 +c1 = proj.in_shape[1] // 2 +x_true[(c0 - 4) : (c0 + 4), (c1 - 4) : (c1 + 4), :] = 3.0 + +x_true[28:32, c1 : (c1 + 4), :] = 5.0 +x_true[c0 : (c0 + 4), 20:24, :] = 5.0 + +x_true[-32:-28, c1 : (c1 + 4), :] = 0.1 +x_true[c0 : (c0 + 4), -24:-20, :] = 0.1 + +x_true[:25, :, :] = 0 +x_true[-25:, :, :] = 0 +x_true[:, :10, :] = 0 +x_true[:, -10:, :] = 0 + +# Attenuation image and sinogram setup +# ------------------------------------ + +# setup an attenuation image +x_att = 0.01 * xp.astype(x_true > 0, xp.float32) +# calculate the attenuation sinogram +att_sino = xp.exp(-proj(x_att)) + +# Complete PET forward model setup +# -------------------------------- +# +# We combine an image-based resolution model, +# a non-TOF or TOF PET projector and an attenuation model +# into a single linear operator. + +# enable TOF - comment if you want to run non-TOF +if tof is True: + proj.tof_parameters = parallelproj.TOFParameters( + num_tofbins=13, tofbin_width=12.0, sigma_tof=12.0 + ) + +# setup the attenuation multiplication operator which is different +# for TOF and non-TOF since the attenuation sinogram is always non-TOF +if proj.tof: + att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator( + proj.out_shape, att_sino + ) +else: + att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino) + +res_model = parallelproj.GaussianFilterOperator( + proj.in_shape, sigma=fwhm_data_mm / (2.35 * proj.voxel_size) +) + +# compose all 3 operators into a single linear operator +pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model)) + +# Simulation of projection data +# ----------------------------- +# +# We setup an arbitrary ground truth :math:`x_{true}` and simulate +# noise-free and noisy data :math:`y` by adding Poisson noise. + +# simulated noise-free data +noise_free_data = pet_lin_op(x_true) + +if true_counts > 0: + scale_fac = true_counts / float(xp.sum(noise_free_data)) + noise_free_data *= scale_fac + x_true *= scale_fac + +# generate a contant contamination sinogram +contamination = xp.full( + noise_free_data.shape, + contam_fraction * float(xp.mean(noise_free_data)), + device=dev, + dtype=xp.float32, +) + +noise_free_data += contamination + +# add Poisson noise +data = xp.asarray( + np.random.poisson(parallelproj.to_numpy_array(noise_free_data)), + device=dev, + dtype=xp.float32, +) + +# run quick OSEM with one iteration +pet_subset_lin_op_seq_osem, subset_slices_osem = split_fwd_model( + pet_lin_op, num_subsets_osem +) + +data_fidelity = SubsetNegPoissonLogLWithPrior( + data, pet_subset_lin_op_seq_osem, contamination, subset_slices_osem +) + +x0 = xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32) +osem_alg = OSEM(data_fidelity) +x_osem = osem_alg.run(x0, num_epochs_osem) + +#### post filter osem to get better initial recon +osem_filter = parallelproj.GaussianFilterOperator( + proj.in_shape, sigma=4.7 / (2.35 * proj.voxel_size) +) +x_init = osem_filter(x_osem) + +# setup of the cost function + +fwd_ones = pet_lin_op(xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32)) +fwd_osem = pet_lin_op(x_osem) + contamination +kappa_img = xp.sqrt(pet_lin_op.adjoint((data * fwd_ones) / (fwd_osem**2))) + +prior = RDP( + img_shape, + xp=xp, + dev=dev, + voxel_size=xp.asarray(voxel_size, device=dev), + eps=float(xp.max(x_osem)) / 100, + gamma=gamma_rdp, +) + +prior.kappa = kappa_img +prior.scale = beta + +adjoint_ones = pet_lin_op.adjoint( + xp.ones(pet_lin_op.out_shape, device=dev, dtype=xp.float32) +) + +pet_subset_lin_op_seq, subset_slices = split_fwd_model(pet_lin_op, num_subsets) + +cost_function = SubsetNegPoissonLogLWithPrior( + data, pet_subset_lin_op_seq, contamination, subset_slices, prior=prior +) + +# run L-BFGS-B without subsets as reference +x0_bfgs = to_device(x_init.ravel(), "cpu") + +bounds = x0.size * [(0, None)] + +ref_file = Path( + f"rdp_t_{true_counts:.2E}_b_{beta:.2E}_g_{gamma_rdp:.2E}_n_{num_iter_bfgs_ref}_nr_{num_rings}_tof_{tof}_cf_{contam_fraction}_s_{seed}.npy" +) + +if ref_file.exists(): + print("ref solution exists: load it") + x_ref = xp.asarray(np.load(ref_file), device=dev) +else: + print("ref solution does not exist: compute it") + res = fmin_l_bfgs_b( + cost_function, + x0_bfgs, + cost_function.gradient, + disp=True, + maxiter=num_iter_bfgs_ref, + bounds=bounds, + m=10, + factr=10.0, + ) + + x_ref = xp.asarray(res[0].reshape(img_shape), device=dev) + xp.save(ref_file, x_ref) + +cost_ref = cost_function(x_ref) + +x_osem_scale = float(xp.mean(x_init)) + +cost_osem = cost_function(x_osem) +nrmse_osem = xp.sqrt(xp.mean((x_ref - x_osem) ** 2)) / scale_fac +nrmse_init = xp.sqrt(xp.mean((x_ref - x_init) ** 2)) / scale_fac + +# %% +# SPD3O + +# Test a rough value for L +x1 = x_init +x1 = np.random.rand(*x1.shape) +x2 = 1.1 * x1 +x2 = np.random.rand(*x1.shape) +g1 = prior.gradient(x1) +g2 = prior.gradient(x2) + +Lest = np.linalg.norm(g1 - g2) / np.linalg.norm(x1 - x2) +print(Lest) + +L = 60 + +# %% +# SPD3O +# parameters +rhos = np.array([1.5]) # up to rho = 3 seems also to work for some gammas +# array of gamma values to try for SPDHG - these get divided by the "scale" of the OSEM image +gammas = np.array([0.8 / 5]) + +num_epochs = 20 + +run_spd3o = True + + +params = gammas + +if run_spd3o: + print("run spd3o") + + nrmse_spd3o = np.zeros((len(params), num_epochs * num_subsets), dtype=xp.float32) + + # list for all recons using different gamma values + x_spd3os = [] + + for ig, param in enumerate(params): + + rho = param + gamma = 1 / x_osem_scale + print("run gamma", gamma) + + # initialize primal and dual variables + x_spd3o = x_init.copy() + # initialize dual variable for the negative Poisson logL + y = 1 - data / (pet_lin_op(x_spd3o) + contamination) + + # y = 0*data + + # initialize z and zbar + z = pet_lin_op.adjoint(y) + # z = 0 * x_init + zbar = 1.0 * z + + # calculate SPHDG step sizes + S_As = [] + T_As = [] + + for lin_op in pet_subset_lin_op_seq: + tmp = lin_op(xp.ones(lin_op.in_shape, dtype=xp.float32, device=dev)) + # replace zeros by smallest non-zero value, to avoid division by zero + tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp) + S_As.append(gamma * rho / tmp) + + T_As.append( + (1 / (gamma * num_subsets)) + * rho + / lin_op.adjoint( + xp.ones(lin_op.out_shape, dtype=xp.float64, device=dev) + ) + ) + + # element-wise minimum of the data T's + T = xp.min(xp.asarray(T_As), axis=0) + + print(T.max() * L / 2) + + num_updates = num_epochs * num_subsets + + for i in range(num_updates): + subset = np.random.randint(num_subsets) + sl = subset_slices[subset] + + if i == 0: + grad_h = prior.gradient(x_spd3o) + + q = zbar + grad_h + x_spd3o = xp.clip(x_spd3o - T * q, 0, None) + + grad_h_new = prior.gradient(x_spd3o) + + xbar = x_spd3o + T * (grad_h - grad_h_new) + # xbar = x_spd3o + + grad_h = grad_h_new + + # forward step + y_plus = y[sl] + S_As[subset] * ( + pet_subset_lin_op_seq[subset](xbar) + contamination[sl] + ) + # prox of convex conjugate of negative Poisson logL + y_plus = 0.5 * ( + y_plus + 1 - xp.sqrt((y_plus - 1) ** 2 + 4 * S_As[subset] * data[sl]) + ) + delta_z = pet_subset_lin_op_seq[subset].adjoint(y_plus - y[sl]) + y[sl] = y_plus + + z = z + delta_z + zbar = z + num_subsets * delta_z + + nrmse_spd3o[ig, i] = xp.sqrt(xp.mean((x_ref - x_spd3o) ** 2)) / scale_fac + + if (i + 1) % num_subsets == 0: + print( + f"SPD3O epoch {((i+1)//num_subsets):04} / {num_epochs} NRMSE: {nrmse_spd3o[ig, i]:.2E}", + end="\r", + ) + + x_spd3os.append(x_spd3o) + + # %% + # SPD3O plots + + vmax = 1.2 * float(xp.max(x_true)) + sl = img_shape[2] // 2 + num_rows = 3 + num_cols = len(params) + 1 + + fig, ax = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), tight_layout=True + ) + ax[0, -1].imshow( + to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[0, -1].set_title(f"ref. (L-BFGS-B)", fontsize="medium") + ax[2, -1].set_axis_off() + + for ig, param in enumerate(params): + ax[0, ig].imshow( + to_device(x_spd3os[ig][:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[1, ig].imshow( + to_device( + (x_spd3os[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), + "cpu", + ), + cmap="seismic", + vmin=-0.2, + vmax=0.2, + ) + ax[2, ig].imshow( + to_device( + ((x_spd3os[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3)) + > 0.01, + "cpu", + ), + cmap="Greys", + ) + ax[0, ig].set_title(f"SPDHG, param {param}, {num_subsets}ss", fontsize="small") + ax[1, ig].set_title(f"rel. bias", fontsize="small") + ax[2, ig].set_title(f"rel. bias > 1%", fontsize="small") + + ax[1, -1].plot( + np.arange(num_updates) / num_subsets, + nrmse_spd3o[ig], + label=f"param {param}", + ) + + ax[1, -1].set_title(f"NRMSE", fontsize="medium") + ax[1, -1].set_xlabel(f"epoch") + ax[1, -1].axhline(nrmse_limit, color="black", ls="--") + ax[1, -1].legend() + ax[1, -1].set_ylim(0, nrmse_init) + ax[1, -1].grid(ls=":") + + fig.suptitle( + f"True counts {true_counts:.2E}, prior RDP, beta {beta:.2E}, seed {seed}" + ) + fig.savefig("fig_spdhg.png") + fig.show() + + +# %% +# Run stochastic gradient descent + +num_epochs_sgd = 5 +run_sgd = True +decrease_step_size = False +step_sizes = np.array([1e-1]) + +# epoch numbers (starting at 0) where preconditioner is updated (use [] for no updates) +precond_update_epochs = [ + 1, + 2, +] +# type of preconditioner (1: "P_logL": x / adjoint_ones, 2: 1 / (1/P_logL + 1/diag_Hess_RDP) +precond_version = 2 +step_size_decay_factor = 0.75 + + +num_updates_sgd = num_epochs_sgd * num_subsets + +init_precond = rdp_preconditioner( + x_init, + adjoint_ones, + prior, + precond_version, +) + + +if run_sgd: + print(f"cost init: {cost_osem}") + print(f"cost ref: {cost_ref}") + print(f"nrmse init: {nrmse_init}") + print() + + cost_sgd = np.zeros((len(step_sizes), num_updates_sgd)) + nrmse_sgd = np.zeros((len(step_sizes), num_updates_sgd)) + + x_sgds = [] + + # SGD + for i, step_size in enumerate(step_sizes): + print(f"SGD {i}, init step size: {step_size}") + sgd_alg = SGD(cost_function, x_init) + sgd_alg.step_size = step_size + sgd_alg.precond = init_precond + + x_cur = x_init.copy() + + for j in range(num_updates_sgd): + epoch = j // num_subsets + subset = j % num_subsets + + if (epoch in precond_update_epochs) and (subset == 0): + print(" updating preconditioner") + sgd_alg.precond = rdp_preconditioner( + x_cur, + adjoint_ones, + prior, + precond_version, + ) + + x_cur = sgd_alg.update(x_cur) + + if track_cost: + cost_sgd[i, j] = cost_function(x_cur) + nrmse_sgd[i, j] = xp.sqrt(xp.mean((x_ref - x_cur) ** 2)) / scale_fac + + if decrease_step_size and (subset == 0) and (j > 0): + sgd_alg.step_size *= step_size_decay_factor + print(f" decreasing step size {sgd_alg.step_size}") + + x_sgds.append(x_cur) + + # %% + # SGD plots + vmax = 1.2 * float(xp.max(x_true)) + sl = img_shape[2] // 2 + num_rows = 3 + num_cols = len(step_sizes) + 1 + + fig, ax = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), tight_layout=True + ) + ax[0, -1].imshow( + to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[0, -1].set_title(f"ref. (L-BFGS-B)", fontsize="medium") + ax[2, -1].set_axis_off() + + for ig, step_size in enumerate(step_sizes): + ax[0, ig].imshow( + to_device(x_sgds[ig][:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[1, ig].imshow( + to_device( + (x_sgds[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), + "cpu", + ), + cmap="seismic", + vmin=-0.2, + vmax=0.2, + ) + ax[2, ig].imshow( + to_device( + ((x_sgds[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3)) + > 0.01, + "cpu", + ), + cmap="Greys", + ) + ax[0, ig].set_title( + f"SGD, step size {step_size}, {num_subsets}ss", fontsize="small" + ) + ax[1, ig].set_title(f"rel. bias", fontsize="small") + ax[2, ig].set_title(f"rel. bias > 1%", fontsize="small") + + ax[1, -1].plot( + np.arange(num_updates_sgd) / num_subsets, + nrmse_sgd[ig], + label=f"step size {step_size}", + ) + + ax[1, -1].set_title(f"NRMSE", fontsize="medium") + ax[1, -1].set_xlabel(f"epoch") + ax[1, -1].axhline(nrmse_limit, color="black", ls="--") + ax[1, -1].legend() + ax[1, -1].set_ylim(0, nrmse_init) + ax[1, -1].grid(ls=":") + fig.suptitle( + f"True counts {true_counts:.2E}, prior RDP, beta {beta:.2E}, seed {seed}" + ) + fig.savefig("fig_sgd.png") + fig.show() + + +# %% +run_svrg = True +# update period for SVRG = epochs when all gradients are recalculated +svrg_gradient_recalc_periods = [x for x in range(0, num_epochs_sgd, 2)] + +step_sizes = np.array([3e-1]) + + +if run_svrg: + x_svrgs = [] + + cost_svrg = np.zeros((len(step_sizes), num_updates_sgd)) + nrmse_svrg = np.zeros((len(step_sizes), num_updates_sgd)) + + # SVRG + for i, step_size in enumerate(step_sizes): + print(f"SVRG {i}, init step size: {step_size}") + svrg_alg = SVRG(cost_function, x_init) + svrg_alg.step_size = step_size + svrg_alg.precond = init_precond + + x_cur = x_init.copy() + + for j in range(num_updates_sgd): + epoch = j // num_subsets + subset = j % num_subsets + + if subset == 0: + print(f" epoch {epoch}") + + line_search = False + if epoch == 0 and subset < 4 and step_size == 0: + line_search = True + + if (epoch in precond_update_epochs) and (subset == 0): + print(f" update {j}, updating preconditioner") + svrg_alg.precond = rdp_preconditioner( + x_cur, + adjoint_ones, + prior, + precond_version, + ) + + if (epoch in svrg_gradient_recalc_periods) and (subset == 0): + x_cur = svrg_alg.update( + x_cur, recalc_subset_gradients=True, line_search=line_search + ) + else: + x_cur = svrg_alg.update( + x_cur, recalc_subset_gradients=False, line_search=line_search + ) + + if track_cost: + cost_svrg[i, j] = cost_function(x_cur) + nrmse_svrg[i, j] = xp.sqrt(xp.mean((x_ref - x_cur) ** 2)) / scale_fac + + if decrease_step_size and (subset == 0) and (j > 0): + svrg_alg.step_size *= step_size_decay_factor + print(f" update {j}, decreasing step size {svrg_alg.step_size}") + + x_svrgs.append(x_cur) + + # %% + # SVRG plots + + vmax = 1.2 * float(xp.max(x_true)) + + num_rows = 3 + num_cols = len(step_sizes) + 1 + + sl = img_shape[2] // 2 + + fig, ax = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), tight_layout=True + ) + ax[0, -1].imshow( + to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[0, -1].set_title(f"ref. (L-BFGS-B)", fontsize="medium") + ax[2, -1].set_axis_off() + + for ig, step_size in enumerate(step_sizes): + ax[0, ig].imshow( + to_device(x_svrgs[ig][:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[1, ig].imshow( + to_device( + (x_svrgs[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), + "cpu", + ), + cmap="seismic", + vmin=-0.2, + vmax=0.2, + ) + ax[2, ig].imshow( + to_device( + ((x_svrgs[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3)) + > 0.01, + "cpu", + ), + cmap="Greys", + ) + ax[0, ig].set_title( + f"SVRG, step size {step_size}, {num_subsets}ss", fontsize="small" + ) + ax[1, ig].set_title(f"rel. bias", fontsize="small") + ax[2, ig].set_title(f"rel. bias > 1%", fontsize="small") + + ax[1, -1].plot( + np.arange(num_updates_sgd) / num_subsets, + nrmse_svrg[ig], + label=f"step size {step_size}", + ) + + ax[1, -1].set_title(f"NRMSE", fontsize="medium") + ax[1, -1].set_xlabel(f"epoch") + ax[1, -1].axhline(nrmse_limit, color="black", ls="--") + ax[1, -1].legend() + ax[1, -1].set_ylim(0, float(nrmse_init)) + ax[1, -1].grid(ls=":") + fig.suptitle( + f"True counts {true_counts:.2E}, prior RDP, beta {beta:.2E}, seed {seed}" + ) + fig.savefig(f"fig_svrg_{sl}.png") + fig.show() + + # %% + # compare + t = np.arange(num_updates_sgd) / num_subsets + plt.plot(t, nrmse_sgd[0], label=f"SGD") + plt.plot(t, nrmse_spd3o[0], label=f"SPD3O") + plt.plot(2 * t, nrmse_svrg[0], label=f"SVRG") + plt.xlabel(f"epoch") + plt.legend() + plt.ylim(0, float(nrmse_init)) + plt.grid(ls=":") + +# %% diff --git a/main.py b/main.py index 264f3d1..9c3f395 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ >>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) """ -import sirf.STIR as STIR +import math from cil.optimisation.algorithms import Algorithm from cil.optimisation.utilities import callbacks from petric import Dataset @@ -19,6 +19,16 @@ import numpy as np +def get_divisors(n): + """Returns a sorted list of all divisors of a positive integer n.""" + divisors = set() + for i in range(1, int(math.sqrt(n)) + 1): + if n % i == 0: + divisors.add(i) + divisors.add(n // i) + return sorted(divisors) + + class MaxIteration(callbacks.Callback): """ The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). @@ -48,10 +58,11 @@ class Submission(Algorithm): def __init__( self, data: Dataset, - num_subsets: int = 28, - update_objective_interval: int = 10, + approx_num_subsets: int = 28, + update_objective_interval: int | None = None, base_gamma: float = 1.0, rho: float = 1.0, + seed: int = 1, **kwargs ): """ @@ -59,6 +70,9 @@ def __init__( NB: in practice, `num_subsets` should likely be determined from the data. This is just an example. Try to modify and improve it! """ + + np.random.seed(seed) + self.acquisition_models = [] self.prompts = [] self.sensitivities = [] @@ -66,11 +80,11 @@ def __init__( self.x = data.OSEM_image.clone() self.prior = data.prior - self.num_subsets = num_subsets - - # find views in each subset - # (note that SIRF can currently only do subsets over views) - views = data.mult_factors.dimensions()[2] + num_views = data.mult_factors.dimensions()[2] + num_views_divisors = np.array(get_divisors(num_views)) + self.num_subsets = num_views_divisors[ + np.argmin(np.abs(num_views_divisors - approx_num_subsets)) + ] self.y = [] @@ -88,14 +102,13 @@ def __init__( self.rho = rho self.gamma = base_gamma / self.x.max() - Ts = [] self.S_As = [] ones_image = self.x.get_uniform_copy(1) Ts_np = np.zeros((self.num_subsets,) + self.x.shape) - for i in range(num_subsets): + for i in range(self.num_subsets): # we need to use the linear part of the acq. subset model # otherwise forward() includes already the additive term acqm = self.acquisition_models[i].get_linear_acquisition_model() @@ -108,7 +121,12 @@ def __init__( # calculate step sizes S_As tmp = acqm.forward(ones_image) - self.S_As.append(tmp.power(-1) * (self.gamma * self.rho)) + S_A = tmp.power(-1) * (self.gamma * self.rho) + + # clip inf values + max_S_A = S_A.as_array()[S_A.as_array() != np.inf].max() + S_A.minimum(max_S_A, out=S_A) + self.S_As.append(S_A) # calcualte Ts tmp2 = acqm.backward(ones_subset_sino) @@ -118,14 +136,33 @@ def __init__( self.T = self.x.get_uniform_copy(0) self.T.fill(Ts_np.min(0)) + # clip inf values + max_T = self.T.as_array()[self.T.as_array() != np.inf].max() + self.T.minimum(max_T, out=self.T) + + # derive FOV mask and multiply step size T with it + self.fov_mask = self.x.get_uniform_copy(0) + tmp = 1.0 * (data.OSEM_image.as_array() > 0) + self.fov_mask.fill(tmp) + self.T *= self.fov_mask self.zbar = self.z.clone() self.grad_h = None + self.subset_number_list = [] + + if update_objective_interval is None: + update_objective_interval = self.num_subsets + super().__init__(update_objective_interval=update_objective_interval, **kwargs) self.configured = True # required by Algorithm def update(self): + if self.subset_number_list == []: + self.create_subset_number_list() + + self.subset = self.subset_number_list.pop() + if self.grad_h is None: self.grad_h = self.prior.gradient(self.x) @@ -145,7 +182,7 @@ def update(self): ) # prox of convex conjugate of negative Poisson logL - tmp = (y_plus - 1) * (y_plus - 1) + 4 * self.S_As[self.subset] * self.data[ + tmp = (y_plus - 1) * (y_plus - 1) + 4 * self.S_As[self.subset] * self.prompts[ self.subset ] tmp.sqrt(out=tmp) @@ -155,10 +192,12 @@ def update(self): y_plus - self.y[self.subset] ) + self.y[self.subset] = y_plus + self.z = self.z + delta_z self.zbar = self.z + self.num_subsets * delta_z - self.subset = (self.subset + 1) % len(self.prompts) + print(self.x.min(), self.x.max()) def update_objective(self): """ @@ -167,5 +206,10 @@ def update_objective(self): """ return 0 + def create_subset_number_list(self): + tmp = np.arange(self.num_subsets) + np.random.shuffle(tmp) + self.subset_number_list = tmp.tolist() + -submission_callbacks = [MaxIteration(660)] +submission_callbacks = [MaxIteration(300)] diff --git a/main_svrg.py b/main_svrg.py new file mode 100644 index 0000000..2b6e5ac --- /dev/null +++ b/main_svrg.py @@ -0,0 +1,338 @@ +"""Main file to modify for submissions. + +Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows: + +>>> from main import Submission, submission_callbacks +>>> from petric import data, metrics +>>> algorithm = Submission(data) +>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) +""" + +import math +import sirf.STIR as STIR +from cil.optimisation.algorithms import Algorithm +from cil.optimisation.utilities import callbacks +from sirf.contrib.partitioner import partitioner + +import numpy as np +import array_api_compat.cupy as xp +from array_api_compat import to_device + +# import pure python re-implementation of the RDP -> only used to get diagonal of the RDP Hessian! +from rdp import RDP + +from petric import Dataset + + +def get_divisors(n): + """Returns a sorted list of all divisors of a positive integer n.""" + divisors = set() + for i in range(1, int(math.sqrt(n)) + 1): + if n % i == 0: + divisors.add(i) + divisors.add(n // i) + return sorted(divisors) + + +class MaxIteration(callbacks.Callback): + """ + The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). + This callback forces stopping after `max_iteration` instead. + """ + + def __init__(self, max_iteration: int, verbose: int = 1): + super().__init__(verbose) + self.max_iteration = max_iteration + + def __call__(self, algorithm: Algorithm): + if algorithm.iteration >= self.max_iteration: + raise StopIteration + + +class Submission(Algorithm): + """ + OSEM algorithm example. + NB: In OSEM, the multiplicative term cancels in the back-projection of the quotient of measured & estimated data + (so this is used here for efficiency). Note that a similar optimisation can be used for all algorithms using the Poisson log-likelihood. + NB: OSEM does not use `data.prior` and thus does not converge to the MAP reference used in PETRIC. + NB: this example does not use the `sirf.STIR` Poisson objective function. + NB: see https://github.com/SyneRBI/SIRF-Contribs/tree/master/src/Python/sirf/contrib/BSREM + """ + + def __init__( + self, + data: Dataset, + step_size_factor: float = 1.0, # multiplicative factor to increase / decrease default step sizes + approx_num_subsets: int = 25, # approximate number of subsets, closest divisor of num_views will be used + update_objective_interval: int | None = None, + complete_gradient_epochs: None | list[int] = None, + precond_update_epochs: None | list[int] = None, + precond_hessian_factor: float = 0.75, + precond_filter_fwhm_mm: float = 5.0, + verbose: bool = False, + seed: int = 1, + **kwargs, + ): + """ + Initialisation function, setting up data & (hyper)parameters. + NB: in practice, `num_subsets` should likely be determined from the data. + This is just an example. Try to modify and improve it! + """ + + np.random.seed(seed) + + self._verbose = verbose + self.subset = 0 + + # --- setup the number of subsets + + num_views = data.mult_factors.dimensions()[2] + num_views_divisors = np.array(get_divisors(num_views)) + self._num_subsets = num_views_divisors[ + np.argmin(np.abs(num_views_divisors - approx_num_subsets)) + ] + + if self._num_subsets not in num_views_divisors: + raise ValueError( + f"Number of subsets {self._num_subsets} is not a divisor of {num_views}. Divisors are {num_views_divisors}" + ) + + if self._verbose: + print(f"num_subsets: {self._num_subsets}") + + # --- setup the initial image as a slightly smoothed version of the OSEM image + self.x = data.OSEM_image.clone() + + self._update = 0 + self._step_size_factor = step_size_factor + self._step_size = self._step_size_factor * 2.0 + self._subset_number_list = [] + self._precond_hessian_factor = precond_hessian_factor + + self._data_sub, self._acq_models, self._subset_likelihood_funcs = ( + partitioner.data_partition( + data.acquired_data, + data.additive_term, + data.mult_factors, + self._num_subsets, + initial_image=data.OSEM_image, + mode="staggered", + ) + ) + + penalization_factor = data.prior.get_penalisation_factor() + + # WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations) + data.prior.set_penalisation_factor(penalization_factor / self._num_subsets) + data.prior.set_up(data.OSEM_image) + + self._subset_prior_fct = data.prior + + self._adjoint_ones = self.x.get_uniform_copy(0) + + for i in range(self._num_subsets): + if self._verbose: + print(f"Calculating subset {i} sensitivity") + self._adjoint_ones += self._subset_likelihood_funcs[ + i + ].get_subset_sensitivity(0) + + self._fov_mask = self.x.get_uniform_copy(0) + # tmp = 1.0 * (self._adjoint_ones.as_array() > 0) + tmp = 1.0 * (data.OSEM_image.as_array() > 0) + self._fov_mask.fill(tmp) + + # add a small number in the adjoint ones outside the FOV to avoid NaN in division + self._adjoint_ones += 1e-6 * (-self._fov_mask + 1.0) + + # initialize list / ImageData for all subset gradients and sum of gradients + self._summed_subset_gradients = self.x.get_uniform_copy(0) + self._subset_gradients = [] + + if complete_gradient_epochs is None: + self._complete_gradient_epochs: list[int] = [x for x in range(0, 1000, 2)] + else: + self._complete_gradient_epochs = complete_gradient_epochs + + if precond_update_epochs is None: + self._precond_update_epochs: list[int] = [1, 2, 3] + else: + self._precond_update_epochs = precond_update_epochs + + # setup python re-implementation of the RDP + # only used to get the diagonal of the RDP Hessian for preconditioning! + # (diag of RDP Hessian is not available in SIRF yet) + if "cupy" in xp.__name__: + self._dev = xp.cuda.Device(0) + else: + self._dev = "cpu" + + self._python_prior = RDP( + data.OSEM_image.shape, + xp, + self._dev, + xp.asarray(data.OSEM_image.spacing, device=self._dev), + eps=data.prior.get_epsilon(), + gamma=data.prior.get_gamma(), + ) + self._python_prior.kappa = xp.asarray( + data.kappa.as_array().astype(xp.float64), device=self._dev + ) + self._python_prior.scale = penalization_factor + + # small relative number for the preconditioner (to avoid zeros in the preconditioner) + self._precond_delta_rel = 0.0 # 1e-6 + + self._precond_filter = STIR.SeparableGaussianImageFilter() + self._precond_filter.set_fwhms( + [precond_filter_fwhm_mm, precond_filter_fwhm_mm, precond_filter_fwhm_mm] + ) + self._precond_filter.set_up(data.OSEM_image) + + # calculate the initial preconditioner based on the initial image + self._precond = self.calc_precond(self.x) + + if update_objective_interval is None: + update_objective_interval = self._num_subsets + + super().__init__(update_objective_interval=update_objective_interval, **kwargs) + self.configured = True # required by Algorithm + + @property + def epoch(self): + return self._update // self._num_subsets + + def update_step_size(self): + if self.epoch <= 4: + self._step_size = self._step_size_factor * 2.0 + elif self.epoch > 4 and self.epoch <= 8: + self._step_size = self._step_size_factor * 1.5 + elif self.epoch > 8 and self.epoch <= 12: + self._step_size = self._step_size_factor * 1.0 + else: + self._step_size = self._step_size_factor * 0.5 + + if self._verbose: + print(self._update, self.epoch, self._step_size) + + def calc_precond( + self, + x: STIR.ImageData, + ) -> STIR.ImageData: + + # generate a smoothed version of the input image + # to avoid high values, especially in first and last slices + x_sm = self._precond_filter.process(x) + + prior_diag_hess = x_sm.get_uniform_copy(0) + prior_diag_hess.fill( + to_device( + self._python_prior.diag_hessian( + xp.asarray(x_sm.as_array().astype(xp.float64), device=self._dev) + ), + "cpu", + ) + ) + + if self._precond_delta_rel > 0: + x_sm += self._precond_delta_rel * x_sm.max() + + precond = ( + self._fov_mask + * x_sm + / ( + self._adjoint_ones + + (self._precond_hessian_factor * 2) * prior_diag_hess * x_sm + ) + ) + + return precond + + def update_all_subset_gradients(self) -> None: + + self._summed_subset_gradients = self.x.get_uniform_copy(0) + self._subset_gradients = [] + + subset_prior_gradient = self._subset_prior_fct.gradient(self.x) + + # remember that the objective has to be maximized + # posterior = log likelihood - log prior ("minus" instead of "plus"!) + for i in range(self._num_subsets): + self._subset_gradients.append( + self._subset_likelihood_funcs[i].gradient(self.x) + - subset_prior_gradient + ) + self._summed_subset_gradients += self._subset_gradients[i] + + def update(self): + + update_all_subset_gradients = ( + self._update % self._num_subsets == 0 + ) and self.epoch in self._complete_gradient_epochs + + update_precond = ( + self._update % self._num_subsets == 0 + ) and self.epoch in self._precond_update_epochs + + if self._update % self._num_subsets == 0: + self.update_step_size() + + if update_precond: + if self._verbose: + print(f" {self._update}, updating preconditioner") + self._precond = self.calc_precond(self.x) + + if update_all_subset_gradients: + if self._verbose: + print( + f" {self._update}, {self.subset}, recalculating all subset gradients" + ) + self.update_all_subset_gradients() + approximated_gradient = self._summed_subset_gradients + else: + if self._subset_number_list == []: + self.create_subset_number_list() + + self.subset = self._subset_number_list.pop() + if self._verbose: + print(f" {self._update}, {self.subset}, subset gradient update") + + subset_prior_gradient = self._subset_prior_fct.gradient(self.x) + + # remember that the objective has to be maximized + # posterior = log likelihood - log prior ("minus" instead of "plus"!) + approximated_gradient = ( + self._num_subsets + * ( + ( + self._subset_likelihood_funcs[self.subset].gradient(self.x) + - subset_prior_gradient + ) + - self._subset_gradients[self.subset] + ) + + self._summed_subset_gradients + ) + + ### Objective has to be maximized -> "+" for gradient ascent + self.x = self.x + self._step_size * self._precond * approximated_gradient + + # enforce non-negative constraint + self.x.maximum(0, out=self.x) + + self._update += 1 + + def update_objective(self) -> None: + """ + NB: The objective value is not required by OSEM nor by PETRIC, so this returns `0`. + NB: It should be `sum(prompts * log(acq_model.forward(self.x)) - self.x * sensitivity)` across all subsets. + """ + + self.loss.append(0) + + def create_subset_number_list(self): + tmp = np.arange(self._num_subsets) + np.random.shuffle(tmp) + self._subset_number_list = tmp.tolist() + + +submission_callbacks = [MaxIteration(300)] diff --git a/test_petric.py b/test_petric.py index 02940ec..4fbcfe0 100644 --- a/test_petric.py +++ b/test_petric.py @@ -423,10 +423,14 @@ def test_petric(ds: int, num_iter: int, suffix: str = "", **kwargs): precond_update_epochs=precond_update_epochs, ) else: - for ns in [25]: - for i in [5]: - test_petric( - ds=i, - num_iter=200, - approx_num_subsets=ns, - ) + for rho in [1.0, 1.5, 0.5]: + for bg in [1e7, 1e6, 1e5]: + # for i in range(5): + for i in [0]: + test_petric( + ds=i, + num_iter=3 * 28, + base_gamma=bg, + rho=rho, + suffix=f"bg_{bg}_rho_{rho}_SPD3O", + )