From 4424cb48890010b91ca1262c87e0852a34a7aacb Mon Sep 17 00:00:00 2001 From: Robert Smith Date: Mon, 4 Nov 2024 15:50:43 +1100 Subject: [PATCH] dwidenoise: Modularise kernel Better separation of code responsible for fetching a batch of input data within a sliding spatial window from the code responsible for the denoising of the image data. --- cmd/dwidenoise.cpp | 218 ++++++++++++++++++++++++++++----------------- 1 file changed, 136 insertions(+), 82 deletions(-) diff --git a/cmd/dwidenoise.cpp b/cmd/dwidenoise.cpp index 20d9242cd4..100c362007 100644 --- a/cmd/dwidenoise.cpp +++ b/cmd/dwidenoise.cpp @@ -144,29 +144,98 @@ void usage() { using real_type = float; -template class DenoisingFunctor { +// Class to encode return information from kernel +template class KernelData { +public: + KernelData(const size_t volumes, const size_t kernel_size) + : centre_index(-1), // + voxel_count(kernel_size), // + X(MatrixType::Zero(volumes, kernel_size)) {} // + size_t centre_index; + size_t voxel_count; + MatrixType X; +}; + +template class KernelBase { +public: + KernelBase() : pos({-1, -1, -1}) {} + KernelBase(const KernelBase &) : pos({-1, -1, -1}) {} + +protected: + // Store / restore position of image before / after data loading + std::array pos; + template void stash_pos(const ImageType &image) { + for (size_t axis = 0; axis != 3; ++axis) + pos[axis] = image.index(axis); + } + template void restore_pos(ImageType &image) { + for (size_t axis = 0; axis != 3; ++axis) + image.index(axis) = pos[axis]; + } +}; + +template class KernelCube : public KernelBase { +public: + KernelCube(const std::vector &extent) + : half_extent({int(extent[0] / 2), int(extent[1] / 2), int(extent[2] / 2)}) { + for (auto e : extent) { + if (!(e % 2)) + throw Exception("Size of cubic kernel must be an odd integer"); + } + } + KernelCube(const KernelCube &) = default; + template void operator()(ImageType &image, KernelData &data) { + assert(data.X.cols() == size()); + KernelBase::stash_pos(image); + size_t k = 0; + for (int z = -half_extent[2]; z <= half_extent[2]; z++) { + image.index(2) = wrapindex(z, 2, image.size(2)); + for (int y = -half_extent[1]; y <= half_extent[1]; y++) { + image.index(1) = wrapindex(y, 1, image.size(1)); + for (int x = -half_extent[0]; x <= half_extent[0]; x++, k++) { + image.index(0) = wrapindex(x, 0, image.size(0)); + data.X.col(k) = image.row(3); + } + } + } + KernelBase::restore_pos(image); + data.voxel_count = size(); + data.centre_index = size() / 2; + } + size_t size() const { return (2 * half_extent[0] + 1) * (2 * half_extent[1] + 1) * (2 * half_extent[2] + 1); } + +private: + const std::vector half_extent; + + // patch handling at image edges + inline size_t wrapindex(int r, int axis, int max) const { + int rr = KernelBase::pos[axis] + r; + if (rr < 0) + rr = half_extent[axis] - r; + if (rr >= max) + rr = (max - 1) - half_extent[axis] - r; + return rr; + } +}; + +template class DenoisingFunctor { public: using MatrixType = Eigen::Matrix; using SValsType = Eigen::VectorXd; - DenoisingFunctor(int ndwi, - const std::vector &extent, - Image &mask, - Image &noise, - Image &rank, - bool exp1) - : extent{{extent[0] / 2, extent[1] / 2, extent[2] / 2}}, + DenoisingFunctor( + int ndwi, KernelType &kernel, Image &mask, Image &noise, Image &rank, bool exp1) + : data(ndwi, kernel.size()), + kernel(kernel), m(ndwi), - n(extent[0] * extent[1] * extent[2]), + n(kernel.size()), r(std::min(m, n)), q(std::max(m, n)), exp1(exp1), - X(m, n), XtX(r, r), eig(r), s(r), - pos{{0, 0, 0}}, mask(mask), noise(noise), rankmap(rank) {} @@ -180,13 +249,13 @@ template class DenoisingFunctor { } // Load data in local window - load_data(dwi); + kernel(dwi, data); // Compute Eigendecomposition: if (m <= n) - XtX.template triangularView() = X * X.adjoint(); + XtX.template triangularView() = data.X * data.X.adjoint(); else - XtX.template triangularView() = X.adjoint() * X; + XtX.template triangularView() = data.X.adjoint() * data.X; eig.compute(XtX); // eigenvalues sorted in increasing order: s = eig.eigenvalues().template cast(); @@ -215,14 +284,18 @@ template class DenoisingFunctor { s.head(cutoff_p).setZero(); s.tail(r - cutoff_p).setOnes(); if (m <= n) - X.col(n / 2) = eig.eigenvectors() * (s.cast().asDiagonal() * (eig.eigenvectors().adjoint() * X.col(n / 2))); + data.X.col(data.centre_index) = + eig.eigenvectors() * + (s.cast().asDiagonal() * (eig.eigenvectors().adjoint() * data.X.col(data.centre_index))); else - X.col(n / 2) = X * (eig.eigenvectors() * (s.cast().asDiagonal() * eig.eigenvectors().adjoint().col(n / 2))); + data.X.col(data.centre_index) = + data.X * + (eig.eigenvectors() * (s.cast().asDiagonal() * eig.eigenvectors().adjoint().col(data.centre_index))); } // Store output assign_pos_of(dwi).to(out); - out.row(3) = X.col(n / 2); + out.row(3) = data.X.col(data.centre_index); // store noise map if requested: if (noise.valid()) { @@ -237,60 +310,26 @@ template class DenoisingFunctor { } private: - const std::array extent; + KernelData data; + KernelType kernel; const ssize_t m, n, r, q; const bool exp1; - MatrixType X; MatrixType XtX; Eigen::SelfAdjointEigenSolver eig; SValsType s; - std::array pos; double sigma2; Image mask; Image noise; Image rankmap; - - template void load_data(ImageType &dwi) { - pos[0] = dwi.index(0); - pos[1] = dwi.index(1); - pos[2] = dwi.index(2); - // fill patch - X.setZero(); - size_t k = 0; - for (int z = -extent[2]; z <= extent[2]; z++) { - dwi.index(2) = wrapindex(z, 2, dwi.size(2)); - for (int y = -extent[1]; y <= extent[1]; y++) { - dwi.index(1) = wrapindex(y, 1, dwi.size(1)); - for (int x = -extent[0]; x <= extent[0]; x++, k++) { - dwi.index(0) = wrapindex(x, 0, dwi.size(0)); - X.col(k) = dwi.row(3); - } - } - } - // reset image position - dwi.index(0) = pos[0]; - dwi.index(1) = pos[1]; - dwi.index(2) = pos[2]; - } - - inline size_t wrapindex(int r, int axis, int max) const { - // patch handling at image edges - int rr = pos[axis] + r; - if (rr < 0) - rr = extent[axis] - r; - if (rr >= max) - rr = (max - 1) - extent[axis] - r; - return rr; - } }; -template +template void process_image(Header &data, Image &mask, Image &noise, Image &rank, const std::string &output_name, - const std::vector &extent, + KernelType &kernel, bool exp1) { auto input = data.get_image().with_direct_io(3); // create output @@ -298,24 +337,20 @@ void process_image(Header &data, header.datatype() = DataType::from(); auto output = Image::create(output_name, header); // run - DenoisingFunctor func(data.size(3), extent, mask, noise, rank, exp1); + DenoisingFunctor func(data.size(3), kernel, mask, noise, rank, exp1); ThreadedLoop("running MP-PCA denoising", data, 0, 3).run(func, input, output); } -void run() { - auto dwi = Header::open(argument[0]); - - if (dwi.ndim() != 4 || dwi.size(3) <= 1) - throw Exception("input image must be 4-dimensional"); - - Image mask; - auto opt = get_options("mask"); - if (!opt.empty()) { - mask = Image::open(opt[0][0]); - check_dimensions(mask, dwi, 0, 3); - } - - opt = get_options("extent"); +template +void make_kernel(Header &data, + Image &mask, + Image &noise, + Image &rank, + const std::string &output_name, + bool exp1) { + using KernelType = KernelCube>; + + auto opt = get_options("extent"); std::vector extent; if (!opt.empty()) { extent = parse_ints(opt[0][0]); @@ -326,25 +361,44 @@ void run() { for (int i = 0; i < 3; i++) { if (!(extent[i] & 1)) throw Exception("-extent must be a (list of) odd numbers"); - if (extent[i] > dwi.size(i)) + if (extent[i] > data.size(i)) throw Exception("-extent must not exceed the image dimensions"); } } else { uint32_t e = 1; - while (e * e * e < dwi.size(3)) + while (Math::pow3(e) < data.size(3)) e += 2; - extent = { - std::min(e, uint32_t(dwi.size(0))), std::min(e, uint32_t(dwi.size(1))), std::min(e, uint32_t(dwi.size(2)))}; + extent = {std::min(e, uint32_t(data.size(0))), // + std::min(e, uint32_t(data.size(1))), // + std::min(e, uint32_t(data.size(2)))}; // } INFO("selected patch size: " + str(extent[0]) + " x " + str(extent[1]) + " x " + str(extent[2]) + "."); - bool exp1 = get_option_value("estimator", 1) == 0; // default: Exp2 (unbiased estimator) + if (std::min(data.size(3), extent[0] * extent[1] * extent[2]) < 15) { + WARN("The number of volumes or the patch size is small. " + "This may lead to discretisation effects in the noise level " + "and cause inconsistent denoising between adjacent voxels."); + } + + KernelType kernel(extent); + process_image(data, mask, noise, rank, output_name, kernel, exp1); +} - if (std::min(dwi.size(3), extent[0] * extent[1] * extent[2]) < 15) { - WARN("The number of volumes or the patch size is small. This may lead to discretisation effects " - "in the noise level and cause inconsistent denoising between adjacent voxels."); +void run() { + auto dwi = Header::open(argument[0]); + + if (dwi.ndim() != 4 || dwi.size(3) <= 1) + throw Exception("input image must be 4-dimensional"); + + Image mask; + auto opt = get_options("mask"); + if (!opt.empty()) { + mask = Image::open(opt[0][0]); + check_dimensions(mask, dwi, 0, 3); } + bool exp1 = get_option_value("estimator", 1) == 0; // default: Exp2 (unbiased estimator) + Image noise; opt = get_options("noise"); if (!opt.empty()) { @@ -370,19 +424,19 @@ void run() { switch (prec) { case 0: INFO("select real float32 for processing"); - process_image(dwi, mask, noise, rank, argument[1], extent, exp1); + make_kernel(dwi, mask, noise, rank, argument[1], exp1); break; case 1: INFO("select real float64 for processing"); - process_image(dwi, mask, noise, rank, argument[1], extent, exp1); + make_kernel(dwi, mask, noise, rank, argument[1], exp1); break; case 2: INFO("select complex float32 for processing"); - process_image(dwi, mask, noise, rank, argument[1], extent, exp1); + make_kernel(dwi, mask, noise, rank, argument[1], exp1); break; case 3: INFO("select complex float64 for processing"); - process_image(dwi, mask, noise, rank, argument[1], extent, exp1); + make_kernel(dwi, mask, noise, rank, argument[1], exp1); break; } }