From 52204d77b2423cfd9d077df101d899963b873786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Onur=20=C3=9Clgen?= Date: Tue, 14 Nov 2023 15:27:41 +0000 Subject: [PATCH] Implement reg_getNmiValue for CUDA #92 --- niftyreg_build_version.txt | 2 +- reg-lib/cpu/_reg_nmi.cpp | 37 +---- reg-lib/cpu/_reg_nmi.h | 45 ++++-- reg-lib/cuda/_reg_nmi_gpu.cu | 259 +++++++++++++++++++++++++++++------ reg-lib/cuda/_reg_nmi_gpu.h | 6 + 5 files changed, 258 insertions(+), 91 deletions(-) diff --git a/niftyreg_build_version.txt b/niftyreg_build_version.txt index 2921a158..35329ed8 100644 --- a/niftyreg_build_version.txt +++ b/niftyreg_build_version.txt @@ -1 +1 @@ -360 +361 diff --git a/reg-lib/cpu/_reg_nmi.cpp b/reg-lib/cpu/_reg_nmi.cpp index f8d0d548..9e3801c1 100755 --- a/reg-lib/cpu/_reg_nmi.cpp +++ b/reg-lib/cpu/_reg_nmi.cpp @@ -164,37 +164,6 @@ void reg_nmi::InitialiseMeasure(nifti_image *refImg, NR_FUNC_CALLED(); } /* *************************************************************** */ -template -static PrecisionType GetBasisSplineValue(PrecisionType x) { - x = fabs(x); - PrecisionType value = 0; - if (x < 2.f) { - if (x < 1.f) - value = 2.f / 3.f + (0.5f * x - 1.f) * x * x; - else { - x -= 2.f; - value = -x * x * x / 6.f; - } - } - return value; -} -/* *************************************************************** */ -template -static PrecisionType GetBasisSplineDerivativeValue(PrecisionType ori) { - PrecisionType x = fabs(ori); - PrecisionType value = 0; - if (x < 2.f) { - if (x < 1.f) - value = (1.5f * x - 2.f) * ori; - else { - x -= 2.f; - value = -0.5f * x * x; - if (ori < 0) value = -value; - } - } - return value; -} -/* *************************************************************** */ template void reg_getNmiValue(const nifti_image *referenceImage, const nifti_image *warpedImage, @@ -261,9 +230,7 @@ void reg_getNmiValue(const nifti_image *referenceImage, } } // Convolve the histogram with a cubic B-spline kernel - double kernel[3]; - kernel[0] = kernel[2] = GetBasisSplineValue(-1.0); - kernel[1] = GetBasisSplineValue(0.0); + constexpr double kernel[3]{ GetBasisSplineValue(-1.0), GetBasisSplineValue(0.0), GetBasisSplineValue(-1.0) }; // Histogram is first smooth along the reference axis memset(jointHistoLogPtr, 0, totalBinNumber[t] * sizeof(double)); for (int f = 0; f < floatingBinNumber[t]; ++f) { @@ -361,8 +328,6 @@ void reg_getNmiValue(const nifti_image *referenceImage, } // if active time point } // iterate over all time point in the reference image } -template void reg_getNmiValue(const nifti_image*, const nifti_image*, const double*, const int, const unsigned short*, const unsigned short*, const unsigned short*, double**, double**, double**, const int*, const bool); -template void reg_getNmiValue(const nifti_image*, const nifti_image*, const double*, const int, const unsigned short*, const unsigned short*, const unsigned short*, double**, double**, double**, const int*, const bool); /* *************************************************************** */ static double GetSimilarityMeasureValue(const nifti_image *referenceImage, const nifti_image *warpedImage, diff --git a/reg-lib/cpu/_reg_nmi.h b/reg-lib/cpu/_reg_nmi.h index 91f37bdb..1c01ba91 100755 --- a/reg-lib/cpu/_reg_nmi.h +++ b/reg-lib/cpu/_reg_nmi.h @@ -87,20 +87,6 @@ class reg_nmi: public reg_measure { void DeallocateHistogram(); }; /* *************************************************************** */ -template -void reg_getNmiValue(const nifti_image *referenceImage, - const nifti_image *warpedImage, - const double *timePointWeight, - const int referenceTimePoints, - const unsigned short *referenceBinNumber, - const unsigned short *floatingBinNumber, - const unsigned short *totalBinNumber, - double **jointHistogramLog, - double **jointHistogramPro, - double **entropyValues, - const int *referenceMask, - const bool approximation); -/* *************************************************************** */ // Simple class to dynamically manage an array of pointers // Needed for multi channel NMI template @@ -283,3 +269,34 @@ void reg_getVoxelBasedMultiChannelNmiGradient3D(nifti_image *referenceImages, int *mask, bool approx); /* *************************************************************** */ +template +DEVICE constexpr PrecisionType GetBasisSplineValue(PrecisionType x) { + x = x < 0 ? -x : x; + PrecisionType value = 0; + if (x < 2.f) { + if (x < 1.f) + value = 2.f / 3.f + (0.5f * x - 1.f) * x * x; + else { + x -= 2.f; + value = -x * x * x / 6.f; + } + } + return value; +} +/* *************************************************************** */ +template +DEVICE constexpr PrecisionType GetBasisSplineDerivativeValue(const PrecisionType origX) { + PrecisionType x = origX < 0 ? -origX : origX; + PrecisionType value = 0; + if (x < 2.f) { + if (x < 1.f) + value = (1.5f * x - 2.f) * origX; + else { + x -= 2.f; + value = -0.5f * x * x; + if (origX < 0) value = -value; + } + } + return value; +} +/* *************************************************************** */ diff --git a/reg-lib/cuda/_reg_nmi_gpu.cu b/reg-lib/cuda/_reg_nmi_gpu.cu index 170c128e..f48fff8f 100755 --- a/reg-lib/cuda/_reg_nmi_gpu.cu +++ b/reg-lib/cuda/_reg_nmi_gpu.cu @@ -34,53 +34,228 @@ void reg_nmi_gpu::InitialiseMeasure(nifti_image *refImg, cudaArray *refImgCuda, nifti_image *warpedImgBw, float *warpedImgBwCuda, nifti_image *warpedGradBw, float4 *warpedGradBwCuda, nifti_image *voxelBasedGradBw, float4 *voxelBasedGradBwCuda) { - this->DeallocateHistogram(); reg_nmi::InitialiseMeasure(refImg, floImg, refMask, warpedImg, warpedGrad, voxelBasedGrad, localWeightSim, floMask, warpedImgBw, warpedGradBw, voxelBasedGradBw); reg_measure_gpu::InitialiseMeasure(refImg, refImgCuda, floImg, floImgCuda, refMask, refMaskCuda, activeVoxNum, warpedImg, warpedImgCuda, warpedGrad, warpedGradCuda, voxelBasedGrad, voxelBasedGradCuda, localWeightSim, localWeightSimCuda, floMask, floMaskCuda, warpedImgBw, warpedImgBwCuda, warpedGradBw, warpedGradBwCuda, voxelBasedGradBw, voxelBasedGradBwCuda); - // Check if the input images have multiple timepoints + // Check if the input images have multiple time points if (this->referenceTimePoints > 1 || this->floatingImage->nt > 1) - NR_FATAL_ERROR("Multiple timepoints are not yet supported"); + NR_FATAL_ERROR("Multiple time points are not yet supported"); // The reference and floating images have to be updated on the device Cuda::TransferNiftiToDevice(this->referenceImageCuda, this->referenceImage); Cuda::TransferNiftiToDevice(this->floatingImageCuda, this->floatingImage); + // Create the joint histograms + this->jointHistogramLogCudaVecs.resize(this->referenceTimePoints); + this->jointHistogramProCudaVecs.resize(this->referenceTimePoints); + if (this->isSymmetric) { + this->jointHistogramLogBwCudaVecs.resize(this->referenceTimePoints); + this->jointHistogramProBwCudaVecs.resize(this->referenceTimePoints); + } + for (int i = 0; i < this->referenceTimePoints; ++i) { + if (this->timePointWeights[i] > 0) { + this->jointHistogramLogCudaVecs[i].resize(this->totalBinNumber[i]); + this->jointHistogramProCudaVecs[i].resize(this->totalBinNumber[i]); + if (this->isSymmetric) { + this->jointHistogramLogBwCudaVecs[i].resize(this->totalBinNumber[i]); + this->jointHistogramProBwCudaVecs[i].resize(this->totalBinNumber[i]); + } + } + } NR_FUNC_CALLED(); } /* *************************************************************** */ -double GetSimilarityMeasureValue(const nifti_image *referenceImage, - nifti_image *warpedImage, - const float *warpedImageCuda, - const double *timePointWeights, - const unsigned short *referenceBinNumber, - const unsigned short *floatingBinNumber, - const unsigned short *totalBinNumber, - double **jointHistogramLog, - double **jointHistogramPro, - double **entropyValues, - const int *referenceMask, - const int referenceTimePoints, - const bool approximation) { - // TODO: Implement the NMI computation for CUDA - // The NMI computation is performed on the host for now - Cuda::TransferFromDeviceToNifti(warpedImage, warpedImageCuda); - reg_getNmiValue(referenceImage, - warpedImage, - timePointWeights, - referenceTimePoints, - referenceBinNumber, - floatingBinNumber, - totalBinNumber, - jointHistogramLog, - jointHistogramPro, - entropyValues, - referenceMask, - approximation); +void reg_getNmiValue_gpu(const nifti_image *referenceImage, + const cudaArray *referenceImageCuda, + const float *warpedImageCuda, + const double *timePointWeights, + const int referenceTimePoints, + const unsigned short *referenceBinNumber, + const unsigned short *floatingBinNumber, + const unsigned short *totalBinNumber, + vector>& jointHistogramLogCudaVecs, + vector>& jointHistogramProCudaVecs, + double **entropyValues, + const int *maskCuda, + const size_t activeVoxelNumber, + const bool approximation) { + const size_t voxelNumber = NiftiImage::calcVoxelNumber(referenceImage, 3); + const int3 referenceImageDims = make_int3(referenceImage->nx, referenceImage->ny, referenceImage->nz); + auto referenceImageTexturePtr = Cuda::CreateTextureObject(referenceImageCuda, cudaResourceTypeArray); + auto maskTexturePtr = Cuda::CreateTextureObject(maskCuda, cudaResourceTypeLinear, activeVoxelNumber * sizeof(int), + cudaChannelFormatKindSigned, 1); + auto referenceImageTexture = *referenceImageTexturePtr; + auto maskTexture = *maskTexturePtr; + + // Iterate over all active time points + for (int t = 0; t < referenceTimePoints; t++) { + if (timePointWeights[t] <= 0) continue; + NR_DEBUG("Computing NMI for time point " << t); + const auto& curTotalBinNumber = totalBinNumber[t]; + const auto& curRefBinNumber = referenceBinNumber[t]; + const auto& curFloBinNumber = floatingBinNumber[t]; + // Define the current histograms + thrust::fill(thrust::device, jointHistogramLogCudaVecs[t].begin(), jointHistogramLogCudaVecs[t].end(), 0.0); + thrust::fill(thrust::device, jointHistogramProCudaVecs[t].begin(), jointHistogramProCudaVecs[t].end(), 0.0); + double *jointHistogramLogCuda = jointHistogramLogCudaVecs[t].data().get(); + double *jointHistogramProCuda = jointHistogramProCudaVecs[t].data().get(); + // Define warped image texture + auto warpedImageTexturePtr = Cuda::CreateTextureObject(warpedImageCuda + t * voxelNumber, cudaResourceTypeLinear, + voxelNumber * sizeof(float), cudaChannelFormatKindFloat, 1); + auto warpedImageTexture = *warpedImageTexturePtr; + // Fill the joint histograms + if (approximation == false) { + // No approximation is used for the Parzen windowing + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), activeVoxelNumber, [=]__device__(const unsigned index) { + const int& voxel = tex1Dfetch(maskTexture, index); + const float& warValue = tex1Dfetch(warpedImageTexture, voxel); + if (warValue != warValue) return; + auto&& [x, y, z] = reg_indexToDims_cuda(voxel, referenceImageDims); + const float& refValue = tex3D(referenceImageTexture, x, y, z); + if (refValue != refValue) return; + for (int r = int(refValue - 1); r < int(refValue + 3); r++) { + if (0 <= r && r < curRefBinNumber) { + const double& refBasis = GetBasisSplineValue(refValue - r); + for (int w = int(warValue - 1); w < int(warValue + 3); w++) { + if (0 <= w && w < curFloBinNumber) { + const double& warBasis = GetBasisSplineValue(warValue - w); + atomicAdd(&jointHistogramProCuda[r + w * curRefBinNumber], refBasis * warBasis); + } + } + } + } + }); + } else { + // An approximation is used for the Parzen windowing. First intensities are binarised then + // the histogram is convolved with a spine kernel function. + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), activeVoxelNumber, [=]__device__(const unsigned index) { + const int& voxel = tex1Dfetch(maskTexture, index); + const float& warValue = tex1Dfetch(warpedImageTexture, voxel); + if (warValue != warValue) return; + auto&& [x, y, z] = reg_indexToDims_cuda(voxel, referenceImageDims); + const float& refValue = tex3D(referenceImageTexture, x, y, z); + if (refValue != refValue) return; + if (0 <= refValue && refValue < curRefBinNumber && 0 <= warValue && warValue < curFloBinNumber) + atomicAdd(&jointHistogramProCuda[int(refValue) + int(warValue) * curRefBinNumber], 1.0); + }); + // Convolve the histogram with a cubic B-spline kernel + // Histogram is first smooth along the reference axis + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), curFloBinNumber, [=]__device__(const unsigned short f) { + constexpr double kernel[3]{ GetBasisSplineValue(-1.0), GetBasisSplineValue(0.0), GetBasisSplineValue(-1.0) }; + for (unsigned short r = 0; r < curRefBinNumber; r++) { + double value = 0; + short index = r - 1; + double *histoPtr = &jointHistogramProCuda[index + curRefBinNumber * f]; + + for (char it = 0; it < 3; it++, index++, histoPtr++) + if (-1 < index && index < curRefBinNumber) + value += *histoPtr * kernel[it]; + jointHistogramLogCuda[r + curRefBinNumber * f] = value; + } + }); + // Histogram is then smooth along the warped floating axis + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), curRefBinNumber, [=]__device__(const unsigned short r) { + constexpr double kernel[3]{ GetBasisSplineValue(-1.0), GetBasisSplineValue(0.0), GetBasisSplineValue(-1.0) }; + for (unsigned short f = 0; f < curFloBinNumber; f++) { + double value = 0; + short index = f - 1; + double *histoPtr = &jointHistogramLogCuda[r + curRefBinNumber * index]; + + for (char it = 0; it < 3; it++, index++, histoPtr += curRefBinNumber) + if (-1 < index && index < curFloBinNumber) + value += *histoPtr * kernel[it]; + jointHistogramProCuda[r + curRefBinNumber * f] = value; + } + }); + } + // Normalise the histogram + const double& activeVoxel = thrust::reduce(thrust::device, jointHistogramProCudaVecs[t].begin(), jointHistogramProCudaVecs[t].end(), 0.0, thrust::plus()); + entropyValues[t][3] = activeVoxel; + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), curTotalBinNumber, [=]__device__(const unsigned index) { + jointHistogramProCuda[index] /= activeVoxel; + }); + // Marginalise over the reference axis + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), curRefBinNumber, [=]__device__(const unsigned short r) { + double sum = 0; + unsigned short index = r; + for (unsigned short f = 0; f < curFloBinNumber; f++, index += curRefBinNumber) + sum += jointHistogramProCuda[index]; + jointHistogramProCuda[curRefBinNumber * curFloBinNumber + r] = sum; + }); + // Marginalise over the warped floating axis + thrust::for_each_n(thrust::device, thrust::make_counting_iterator(0), curFloBinNumber, [=]__device__(const unsigned short f) { + double sum = 0; + unsigned short index = curRefBinNumber * f; + for (unsigned short r = 0; r < curRefBinNumber; r++, index++) + sum += jointHistogramProCuda[index]; + jointHistogramProCuda[curRefBinNumber * curFloBinNumber + curRefBinNumber + f] = sum; + }); + // Compute the entropy of the reference image + thrust::counting_iterator it(0); + entropyValues[t][0] = thrust::transform_reduce(thrust::device, it, it + curRefBinNumber, [=]__device__(const unsigned short r) { + const double& valPro = jointHistogramProCuda[curRefBinNumber * curFloBinNumber + r]; + if (valPro > 0) { + const double& valLog = log(valPro); + jointHistogramLogCuda[curRefBinNumber * curFloBinNumber + r] = valLog; + return -valPro * valLog; + } else return 0.0; + }, 0.0, thrust::plus()); + // Compute the entropy of the warped floating image + it = thrust::counting_iterator(0); + entropyValues[t][1] = thrust::transform_reduce(thrust::device, it, it + curFloBinNumber, [=]__device__(const unsigned short f) { + const double& valPro = jointHistogramProCuda[curRefBinNumber * curFloBinNumber + curRefBinNumber + f]; + if (valPro > 0) { + const double& valLog = log(valPro); + jointHistogramLogCuda[curRefBinNumber * curFloBinNumber + curRefBinNumber + f] = valLog; + return -valPro * valLog; + } else return 0.0; + }, 0.0, thrust::plus()); + // Compute the joint entropy + it = thrust::counting_iterator(0); + entropyValues[t][2] = thrust::transform_reduce(thrust::device, it, it + curRefBinNumber * curFloBinNumber, [=]__device__(const unsigned short index) { + const double& valPro = jointHistogramProCuda[index]; + if (valPro > 0) { + const double& valLog = log(valPro); + jointHistogramLogCuda[index] = valLog; + return -valPro * valLog; + } else return 0.0; + }, 0.0, thrust::plus()); + } // iterate over all time point in the reference image +} +/* *************************************************************** */ +static double GetSimilarityMeasureValue(const nifti_image *referenceImage, + const cudaArray *referenceImageCuda, + const nifti_image *warpedImage, + const float *warpedImageCuda, + const double *timePointWeights, + const int referenceTimePoints, + const unsigned short *referenceBinNumber, + const unsigned short *floatingBinNumber, + const unsigned short *totalBinNumber, + vector>& jointHistogramLogCudaVecs, + vector>& jointHistogramProCudaVecs, + double **entropyValues, + const int *referenceMaskCuda, + const size_t activeVoxelNumber, + const bool approximation) { + reg_getNmiValue_gpu(referenceImage, + referenceImageCuda, + warpedImageCuda, + timePointWeights, + referenceTimePoints, + referenceBinNumber, + floatingBinNumber, + totalBinNumber, + jointHistogramLogCudaVecs, + jointHistogramProCudaVecs, + entropyValues, + referenceMaskCuda, + activeVoxelNumber, + approximation); double nmi = 0; - for (int t = 0; t < referenceTimePoints; ++t) { + for (int t = 0; t < referenceTimePoints; t++) { if (timePointWeights[t] > 0) nmi += timePointWeights[t] * (entropyValues[t][0] + entropyValues[t][1]) / entropyValues[t][2]; } @@ -89,33 +264,37 @@ double GetSimilarityMeasureValue(const nifti_image *referenceImage, /* *************************************************************** */ double reg_nmi_gpu::GetSimilarityMeasureValueFw() { return ::GetSimilarityMeasureValue(this->referenceImage, + this->referenceImageCuda, this->warpedImage, this->warpedImageCuda, this->timePointWeights, + this->referenceTimePoints, this->referenceBinNumber, this->floatingBinNumber, this->totalBinNumber, - this->jointHistogramLog, - this->jointHistogramPro, + this->jointHistogramLogCudaVecs, + this->jointHistogramProCudaVecs, this->entropyValues, - this->referenceMask, - this->referenceTimePoints, + this->referenceMaskCuda, + this->activeVoxelNumber, this->approximatePw); } /* *************************************************************** */ double reg_nmi_gpu::GetSimilarityMeasureValueBw() { return ::GetSimilarityMeasureValue(this->floatingImage, + this->floatingImageCuda, this->warpedImageBw, this->warpedImageBwCuda, this->timePointWeights, + this->referenceTimePoints, this->floatingBinNumber, this->referenceBinNumber, this->totalBinNumber, - this->jointHistogramLogBw, - this->jointHistogramProBw, + this->jointHistogramLogBwCudaVecs, + this->jointHistogramProBwCudaVecs, this->entropyValuesBw, - this->floatingMask, - this->referenceTimePoints, + this->floatingMaskCuda, + this->activeVoxelNumber, this->approximatePw); } /* *************************************************************** */ diff --git a/reg-lib/cuda/_reg_nmi_gpu.h b/reg-lib/cuda/_reg_nmi_gpu.h index 51bc12a8..c3f33d4c 100755 --- a/reg-lib/cuda/_reg_nmi_gpu.h +++ b/reg-lib/cuda/_reg_nmi_gpu.h @@ -56,6 +56,12 @@ class reg_nmi_gpu: public reg_nmi, public reg_measure_gpu { virtual void GetVoxelBasedSimilarityMeasureGradientFw(int currentTimePoint) override; /// @brief Compute the voxel-based nmi gradient backwards virtual void GetVoxelBasedSimilarityMeasureGradientBw(int currentTimePoint) override; + +protected: + vector> jointHistogramLogCudaVecs; + vector> jointHistogramProCudaVecs; + vector> jointHistogramLogBwCudaVecs; + vector> jointHistogramProBwCudaVecs; }; /* *************************************************************** */ /// @brief NMI measure of similarity class