Skip to content

Commit

Permalink
Implement reg_getNmiValue for CUDA #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Nov 14, 2023
1 parent f4c3c15 commit 52204d7
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 91 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
360
361
37 changes: 1 addition & 36 deletions reg-lib/cpu/_reg_nmi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,37 +164,6 @@ void reg_nmi::InitialiseMeasure(nifti_image *refImg,
NR_FUNC_CALLED();
}
/* *************************************************************** */
template<class PrecisionType>
static PrecisionType GetBasisSplineValue(PrecisionType x) {
x = fabs(x);
PrecisionType value = 0;
if (x < 2.f) {
if (x < 1.f)
value = 2.f / 3.f + (0.5f * x - 1.f) * x * x;
else {
x -= 2.f;
value = -x * x * x / 6.f;
}
}
return value;
}
/* *************************************************************** */
template<class PrecisionType>
static PrecisionType GetBasisSplineDerivativeValue(PrecisionType ori) {
PrecisionType x = fabs(ori);
PrecisionType value = 0;
if (x < 2.f) {
if (x < 1.f)
value = (1.5f * x - 2.f) * ori;
else {
x -= 2.f;
value = -0.5f * x * x;
if (ori < 0) value = -value;
}
}
return value;
}
/* *************************************************************** */
template <class DataType>
void reg_getNmiValue(const nifti_image *referenceImage,
const nifti_image *warpedImage,
Expand Down Expand Up @@ -261,9 +230,7 @@ void reg_getNmiValue(const nifti_image *referenceImage,
}
}
// Convolve the histogram with a cubic B-spline kernel
double kernel[3];
kernel[0] = kernel[2] = GetBasisSplineValue(-1.0);
kernel[1] = GetBasisSplineValue(0.0);
constexpr double kernel[3]{ GetBasisSplineValue(-1.0), GetBasisSplineValue(0.0), GetBasisSplineValue(-1.0) };
// Histogram is first smooth along the reference axis
memset(jointHistoLogPtr, 0, totalBinNumber[t] * sizeof(double));
for (int f = 0; f < floatingBinNumber[t]; ++f) {
Expand Down Expand Up @@ -361,8 +328,6 @@ void reg_getNmiValue(const nifti_image *referenceImage,
} // if active time point
} // iterate over all time point in the reference image
}
template void reg_getNmiValue<float>(const nifti_image*, const nifti_image*, const double*, const int, const unsigned short*, const unsigned short*, const unsigned short*, double**, double**, double**, const int*, const bool);
template void reg_getNmiValue<double>(const nifti_image*, const nifti_image*, const double*, const int, const unsigned short*, const unsigned short*, const unsigned short*, double**, double**, double**, const int*, const bool);
/* *************************************************************** */
static double GetSimilarityMeasureValue(const nifti_image *referenceImage,
const nifti_image *warpedImage,
Expand Down
45 changes: 31 additions & 14 deletions reg-lib/cpu/_reg_nmi.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,6 @@ class reg_nmi: public reg_measure {
void DeallocateHistogram();
};
/* *************************************************************** */
template <class DataType>
void reg_getNmiValue(const nifti_image *referenceImage,
const nifti_image *warpedImage,
const double *timePointWeight,
const int referenceTimePoints,
const unsigned short *referenceBinNumber,
const unsigned short *floatingBinNumber,
const unsigned short *totalBinNumber,
double **jointHistogramLog,
double **jointHistogramPro,
double **entropyValues,
const int *referenceMask,
const bool approximation);
/* *************************************************************** */
// Simple class to dynamically manage an array of pointers
// Needed for multi channel NMI
template<class DataTYPE>
Expand Down Expand Up @@ -283,3 +269,34 @@ void reg_getVoxelBasedMultiChannelNmiGradient3D(nifti_image *referenceImages,
int *mask,
bool approx);
/* *************************************************************** */
template<class PrecisionType>
DEVICE constexpr PrecisionType GetBasisSplineValue(PrecisionType x) {
x = x < 0 ? -x : x;
PrecisionType value = 0;
if (x < 2.f) {
if (x < 1.f)
value = 2.f / 3.f + (0.5f * x - 1.f) * x * x;
else {
x -= 2.f;
value = -x * x * x / 6.f;
}
}
return value;
}
/* *************************************************************** */
template<class PrecisionType>
DEVICE constexpr PrecisionType GetBasisSplineDerivativeValue(const PrecisionType origX) {
PrecisionType x = origX < 0 ? -origX : origX;
PrecisionType value = 0;
if (x < 2.f) {
if (x < 1.f)
value = (1.5f * x - 2.f) * origX;
else {
x -= 2.f;
value = -0.5f * x * x;
if (origX < 0) value = -value;
}
}
return value;
}
/* *************************************************************** */
Loading

0 comments on commit 52204d7

Please sign in to comment.