Skip to content

Commit

Permalink
Merge pull request #1788 from borisfom/fix-onnx-surgeon-ort
Browse files Browse the repository at this point in the history
Fixing onnxruntime-1.10 compatibility
  • Loading branch information
pranavm-nvidia authored Feb 23, 2022
2 parents 42805f0 + ab48358 commit 97c5b58
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def fold_constants(self, fold_shapes=True, recurse_subgraphs=True, partitioning=
import onnxruntime as rt
from onnx_graphsurgeon.exporters.onnx_exporter import export_onnx

ORT_PROVIDERS=['CPUExecutionProvider']
PARTITIONING_MODES = [None, "basic", "recursive"]
if partitioning not in PARTITIONING_MODES:
G_LOGGER.critical("Argument for parameter 'partitioning' must be one of: {:}".format(PARTITIONING_MODES))
Expand Down Expand Up @@ -795,7 +796,7 @@ def get_out_node_ids():

try:
# Determining types is not trivial, and ONNX-RT does its own type inference.
sess = rt.InferenceSession(export_onnx(part, do_type_check=False).SerializeToString())
sess = rt.InferenceSession(export_onnx(part, do_type_check=False).SerializeToString(), providers=ORT_PROVIDERS)
values = sess.run(names, {})
except Exception as err:
G_LOGGER.warning("Inference failed for subgraph: {:}. Note: Error was:\n{:}".format(part.name, err))
Expand Down Expand Up @@ -842,7 +843,7 @@ def should_eval_foldable(tensor):
else:
names = [t.name for t in graph_clone.outputs]
try:
sess = rt.InferenceSession(export_onnx(graph_clone, do_type_check=False).SerializeToString())
sess = rt.InferenceSession(export_onnx(graph_clone, do_type_check=False).SerializeToString(), providers=ORT_PROVIDERS)
values = sess.run(names, {})
constant_values.update({name: val for name, val in zip(names, values)})
except Exception as err:
Expand Down

0 comments on commit 97c5b58

Please sign in to comment.