Skip to content

Commit

Permalink
Fix hangs at IndexErrors when when TF is imported after TRT
Browse files Browse the repository at this point in the history
Unpacking an trt.tensorrt.Dims object causes Python to hang when
Tensorflow is imported after TensorRT, only on Windows.

Signed-off-by: samurdhi karunaratne <skarunaratne@nvidia.com>
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
samurdhikaru authored and rajeevsrao committed Apr 14, 2022
1 parent 8b74bbf commit 9798a72
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 9 deletions.
2 changes: 2 additions & 0 deletions python/include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ void doNothingDel(const T& self)
issueDeprecationWarning("del obj");
}

void throwPyIndexError(std::string message = "out of bounds");

} // namespace utils
} // namespace tensorrt

Expand Down
4 changes: 3 additions & 1 deletion python/src/infer/pyCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ static const auto engine_getitem = [](ICudaEngine& self, int pyIndex) {
// Support python's negative indexing
size_t index = (pyIndex < 0) ? static_cast<int>(self.getNbBindings()) + pyIndex : pyIndex;
if (index >= self.getNbBindings())
throw py::index_error();
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
return self.getBindingName(index);
};

Expand Down
16 changes: 12 additions & 4 deletions python/src/infer/pyFoundationalTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ static const auto dims_getter = [](const Dims& self, int pyIndex) -> const int&
// Without these bounds checks, horrible infinite looping will occur.
size_t index = (pyIndex < 0) ? static_cast<int>(self.nbDims) + pyIndex : pyIndex;
if (index >= self.nbDims)
throw py::index_error();
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
return self.d[index];
};

Expand All @@ -110,7 +112,9 @@ static const auto dims_getter_slice = [](const Dims& self, py::slice slice) {
throw py::error_already_set();
// Disallow out-of-bounds things.
if (stop > self.nbDims)
throw py::index_error();
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}

py::tuple ret{slicelength};
for (int i = start, index = 0; i < stop; i += step, ++index)
Expand All @@ -121,7 +125,9 @@ static const auto dims_getter_slice = [](const Dims& self, py::slice slice) {
static const auto dims_setter = [](Dims& self, int pyIndex, int item) {
size_t index = (pyIndex < 0) ? static_cast<int>(self.nbDims) + pyIndex : pyIndex;
if (index >= self.nbDims)
throw py::index_error();
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
self.d[index] = item;
};

Expand All @@ -131,7 +137,9 @@ static const auto dims_setter_slice = [](Dims& self, py::slice slice, const Dims
throw py::error_already_set();
// Disallow out-of-bounds things.
if (stop >= self.nbDims)
throw py::index_error();
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}

for (int i = start, index = 0; i < stop; i += step, ++index)
self.d[i] = other.d[index];
Expand Down
15 changes: 12 additions & 3 deletions python/src/infer/pyGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,20 @@ namespace tensorrt
static const auto permutation_getter = [] (const Permutation& self, int pyIndex) {
size_t index = (pyIndex < 0) ? static_cast<const int>(Dims::MAX_DIMS) + pyIndex : pyIndex;
// Static cast is REQUIRED here, or chaos ensues as MAX_DIMS is not pulled in at link time.
if (index >= static_cast<const size_t>(Dims::MAX_DIMS)) throw py::index_error();
if (index >= static_cast<const size_t>(Dims::MAX_DIMS))
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
return self.order[index];
};

static const auto permutation_setter = [] (Permutation& self, int pyIndex, int item) {
size_t index = (pyIndex < 0) ? static_cast<const int>(Dims::MAX_DIMS) + pyIndex : pyIndex;
// Static cast is REQUIRED here, or chaos ensues as MAX_DIMS is not pulled in at link time.
if (index >= static_cast<const size_t>(Dims::MAX_DIMS)) throw py::index_error();
if (index >= static_cast<const size_t>(Dims::MAX_DIMS))
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
self.order[index] = item;
};

Expand Down Expand Up @@ -197,7 +203,10 @@ namespace tensorrt
static const auto network_getitem = [](INetworkDefinition& self, int pyIndex) {
// Support python's negative indexing
size_t index = (pyIndex < 0) ? self.getNbLayers() + pyIndex : pyIndex;
if (index >= self.getNbLayers()) throw py::index_error();
if (index >= self.getNbLayers())
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
return self.getLayer(index);
};

Expand Down
5 changes: 4 additions & 1 deletion python/src/infer/pyPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// This file contains all bindings related to plugins.
#include "ForwardDeclarations.h"
#include "infer/pyPluginDoc.h"
#include "utils.h"
#include <cuda_runtime_api.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -217,7 +218,9 @@ void bindPlugin(py::module& m)
.def("__len__", [](PluginFieldCollection& self) { return self.nbFields; })
.def("__getitem__", [](PluginFieldCollection& self, int index) {
if (index >= self.nbFields)
throw py::index_error();
{
utils::throwPyIndexError(); // See definition of throwPyIndexError() for details
}
return self.fields[index];
});

Expand Down
10 changes: 10 additions & 0 deletions python/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,15 @@ void issueDeprecationWarning(const char* useInstead)
PyErr_WarnEx(PyExc_DeprecationWarning, msg.c_str(), 1);
}

// The following is a helper WAR to "throw py::index_error()", which results in an incompatibility
// with Tensorflow 2.5 and above--on Windows only--when Tensorflow is imported after TensorRT.
// The TF library fast_module_type.pyd hooks on to IndexErrors thrown through py::index_error()
// resulting in hangs at unpacking operations and out-of-bounds index accesses.
void throwPyIndexError(std::string message)
{
PyErr_SetString(PyExc_IndexError, message.data());
throw py::error_already_set();
}

} // namespace utils
} // namespace tensorrt

0 comments on commit 9798a72

Please sign in to comment.