From d90e0d1df80d7d50bd7603fa1dc30773046d36ae Mon Sep 17 00:00:00 2001 From: Samurdhi Karunaratne Date: Thu, 21 Jul 2022 16:41:55 -0700 Subject: [PATCH] Fix Normalize_TRT plugin segfault Signed-off-by: Kevin Chen --- plugin/normalizePlugin/normalizePlugin.cpp | 26 +++++++++++++++++----- plugin/normalizePlugin/normalizePlugin.h | 7 +++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/plugin/normalizePlugin/normalizePlugin.cpp b/plugin/normalizePlugin/normalizePlugin.cpp index 8bccf210..fd40f07f 100644 --- a/plugin/normalizePlugin/normalizePlugin.cpp +++ b/plugin/normalizePlugin/normalizePlugin.cpp @@ -35,7 +35,7 @@ const char* NORMALIZE_PLUGIN_NAME{"Normalize_TRT"}; PluginFieldCollection NormalizePluginCreator::mFC{}; std::vector NormalizePluginCreator::mPluginAttributes; -Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps) +Normalize::Normalize(Weights const* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps) : acrossSpatial(acrossSpatial) , channelShared(channelShared) , eps(eps) @@ -44,11 +44,13 @@ Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, PLUGIN_VALIDATE(nbWeights == 1); PLUGIN_VALIDATE(weights[0].count >= 1); mWeights = copyToDevice(weights[0].values, weights[0].count); + mScalarScale = static_cast(weights[0].values)[0]; } Normalize::Normalize( - const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W) - : acrossSpatial(acrossSpatial) + Weights const* weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W) + : mScalarScale(scalarScale) + , acrossSpatial(acrossSpatial) , channelShared(channelShared) , eps(eps) , C(C) @@ -74,6 +76,7 @@ Normalize::Normalize(const void* buffer, size_t length) mNbWeights = read(d); int count = read(d); + std::memcpy(&mScalarScale, d, sizeof(float)); mWeights = deserializeToDevice(d, count); PLUGIN_VALIDATE(d == a + length); } @@ -111,8 +114,19 @@ int Normalize::enqueue( { const void* inputData = inputs[0]; void* outputData = outputs[0]; - pluginStatus_t status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps, - static_cast(mWeights.values), inputData, outputData, workspace); + + pluginStatus_t status; + + if(acrossSpatial && channelShared) // Since cublasPointerMode_t is CUBLAS_POINTER_MODE_HOST, scale should be on the host + { + status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps, + &mScalarScale, inputData, outputData, workspace); + } + else // No risk of device pointers being passed to cublas as alpha or beta + { + status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps, + static_cast(mWeights.values), inputData, outputData, workspace); + } return status; } @@ -254,7 +268,7 @@ IPluginV2Ext* Normalize::clone() const noexcept try { // Create a new instance - IPluginV2Ext* plugin = new Normalize(&mWeights, mNbWeights, acrossSpatial, channelShared, eps, C, H, W); + IPluginV2Ext* plugin = new Normalize(&mWeights, mNbWeights, mScalarScale, acrossSpatial, channelShared, eps, C, H, W); // Set the namespace plugin->setPluginNamespace(mPluginNamespace.c_str()); diff --git a/plugin/normalizePlugin/normalizePlugin.h b/plugin/normalizePlugin/normalizePlugin.h index 1b904916..c9807717 100644 --- a/plugin/normalizePlugin/normalizePlugin.h +++ b/plugin/normalizePlugin/normalizePlugin.h @@ -31,10 +31,10 @@ namespace plugin class Normalize : public IPluginV2Ext { public: - Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps); + Normalize(Weights const* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps); Normalize( - const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W); + Weights const* weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W); Normalize(const void* buffer, size_t length); @@ -93,8 +93,9 @@ class Normalize : public IPluginV2Ext cublasHandle_t mCublas; - Weights mWeights{}; + Weights mWeights{}; // mWeights.values is on the device int mNbWeights{}; + float mScalarScale{}; // keep track of scale on the host (for when channelShared is true) bool acrossSpatial{}; bool channelShared{}; float eps{};