Skip to content

Commit

Permalink
Merge pull request #1349 from ANTsX/UseGradientFilter
Browse files Browse the repository at this point in the history
ENH:  Expose UseGradientFilter variable for image metrics.
  • Loading branch information
cookpa authored May 4, 2022
2 parents e9c4cc0 + 4a37a40 commit ec91afc
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 27 deletions.
15 changes: 8 additions & 7 deletions Examples/antsRegistration.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ antsRegistrationInitializeCommandLineOptions(itk::ants::CommandLineParser * pars
std::string("one sample per voxel), otherwise it defines a point set over which to optimize the metric. ") +
std::string("The point set can be on a regular lattice or a random lattice of points slightly ") +
std::string("perturbed to minimize aliasing artifacts. samplingPercentage defines the ") +
std::string("fraction of points to select from the domain. ") +
std::string("fraction of points to select from the domain. useGradientFilter specifies whether a smoothing") +
std::string("filter is applied when estimating the metric gradient.") +
std::string("In addition, three point set metrics are available: Euclidean ") +
std::string("(ICP), Point-set expectation (PSE), and Jensen-Havrda-Charvet-Tsallis (JHCT).");

Expand All @@ -318,22 +319,22 @@ antsRegistrationInitializeCommandLineOptions(itk::ants::CommandLineParser * pars
option->SetShortName('m');
option->SetUsageOption(0,
"CC[fixedImage,movingImage,metricWeight,radius,<samplingStrategy={None,Regular,Random}>,<"
"samplingPercentage=[0,1]>]");
"samplingPercentage=[0,1]>,<useGradientFilter=false>]");
option->SetUsageOption(1,
"MI[fixedImage,movingImage,metricWeight,numberOfBins,<samplingStrategy={None,Regular,Random}"
">,<samplingPercentage=[0,1]>]");
">,<samplingPercentage=[0,1]>,<useGradientFilter=false>]");
option->SetUsageOption(2,
"Mattes[fixedImage,movingImage,metricWeight,numberOfBins,<samplingStrategy={None,Regular,"
"Random}>,<samplingPercentage=[0,1]>]");
"Random}>,<samplingPercentage=[0,1]>,<useGradientFilter=false>]");
option->SetUsageOption(3,
"MeanSquares[fixedImage,movingImage,metricWeight,radius=NA,<samplingStrategy={None,Regular,"
"Random}>,<samplingPercentage=[0,1]>]");
"Random}>,<samplingPercentage=[0,1]>,<useGradientFilter=false>]");
option->SetUsageOption(4,
"Demons[fixedImage,movingImage,metricWeight,radius=NA,<samplingStrategy={None,Regular,"
"Random}>,<samplingPercentage=[0,1]>]");
"Random}>,<samplingPercentage=[0,1]>,<useGradientFilter=false>]");
option->SetUsageOption(5,
"GC[fixedImage,movingImage,metricWeight,radius=NA,<samplingStrategy={None,Regular,Random}>,<"
"samplingPercentage=[0,1]>]");
"samplingPercentage=[0,1]>,<useGradientFilter=false>]");
option->SetUsageOption(
6, "ICP[fixedPointSet,movingPointSet,metricWeight,<samplingPercentage=[0,1]>,<boundaryPointsOnly=0>]");
option->SetUsageOption(7,
Expand Down
8 changes: 8 additions & 0 deletions Examples/antsRegistrationTemplateHeader.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ DoRegistration(typename ParserType::Pointer & parser)
typename RegistrationHelperType::SamplingStrategy samplingStrategy = RegistrationHelperType::none;
unsigned int numberOfBins = 32;
unsigned int radius = 4;
bool useGradientFilter = false;

// assign default point-set variables

Expand All @@ -993,6 +994,12 @@ DoRegistration(typename ParserType::Pointer & parser)
{
samplingPercentage = parser->Convert<float>(metricOption->GetFunction(currentMetricNumber)->GetParameter(5));
}

if (metricOption->GetFunction(currentMetricNumber)->GetNumberOfParameters() > 6)
{
useGradientFilter = parser->Convert<bool>(metricOption->GetFunction(currentMetricNumber)->GetParameter(6));
}

std::string fixedFileName = metricOption->GetFunction(currentMetricNumber)->GetParameter(0);
std::string movingFileName = metricOption->GetFunction(currentMetricNumber)->GetParameter(1);

Expand Down Expand Up @@ -1196,6 +1203,7 @@ DoRegistration(typename ParserType::Pointer & parser)
samplingStrategy,
numberOfBins,
radius,
useGradientFilter,
useBoundaryPointsOnly,
pointSetSigma,
evaluationKNeighborhood,
Expand Down
9 changes: 8 additions & 1 deletion Examples/itkantsRegistrationHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,16 @@ class RegistrationHelper final : public itk::Object
SamplingStrategy samplingStrategy,
int numberOfBins,
unsigned int radius,
bool useGradientFilter,
bool useBoundaryPointsOnly,
RealType pointSetSigma,
unsigned int evaluationKNeighborhood,
RealType alpha,
bool useAnisotropicCovariances,
RealType samplingPercentage,
RealType intensityDistanceSigma,
RealType euclideanDistanceSigma)
RealType euclideanDistanceSigma
)
: m_MetricType(metricType)
, m_FixedImage(fixedImage)
, m_MovingImage(movingImage)
Expand All @@ -224,6 +226,7 @@ class RegistrationHelper final : public itk::Object
, m_SamplingStrategy(samplingStrategy)
, m_NumberOfBins(numberOfBins)
, m_Radius(radius)
, m_UseGradientFilter(useGradientFilter)
, m_FixedLabeledPointSet(fixedLabeledPointSet)
, m_MovingLabeledPointSet(movingLabeledPointSet)
, m_FixedIntensityPointSet(fixedIntensityPointSet)
Expand Down Expand Up @@ -302,6 +305,7 @@ class RegistrationHelper final : public itk::Object
SamplingStrategy m_SamplingStrategy;
int m_NumberOfBins;
unsigned int m_Radius; // Only for CC metric
bool m_UseGradientFilter;

// Variables for point-set metrics

Expand Down Expand Up @@ -481,6 +485,7 @@ class RegistrationHelper final : public itk::Object
SamplingStrategy samplingStrategy,
int numberOfBins,
unsigned int radius,
bool useGradientFilter,
bool useBoundaryPointsOnly,
RealType pointSetSigma,
unsigned int evaluationKNeighborhood,
Expand All @@ -500,6 +505,7 @@ class RegistrationHelper final : public itk::Object
SamplingStrategy samplingStrategy,
int numberOfBins,
unsigned int radius,
bool useGradientFilter,
RealType samplingPercentage)
{
this->AddMetric(metricType,
Expand All @@ -514,6 +520,7 @@ class RegistrationHelper final : public itk::Object
samplingStrategy,
numberOfBins,
radius,
useGradientFilter,
false,
1.0,
50,
Expand Down
43 changes: 25 additions & 18 deletions Examples/itkantsRegistrationHelper.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ RegistrationHelper<TComputeType, VImageDimension>::AddMetric(MetricEnumeration
SamplingStrategy samplingStrategy,
int numberOfBins,
unsigned int radius,
bool useGradientFilter,
bool useBoundaryPointsOnly,
RealType pointSetSigma,
unsigned int evaluationKNeighborhood,
Expand All @@ -279,6 +280,7 @@ RegistrationHelper<TComputeType, VImageDimension>::AddMetric(MetricEnumeration
samplingStrategy,
numberOfBins,
radius,
useGradientFilter,
useBoundaryPointsOnly,
pointSetSigma,
evaluationKNeighborhood,
Expand Down Expand Up @@ -783,8 +785,6 @@ template <typename TComputeType, unsigned VImageDimension>
int
RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
{
/** Can really impact performance */
const bool gradientfilter = false;
itk::TimeProbe totalTimer;

totalTimer.Start();
Expand Down Expand Up @@ -923,7 +923,9 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
{
const unsigned int radiusOption = stageMetricList[currentMetricNumber].m_Radius;
this->Logger() << " using the CC metric (radius = " << radiusOption
<< ", weight = " << stageMetricList[currentMetricNumber].m_Weighting << ")" << std::endl;
<< ", weight = " << stageMetricList[currentMetricNumber].m_Weighting
<< ", use gradient filter = " << stageMetricList[currentMetricNumber].m_UseGradientFilter
<< ")" << std::endl;
typedef itk::ANTSNeighborhoodCorrelationImageToImageMetricv4<ImageType, ImageType, ImageType, TComputeType>
CorrelationMetricType;
typename CorrelationMetricType::Pointer correlationMetric = CorrelationMetricType::New();
Expand All @@ -932,8 +934,6 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
radius.Fill(radiusOption);
correlationMetric->SetRadius(radius);
}
correlationMetric->SetUseMovingImageGradientFilter(gradientfilter);
correlationMetric->SetUseFixedImageGradientFilter(gradientfilter);

imageMetric = correlationMetric;
}
Expand All @@ -942,14 +942,14 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
{
const unsigned int binOption = stageMetricList[currentMetricNumber].m_NumberOfBins;
this->Logger() << " using the Mattes MI metric (number of bins = " << binOption
<< ", weight = " << stageMetricList[currentMetricNumber].m_Weighting << ")" << std::endl;
<< ", weight = " << stageMetricList[currentMetricNumber].m_Weighting
<< ", use gradient filter = " << stageMetricList[currentMetricNumber].m_UseGradientFilter
<< ")" << std::endl;
typedef itk::MattesMutualInformationImageToImageMetricv4<ImageType, ImageType, ImageType, TComputeType>
MutualInformationMetricType;
typename MutualInformationMetricType::Pointer mutualInformationMetric = MutualInformationMetricType::New();
// mutualInformationMetric = mutualInformationMetric;
mutualInformationMetric->SetNumberOfHistogramBins(binOption);
mutualInformationMetric->SetUseMovingImageGradientFilter(gradientfilter);
mutualInformationMetric->SetUseFixedImageGradientFilter(gradientfilter);
mutualInformationMetric->SetUseSampledPointSet(false);

imageMetric = mutualInformationMetric;
Expand All @@ -959,15 +959,15 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
{
const unsigned int binOption = stageMetricList[currentMetricNumber].m_NumberOfBins;
this->Logger() << " using the joint histogram MI metric (number of bins = " << binOption
<< ", weight = " << stageMetricList[currentMetricNumber].m_Weighting << ")" << std::endl;
<< ", weight = " << stageMetricList[currentMetricNumber].m_Weighting
<< ", use gradient filter = " << stageMetricList[currentMetricNumber].m_UseGradientFilter
<< ")" << std::endl;
typedef itk::
JointHistogramMutualInformationImageToImageMetricv4<ImageType, ImageType, ImageType, TComputeType>
MutualInformationMetricType;
typename MutualInformationMetricType::Pointer mutualInformationMetric = MutualInformationMetricType::New();
// mutualInformationMetric = mutualInformationMetric;
mutualInformationMetric->SetNumberOfHistogramBins(binOption);
mutualInformationMetric->SetUseMovingImageGradientFilter(gradientfilter);
mutualInformationMetric->SetUseFixedImageGradientFilter(gradientfilter);
mutualInformationMetric->SetUseSampledPointSet(false);
mutualInformationMetric->SetVarianceForJointPDFSmoothing(1.0);

Expand All @@ -976,8 +976,10 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
break;
case MeanSquares:
{
this->Logger() << " using the MeanSquares metric (weight = "
<< stageMetricList[currentMetricNumber].m_Weighting << ")" << std::endl;
this->Logger() << " using the MeanSquares metric "
<< "( weight = " << stageMetricList[currentMetricNumber].m_Weighting
<< ", use gradient filter = " << stageMetricList[currentMetricNumber].m_UseGradientFilter
<< ")" << std::endl;

typedef itk::MeanSquaresImageToImageMetricv4<ImageType, ImageType, ImageType, TComputeType>
MeanSquaresMetricType;
Expand All @@ -989,7 +991,9 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
break;
case Demons:
{
this->Logger() << " using the Demons metric (weight = " << stageMetricList[currentMetricNumber].m_Weighting
this->Logger() << " using the Demons metric "
<< "( weight = " << stageMetricList[currentMetricNumber].m_Weighting
<< ", use gradient filter = " << stageMetricList[currentMetricNumber].m_UseGradientFilter
<< ")" << std::endl;

typedef itk::DemonsImageToImageMetricv4<ImageType, ImageType, ImageType, TComputeType> DemonsMetricType;
Expand All @@ -1000,8 +1004,10 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
break;
case GC:
{
this->Logger() << " using the global correlation metric (weight = "
<< stageMetricList[currentMetricNumber].m_Weighting << ")" << std::endl;
this->Logger() << " using the global correlation metric "
<< "( weight = " << stageMetricList[currentMetricNumber].m_Weighting
<< ", use gradient filter = " << stageMetricList[currentMetricNumber].m_UseGradientFilter
<< ")" << std::endl;
typedef itk::CorrelationImageToImageMetricv4<ImageType, ImageType, ImageType, TComputeType> corrMetricType;
typename corrMetricType::Pointer corrMetric = corrMetricType::New();

Expand Down Expand Up @@ -1137,8 +1143,9 @@ RegistrationHelper<TComputeType, VImageDimension>::DoRegistration()
// Set up the image metric and scales estimator

imageMetric->SetVirtualDomainFromImage(fixedImage);
imageMetric->SetUseMovingImageGradientFilter(gradientfilter);
imageMetric->SetUseFixedImageGradientFilter(gradientfilter);
imageMetric->SetUseMovingImageGradientFilter(stageMetricList[currentMetricNumber].m_UseGradientFilter);
imageMetric->SetUseFixedImageGradientFilter(stageMetricList[currentMetricNumber].m_UseGradientFilter);

metricWeights[currentMetricNumber] = stageMetricList[currentMetricNumber].m_Weighting;
if (useFixedImageMaskForThisStage)
{
Expand Down
3 changes: 2 additions & 1 deletion Examples/simpleSynRegistration.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ simpleSynReg(ImageType::Pointer & fixedImage,
constexpr float samplingPercentage = 1.0;
RegistrationHelperType::SamplingStrategy samplingStrategy = RegistrationHelperType::none;
constexpr unsigned int binOption = 200;
regHelper->AddMetric(curMetric, fixedImage, movingImage, 0, 1.0, samplingStrategy, binOption, 1, samplingPercentage);
bool useGradientFilter = false;
regHelper->AddMetric(curMetric, fixedImage, movingImage, 0, 1.0, samplingStrategy, binOption, 1, useGradientFilter, samplingPercentage);

const float learningRate(0.25F);
const float varianceForUpdateField(3.0F);
Expand Down

0 comments on commit ec91afc

Please sign in to comment.