forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathImporterContext.cpp
116 lines (104 loc) · 3.89 KB
/
ImporterContext.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "ImporterContext.hpp"
namespace onnx2trt
{
void ImporterContext::pushBaseNameScope()
{
mBaseNameScopeStack.push_back({});
}
void ImporterContext::popBaseNameScope()
{
auto& tensorMap = tensors();
for (auto& binding : mBaseNameScopeStack.back())
{
if (binding.second.first)
{
tensorMap.erase(binding.first);
}
else
{
tensorMap.at(binding.first) = std::move(binding.second.second);
}
}
mBaseNameScopeStack.pop_back();
}
void ImporterContext::registerTensor(TensorOrWeights tensor, std::string const& basename, bool const checkUniqueName)
{
// TRT requires unique tensor names.
std::string const& uniqueName = generateUniqueName(mTensorNames, basename);
if (tensor)
{
if (tensor.is_tensor())
{
tensor.tensor().setName(uniqueName.c_str());
// Logging macro refers to ctx.
auto* ctx = this;
LOG_VERBOSE("Registering tensor: " << uniqueName << " for ONNX tensor: " << basename);
}
else if (tensor.is_weights())
{
auto const& weights = tensor.weights();
if (tensor.weights().type == ::ONNX_NAMESPACE::TensorProto::INT64)
{
tensor = ShapedWeights{::ONNX_NAMESPACE::TensorProto::INT32,
convertINT64(reinterpret_cast<int64_t*>(weights.values), weights.shape, this), weights.shape};
}
// It may be possible for nested subgraphs to have different values for the same initializer.
// For multiple name scopes - use unique name to keep track of weights.
if (!mBaseNameScopeStack.empty())
{
tensor.weights().setName(uniqueName.c_str());
}
else
{
tensor.weights().setName(basename.c_str());
}
}
}
std::string const& nameToCheck = checkUniqueName ? uniqueName : basename;
auto const p = this->tensors().emplace(nameToCheck, TensorOrWeights{});
bool nameIsDuplicate = false;
if (!mBaseNameScopeStack.empty())
{
// Remember original binding so it can be restored when scope is popped.
auto const q
= mBaseNameScopeStack.back().emplace(nameToCheck, std::make_pair(p.second, std::move(p.first->second)));
// Check that scope did not already have a binding for basename.
nameIsDuplicate = !q.second;
}
else
{
// The condition here accounts for ModelImporter::importModel reserving
// output names by registering null tensors.
nameIsDuplicate = !p.second && !p.first->second.isNullTensor();
}
if (nameIsDuplicate)
{
throw std::runtime_error("ONNX graph has duplicate tensor name: " + nameToCheck);
}
p.first->second = std::move(tensor);
}
void ImporterContext::registerLayer(nvinfer1::ILayer* layer, std::string const& basename)
{
// No layer will be added for Constant nodes in ONNX.
if (layer)
{
std::string const name = basename.empty() ? layer->getName() : basename;
std::string const& uniqueName = generateUniqueName(mLayerNames, name);
auto* ctx = this; // To enable logging.
LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename);
layer->setName(uniqueName.c_str());
if (layer->getType() == nvinfer1::LayerType::kCONSTANT)
{
if (basename != uniqueName && mConstantLayers.find(uniqueName) != mConstantLayers.end())
{
LOG_ERROR("Constant layer: " << uniqueName << " can be a duplicate of: " << basename);
assert(!"Internal error: duplicate constant layers for the same weights");
}
mConstantLayers.insert({uniqueName, static_cast<nvinfer1::IConstantLayer*>(layer)});
}
}
}
} // namespace onnx2trt