Skip to content

Commit

Permalink
dwidenoise: Change code handling of estimator selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Lestropie committed Nov 5, 2024
1 parent 3c76457 commit 6ae5583
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions cmd/dwidenoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using namespace App;

const std::vector<std::string> dtypes = {"float32", "float64"};
const std::vector<std::string> estimators = {"exp1", "exp2"};
enum class estimator_type { EXP1, EXP2 };

const std::vector<std::string> shapes = {"cuboid", "sphere"};
enum class shape_type { CUBOID, SPHERE };
Expand Down Expand Up @@ -312,7 +313,7 @@ class KernelSphereRatio : public KernelSphereBase {
}
if (map_it == shared->end()) {
throw Exception( //
"Inadequate spherical kernel initialisation " //
std::string("Inadequate spherical kernel initialisation ") //
+ "(lookup table " + str(std::distance(shared->begin(), shared->end())) + "; " //
+ "min size " + str(min_size) + "; " //
+ "read size " + str(result.voxels.size()) + ")"); //
Expand Down Expand Up @@ -377,10 +378,10 @@ template <typename F> class DenoisingFunctor {
Image<real_type> &noise,
Image<uint16_t> &rank,
Image<uint16_t> &voxels,
bool exp1)
estimator_type estimator)
: kernel(kernel),
m(ndwi),
exp1(exp1),
estimator(estimator),
X(ndwi, kernel->estimated_size()),
XtX(std::min(m, kernel->estimated_size()), std::min(m, kernel->estimated_size())),
eig(std::min(m, kernel->estimated_size())),
Expand Down Expand Up @@ -445,7 +446,18 @@ template <typename F> class DenoisingFunctor {
{ // (as opposed to the paper where p is defined as the number of signal components)
double lam = std::max(s[p], 0.0) / q;
clam += lam;
double gam = double(p + 1) / (exp1 ? q : q - (r - p - 1));
double denominator;
switch (estimator) {
case estimator_type::EXP1:
denominator = q;
break;
case estimator_type::EXP2:
denominator = q - (r - p - 1);
break;
default:
assert(false);
}
double gam = double(p + 1) / denominator;
double sigsq1 = clam / double(p + 1);
double sigsq2 = (lam - lam_r) / (4.0 * std::sqrt(gam));
// sigsq2 > sigsq1 if signal else noise
Expand Down Expand Up @@ -493,7 +505,7 @@ template <typename F> class DenoisingFunctor {
private:
std::shared_ptr<KernelBase> kernel;
const ssize_t m;
const bool exp1;
const estimator_type estimator;
MatrixType X;
MatrixType XtX;
Eigen::SelfAdjointEigenSolver<MatrixType> eig;
Expand Down Expand Up @@ -522,14 +534,14 @@ void run(Header &data,
Image<uint16_t> &voxels,
const std::string &output_name,
std::shared_ptr<KernelBase> kernel,
bool exp1) {
estimator_type estimator) {
auto input = data.get_image<T>().with_direct_io(3);
// create output
Header header(data);
header.datatype() = DataType::from<T>();
auto output = Image<T>::create(output_name, header);
// run
DenoisingFunctor<T> func(data.size(3), kernel, mask, noise, rank, voxels, exp1);
DenoisingFunctor<T> func(data.size(3), kernel, mask, noise, rank, voxels, estimator);
ThreadedLoop("running MP-PCA denoising", data, 0, 3).run(func, input, output);
}

Expand All @@ -546,7 +558,10 @@ void run() {
check_dimensions(mask, dwi, 0, 3);
}

bool exp1 = get_option_value("estimator", 1) == 0; // default: Exp2 (unbiased estimator)
estimator_type estimator = estimator_type::EXP2; // default: Exp2 (unbiased estimator)
opt = get_options("estimator");
if (opt.size())
estimator = estimator_type(int(opt[0][0]));

Image<real_type> noise;
opt = get_options("noise");
Expand Down Expand Up @@ -638,19 +653,19 @@ void run() {
switch (prec) {
case 0:
INFO("select real float32 for processing");
run<float>(dwi, mask, noise, rank, voxels, argument[1], kernel, exp1);
run<float>(dwi, mask, noise, rank, voxels, argument[1], kernel, estimator);
break;
case 1:
INFO("select real float64 for processing");
run<double>(dwi, mask, noise, rank, voxels, argument[1], kernel, exp1);
run<double>(dwi, mask, noise, rank, voxels, argument[1], kernel, estimator);
break;
case 2:
INFO("select complex float32 for processing");
run<cfloat>(dwi, mask, noise, rank, voxels, argument[1], kernel, exp1);
run<cfloat>(dwi, mask, noise, rank, voxels, argument[1], kernel, estimator);
break;
case 3:
INFO("select complex float64 for processing");
run<cdouble>(dwi, mask, noise, rank, voxels, argument[1], kernel, exp1);
run<cdouble>(dwi, mask, noise, rank, voxels, argument[1], kernel, estimator);
break;
}
}

0 comments on commit 6ae5583

Please sign in to comment.