From 855c1cf1a7b59f3b65056ebadb3df5dc7a6a8f19 Mon Sep 17 00:00:00 2001 From: "Georg Schramm (Windows)" Date: Mon, 15 Jul 2024 10:51:26 +0200 Subject: [PATCH] wip on spdhg test script --- simulations/spdhg_test.py | 562 +++++++++++++++++++++----------------- 1 file changed, 305 insertions(+), 257 deletions(-) diff --git a/simulations/spdhg_test.py b/simulations/spdhg_test.py index 2d2e074..3fdde6f 100644 --- a/simulations/spdhg_test.py +++ b/simulations/spdhg_test.py @@ -41,8 +41,30 @@ # %% # input parameters +run_spdhg = True +run_sgd = True +run_svrg = True + +# prior type quad or rdp - must be quad if run_spdhg is True +prior_type = "rdp" + +seed = 1 + +# true counts, reasonable range: 1e6, 1e7 (high counts), 1e5 (low counts) +true_counts = 1e6 +# regularization weight, reasonable range: 0.14, 0.014, 1.4 for RDP +# regularization weight, reasonable range: for quad 30 * 0.14 for 1e6, 3*0.14 for 1e7, 300*0.14 for 1e5 +if prior_type == "rdp": + beta = 0.14 +elif prior_type == "quad": + beta = 30 * (1e6 / true_counts) * 0.14 +else: + raise ValueError("invalid prior_type") +# RDP gamma parameter +gamma_rdp = 2.0 + # number of epochs / subsets for stochastic gradient algorithms -num_epochs_sgd = 10 +num_epochs_sgd = 20 num_subsets_sgd = 54 # decrease step in stochastic gradient algorithms after every epoch @@ -59,7 +81,17 @@ # update period for SVRG = epochs when all gradients are recalculated svrg_update_period = 2 # (initial) step sizes to try -step_sizes = [0.03, 0.1, 0.3, 0.5, 1.0] +step_sizes = [0.1, 0.3, 0.5, 1.0, 2.0] + +# max number of updates for reference L-BFGS-B solution +num_iter_bfgs_ref = 400 + +# SDPHG parameters +rho = 1.0 # up to rho = 3 seems also to work for some gammas +# array of gamma values to try for SPDHG - these get divided by the "scale" of the OSEM image +gammas = np.array([0.1, 0.3, 1.0, 3.0, 10]) + +# %% # number of rings of simulated PET scanner, should be odd in this example num_rings = 11 @@ -67,6 +99,8 @@ fwhm_data_mm = 4.5 # simulated TOF or non-TOF system tof = False +# mean of contamination sinogram, relative to mean of trues sinogram, reasonable range: 0.5 - 1.0 +contam_fraction = 0.5 # show the geometry of the scanner / image volume show_geometry = False # verbose output @@ -74,35 +108,18 @@ # track cost function values after every update (slow) track_cost = False -# true counts, reasonable range: 1e6, 1e7 (high counts), 1e5 (low counts) -true_counts = 1e6 -# mean of contamination sinogram, relative to mean of trues sinogram, reasonable range: 0.5 - 1.0 -contam_fraction = 0.5 - -# prior type rdp or quad -prior_type = "quad" - -# regularization weight, reasonable range: 0.14, 0.014, 1.4 -beta = 10000 * 0.14 -# RDP gamma parameter -gamma = 2.0 - # number of epochs / subsets for intial OSEM num_epochs_osem = 1 num_subsets_osem = 27 -# max number of updates for reference L-BFGS-B solution -num_iter_bfgs_ref = 400 - -run_sgd = True - -# SDPHG parameters -rho = 0.999 -gammas = np.array([0.1, 0.3, 1.0, 3.0, 10.0, 30.0]) +# %% # random seed -seed = 1 np.random.seed(seed) + +if run_spdhg and prior_type != "quad": + raise ValueError("SPDHG only works with quadratic prior for now") + # %% # Setup of the forward model :math:`\bar{y}(x) = A x + s` # -------------------------------------------------------- @@ -158,14 +175,14 @@ c1 = proj.in_shape[1] // 2 x_true[(c0 - 4) : (c0 + 4), (c1 - 4) : (c1 + 4), :] = 3.0 -x_true[20:24, c1 : (c1 + 4), :] = 5.0 +x_true[28:32, c1 : (c1 + 4), :] = 5.0 x_true[c0 : (c0 + 4), 20:24, :] = 5.0 -x_true[-24:-20, c1 : (c1 + 4), :] = 0.1 +x_true[-32:-28, c1 : (c1 + 4), :] = 0.1 x_true[c0 : (c0 + 4), -24:-20, :] = 0.1 -x_true[:10, :, :] = 0 -x_true[-10:, :, :] = 0 +x_true[:25, :, :] = 0 +x_true[-25:, :, :] = 0 x_true[:, :10, :] = 0 x_true[:, -10:, :] = 0 @@ -278,6 +295,10 @@ fwd_sum = FwdSum(img_shape, xp) fwd_mult = FwdMult(img_shape, xp) +fwd_ones = pet_lin_op(xp.ones(pet_lin_op.in_shape, device=dev, dtype=xp.float32)) +fwd_osem = pet_lin_op(x_osem) + contamination +kappa = xp.sqrt(pet_lin_op.adjoint((data * fwd_ones) / (fwd_osem**2))) + if prior_type == "rdp": prior = RDP( fwd_diff, @@ -285,7 +306,7 @@ eps=float(xp.max(x_osem)) / 100, xp=xp, dev=dev, - gamma=gamma, + gamma=gamma_rdp, ) elif prior_type == "quad": prior = QuadraticPrior( @@ -297,7 +318,7 @@ raise ValueError("invalid prior_type") -# prior.weights = xp.sqrt(fwd_mult(kappa)) +prior.weights = xp.sqrt(fwd_mult(kappa)) prior.scale = beta pet_subset_lin_op_seq, subset_slices = split_fwd_model(pet_lin_op, num_subsets_sgd) @@ -315,7 +336,7 @@ if prior_type == "rdp": ref_file = Path( - f"{prior_type}_t_{true_counts:.2E}_b_{beta:.2E}_g_{gamma:.2E}_n_{num_iter_bfgs_ref}_nr_{num_rings}_tof_{tof}_cf_{contam_fraction}_s_{seed}.npy" + f"{prior_type}_t_{true_counts:.2E}_b_{beta:.2E}_g_{gamma_rdp:.2E}_n_{num_iter_bfgs_ref}_nr_{num_rings}_tof_{tof}_cf_{contam_fraction}_s_{seed}.npy" ) elif prior_type == "quad": ref_file = Path( @@ -360,123 +381,178 @@ # %% # SPDHG -op_G = fwd_diff -nrmse_pdhg = np.zeros((len(gammas), num_epochs_sgd * num_subsets_sgd), dtype=xp.float32) - -for ig, gamma in enumerate(gammas / x_osem_scale): - # initialize primal and dual variables - x_spdhg = 1.0 * x_osem - # initialize dual variable for the negative Poisson logL - y = 1 - data / (pet_lin_op(x_spdhg) + contamination) +if run_spdhg: + op_G = fwd_diff + nrmse_pdhg = np.zeros( + (len(gammas), num_epochs_sgd * num_subsets_sgd), dtype=xp.float32 + ) - # initialize dual variable for the gradient - w = xp.zeros(op_G.out_shape, dtype=xp.float32, device=dev) + # list for all SPDHG recons using different gamma values + x_spdhgs = [] + + for ig, gamma in enumerate(gammas / x_osem_scale): + # initialize primal and dual variables + x_spdhg = 1.0 * x_osem + # initialize dual variable for the negative Poisson logL + y = 1 - data / (pet_lin_op(x_spdhg) + contamination) + + # initialize dual variable for the gradient + w = xp.zeros(op_G.out_shape, dtype=xp.float32, device=dev) + # w = beta * prior.weights * fwd_diff(x_osem) + + # initialize z and zbar + z = pet_lin_op.adjoint(y) + op_G.adjoint(w) + zbar = 1.0 * z + + # %% + # calculate SPHDG step sizes + S_As = [] + T_As = [] + + # probability that we do a data update + p_a = 0.5 / num_subsets_sgd + # probability that we do a prior update + p_g = 0.5 + + for lin_op in pet_subset_lin_op_seq: + tmp = lin_op(xp.ones(lin_op.in_shape, dtype=xp.float32, device=dev)) + # replace zeros by smallest non-zero value, to avoid division by zero + tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp) + S_As.append(gamma * rho / tmp) + + T_As.append( + (1 / gamma) + * rho + * p_a + / lin_op.adjoint( + xp.ones(lin_op.out_shape, dtype=xp.float64, device=dev) + ) + ) - # initialize z and zbar - z = pet_lin_op.adjoint(y) + op_G.adjoint(w) - zbar = 1.0 * z + # element-wise minimum of the data T's + T_A_min = xp.min(xp.asarray(T_As), axis=0) - # %% - # calculate SPHDG step sizes - S_As = [] - T_As = [] - - p_a = 0.5 / num_subsets_sgd - p_g = 0.5 - - for lin_op in pet_subset_lin_op_seq: - tmp = lin_op(xp.ones(lin_op.in_shape, dtype=xp.float32, device=dev)) - tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp) - S_As.append(gamma * rho / tmp) - - T_As.append( - (1 / gamma) - * rho - * p_a - / lin_op.adjoint(xp.ones(lin_op.out_shape, dtype=xp.float64, device=dev)) - ) + # norm of the fwd difference operator in 3D + op_G_norm = np.sqrt(12) + S_G = gamma * rho / op_G_norm + T_G = (1 / gamma) * rho * p_g / op_G_norm - T_A_min = xp.min(xp.asarray(T_As), axis=0) + # "clip" element-wise minimum of the data T's with the prior T + T = xp.where(T_A_min < T_G, T_A_min, xp.full(pet_lin_op.in_shape, T_G)) - op_G_norm = np.sqrt(12) - S_G = gamma * rho / op_G_norm - T_G = (1 / gamma) * rho * p_g / op_G_norm + # %% - T = xp.where(T_A_min < T_G, T_A_min, xp.full(pet_lin_op.in_shape, T_G)) + num_updates_spdhg = num_epochs_sgd * num_subsets_sgd - # %% + for i in range(num_updates_spdhg): - num_updates_spdhg = num_epochs_sgd * num_subsets_sgd + subset = np.random.randint(2 * num_subsets_sgd) - for i in range(num_updates_spdhg): + if subset < num_subsets_sgd: + sl = subset_slices[subset] - subset = np.random.randint(2 * num_subsets_sgd) + x_spdhg -= T * zbar + x_spdhg = xp.where(x_spdhg < 0, xp.zeros_like(x_spdhg), x_spdhg) - if subset < num_subsets_sgd: - sl = subset_slices[subset] + y_plus = y[sl] + S_As[subset] * ( + pet_subset_lin_op_seq[subset](x_spdhg) + contamination[sl] + ) + # prox of convex conjugate of negative Poisson logL + y_plus = 0.5 * ( + y_plus + + 1 + - xp.sqrt((y_plus - 1) ** 2 + 4 * S_As[subset] * data[sl]) + ) + delta_z = pet_subset_lin_op_seq[subset].adjoint(y_plus - y[sl]) + y[sl] = y_plus - x_spdhg -= T * zbar - x_spdhg = xp.where(x_spdhg < 0, xp.zeros_like(x_spdhg), x_spdhg) + z = z + delta_z + zbar = z + delta_z / p_a + else: + w_plus = w + S_G * op_G(x_spdhg) - y_plus = y[sl] + S_As[subset] * ( - pet_subset_lin_op_seq[subset](x_spdhg) + contamination[sl] - ) - # prox of convex conjugate of negative Poisson logL - y_plus = 0.5 * ( - y_plus + 1 - xp.sqrt((y_plus - 1) ** 2 + 4 * S_As[subset] * data[sl]) - ) - delta_z = pet_subset_lin_op_seq[subset].adjoint(y_plus - y[sl]) - y[sl] = y_plus + # prox of convex dual of weighted quad diff. prior + w_plus = (beta * prior.weights / (S_G + beta * prior.weights)) * w_plus - z = z + delta_z - zbar = z + delta_z / p_a - else: - w_plus = (w + S_G * op_G(x_spdhg)) / beta + delta_z = op_G.adjoint(w_plus - w) + w = w_plus - # prox of dual of the squared L2 norm - w_plus /= 1 + S_G / beta - w_plus *= beta + z = z + delta_z + zbar = z + delta_z / p_g - delta_z = op_G.adjoint(w_plus - w) - w = w_plus + nrmse_pdhg[ig, i] = xp.sqrt(xp.mean((x_ref - x_spdhg) ** 2)) / scale_fac - z = z + delta_z - zbar = z + delta_z / p_g + if (i + 1) % num_subsets_sgd == 0: + print( + f"SPDHG epoch {((i+1)//num_subsets_sgd):04} / {num_epochs_sgd} NRMSE: {nrmse_pdhg[ig, i]:.2E}", + end="\r", + ) - nrmse_pdhg[ig, i] = xp.sqrt(xp.mean((x_ref - x_spdhg) ** 2)) / scale_fac + x_spdhgs.append(x_spdhg) - if (i + 1) % num_subsets_sgd == 0: - print( - f"SPDHG epoch {((i+1)//num_subsets_sgd):04} / {num_epochs_sgd} NRMSE: {nrmse_pdhg[ig, i]:.2E}", - end="\r", - ) + # %% + # SPDHG plots -# %% -fig, ax = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) -ax[0].imshow( - to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), - cmap="Greys", - vmin=0, - vmax=1.2 * float(xp.max(x_true)), -) -ax[1].imshow( - to_device(x_spdhg[:, :, img_shape[2] // 2], "cpu"), - cmap="Greys", - vmin=0, - vmax=1.2 * float(xp.max(x_true)), -) + vmax = 1.2 * float(xp.max(x_true)) + sl = img_shape[2] // 2 + num_rows = 3 + num_cols = len(gammas) + 1 -for ig, gamma in enumerate(gammas): - ax[2].plot( - np.arange(num_updates_spdhg) / num_subsets_sgd, - nrmse_pdhg[ig], - label=f"gam {gamma}", + fig, ax = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), tight_layout=True ) -ax[2].axhline(0.01, color="black", ls="--") -ax[2].legend() -ax[2].set_ylim(0, nrmse_osem) -ax[2].grid(ls=":") -fig.show() + ax[0, -1].imshow( + to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[0, -1].set_title(f"ref. (L-BFGS-B)", fontsize="medium") + ax[2, -1].set_axis_off() + + for ig, gamma in enumerate(gammas): + ax[0, ig].imshow( + to_device(x_spdhgs[ig][:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[1, ig].imshow( + to_device( + (x_spdhgs[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), + "cpu", + ), + cmap="seismic", + vmin=-0.2, + vmax=0.2, + ) + ax[2, ig].imshow( + to_device( + ((x_spdhgs[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3)) + > 0.01, + "cpu", + ), + cmap="Greys", + ) + ax[0, ig].set_title( + f"SPDHG, rho {rho}, gam {gamma}, {num_subsets_sgd}ss", fontsize="small" + ) + ax[1, ig].set_title(f"rel. bias", fontsize="small") + ax[2, ig].set_title(f"rel. bias > 1%", fontsize="small") + + ax[1, -1].plot( + np.arange(num_updates_spdhg) / num_subsets_sgd, + nrmse_pdhg[ig], + label=f"gam {gamma}", + ) + + ax[1, -1].set_title(f"NRMSE", fontsize="medium") + ax[1, -1].set_xlabel(f"epoch") + ax[1, -1].axhline(0.01, color="black", ls="--") + ax[1, -1].legend() + ax[1, -1].set_ylim(0, nrmse_osem) + ax[1, -1].grid(ls=":") + fig.show() # %% @@ -536,6 +612,71 @@ x_sgds.append(x_cur) # %% + # SGD plots + vmax = 1.2 * float(xp.max(x_true)) + sl = img_shape[2] // 2 + num_rows = 3 + num_cols = len(gammas) + 1 + + fig, ax = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), tight_layout=True + ) + ax[0, -1].imshow( + to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[0, -1].set_title(f"ref. (L-BFGS-B)", fontsize="medium") + ax[2, -1].set_axis_off() + + for ig, step_size in enumerate(step_sizes): + ax[0, ig].imshow( + to_device(x_sgds[ig][:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[1, ig].imshow( + to_device( + (x_sgds[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), + "cpu", + ), + cmap="seismic", + vmin=-0.2, + vmax=0.2, + ) + ax[2, ig].imshow( + to_device( + ((x_sgds[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3)) + > 0.01, + "cpu", + ), + cmap="Greys", + ) + ax[0, ig].set_title( + f"SGD, step size {step_size}, {num_subsets_sgd}ss", fontsize="small" + ) + ax[1, ig].set_title(f"rel. bias", fontsize="small") + ax[2, ig].set_title(f"rel. bias > 1%", fontsize="small") + + ax[1, -1].plot( + np.arange(num_updates_sgd) / num_subsets_sgd, + nrmse_sgd[ig], + label=f"step size {step_size}", + ) + + ax[1, -1].set_title(f"NRMSE", fontsize="medium") + ax[1, -1].set_xlabel(f"epoch") + ax[1, -1].axhline(0.01, color="black", ls="--") + ax[1, -1].legend() + ax[1, -1].set_ylim(0, nrmse_osem) + ax[1, -1].grid(ls=":") + fig.show() + + +# %% +if run_svrg: x_svrgs = [] cost_svrg = np.zeros((len(step_sizes), num_updates_sgd)) @@ -585,158 +726,65 @@ x_svrgs.append(x_cur) # %% - # SGD plots + # SVRG plots + vmax = 1.2 * float(xp.max(x_true)) sl = img_shape[2] // 2 + num_rows = 3 + num_cols = len(gammas) + 1 - fig, ax = plt.subplots(4, 5, figsize=(15, 12), tight_layout=True) - - ax[0, 0].imshow(to_device(x_true[:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax) - ax[0, 1].imshow(to_device(x_osem[:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax) - ax[0, 2].imshow(to_device(x_ref[:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax) - - ax[0, 0].set_title(f"true", fontsize="medium") - ax[0, 1].set_title(f"OSEM {num_epochs_osem}/{num_subsets_osem}", fontsize="medium") - ax[0, 2].set_title("L-BFGS-B (ref)", fontsize="medium") + fig, ax = plt.subplots( + num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), tight_layout=True + ) + ax[0, -1].imshow( + to_device(x_ref[:, :, img_shape[2] // 2], "cpu"), + cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), + ) + ax[0, -1].set_title(f"ref. (L-BFGS-B)", fontsize="medium") + ax[2, -1].set_axis_off() - for i in range(5): - ax[1, i].imshow( - to_device(x_sgds[i][:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax - ) - ax[2, i].imshow( - to_device( - (x_sgds[i][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), - "cpu", - ), - cmap="seismic", - vmin=-0.2, - vmax=0.2, - ) - ax[3, i].imshow( - to_device( - ( - xp.abs(x_sgds[i][:, :, sl] - x_ref[:, :, sl]) - / (x_ref[:, :, sl] + 1e-3) - ) - > 0.05, - "cpu", - ), + for ig, step_size in enumerate(step_sizes): + ax[0, ig].imshow( + to_device(x_svrgs[ig][:, :, img_shape[2] // 2], "cpu"), cmap="Greys", + vmin=0, + vmax=1.2 * float(xp.max(x_true)), ) - ax[1, i].set_title( - f"SGD {num_epochs_sgd}/{num_subsets_sgd} {step_sizes[i]}", fontsize="medium" - ) - ax[2, i].set_title( - f"SGD {num_epochs_sgd}/{num_subsets_sgd} rel.bias", fontsize="medium" - ) - ax[3, i].set_title( - f"SGD {num_epochs_sgd}/{num_subsets_sgd} rel.bias > 5%", fontsize="medium" - ) - - for i, step_size in enumerate(step_sizes): - if track_cost: - ax[0, 3].plot( - np.arange(num_updates_sgd) / num_subsets_sgd, - cost_sgd[i], - label=f"s.s.: {step_size}", - ) - ax[0, 4].plot( - np.arange(num_updates_sgd) / num_subsets_sgd, - nrmse_sgd[i], - label=f"s.s.: {step_size}", - ) - ax[0, 4].axhline(0.01, color="black", ls="--") - - if track_cost: - ax[0, 3].set_ylim(cost_ref, cost_osem) - ax[0, 3].set_xlabel("epoch") - ax[0, 3].set_title("cost", fontsize="medium") - ax[0, 3].grid(ls=":") - else: - ax[0, 3].set_axis_off() - - ax[0, 4].set_ylim(0, nrmse_osem) - ax[0, 4].set_xlabel("epoch") - ax[0, 4].set_title("NRMSE", fontsize="medium") - ax[0, 4].grid(ls=":") - ax[0, 4].legend(fontsize="small") - - fig.show() - - # %% - # SVRG plots - - fig, ax = plt.subplots(4, 5, figsize=(15, 12), tight_layout=True) - - ax[0, 0].imshow(to_device(x_true[:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax) - ax[0, 1].imshow(to_device(x_osem[:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax) - ax[0, 2].imshow(to_device(x_ref[:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax) - - ax[0, 0].set_title(f"true", fontsize="medium") - ax[0, 1].set_title(f"OSEM {num_epochs_osem}/{num_subsets_osem}", fontsize="medium") - ax[0, 2].set_title("L-BFGS-B (ref)", fontsize="medium") - - for i in range(5): - ax[1, i].imshow( - to_device(x_svrgs[i][:, :, sl], "cpu"), cmap="Greys", vmin=0, vmax=vmax - ) - ax[2, i].imshow( + ax[1, ig].imshow( to_device( - (x_svrgs[i][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), + (x_svrgs[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3), "cpu", ), cmap="seismic", vmin=-0.2, vmax=0.2, ) - ax[3, i].imshow( + ax[2, ig].imshow( to_device( - ( - xp.abs(x_svrgs[i][:, :, sl] - x_ref[:, :, sl]) - / (x_ref[:, :, sl] + 1e-3) - ) - > 0.05, + ((x_svrgs[ig][:, :, sl] - x_ref[:, :, sl]) / (x_ref[:, :, sl] + 1e-3)) + > 0.01, "cpu", ), cmap="Greys", ) - ax[1, i].set_title( - f"SVRG {num_epochs_sgd}/{num_subsets_sgd} {step_sizes[i]}", - fontsize="medium", - ) - ax[2, i].set_title( - f"SVRG {num_epochs_sgd}/{num_subsets_sgd} rel.bias", fontsize="medium" - ) - ax[3, i].set_title( - f"SVRG {num_epochs_sgd}/{num_subsets_sgd} rel.bias > 5%", fontsize="medium" + ax[0, ig].set_title( + f"SVRG, step size {step_size}, {num_subsets_sgd}ss", fontsize="small" ) + ax[1, ig].set_title(f"rel. bias", fontsize="small") + ax[2, ig].set_title(f"rel. bias > 1%", fontsize="small") - for i, step_size in enumerate(step_sizes): - if track_cost: - ax[0, 3].plot( - np.arange(num_updates_sgd) / num_subsets_sgd, - cost_svrg[i], - label=f"s.s.: {step_size}", - ) - ax[0, 4].plot( + ax[1, -1].plot( np.arange(num_updates_sgd) / num_subsets_sgd, - nrmse_svrg[i], - label=f"s.s.: {step_size}", + nrmse_svrg[ig], + label=f"step size {step_size}", ) - ax[0, 4].axhline(0.01, color="black", ls="--") - - if track_cost: - ax[0, 3].set_ylim(cost_ref, cost_osem) - ax[0, 3].set_xlabel("epoch") - ax[0, 3].set_title("cost", fontsize="medium") - ax[0, 3].grid(ls=":") - else: - ax[0, 3].set_axis_off() - - ax[0, 4].set_ylim(0, nrmse_osem) - ax[0, 4].set_xlabel("epoch") - ax[0, 4].set_title("NRMSE", fontsize="medium") - ax[0, 4].grid(ls=":") - ax[0, 4].legend(fontsize="small") + ax[1, -1].set_title(f"NRMSE", fontsize="medium") + ax[1, -1].set_xlabel(f"epoch") + ax[1, -1].axhline(0.01, color="black", ls="--") + ax[1, -1].legend() + ax[1, -1].set_ylim(0, nrmse_osem) + ax[1, -1].grid(ls=":") fig.show()