Skip to content

Commit

Permalink
fix ort
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 15, 2024
1 parent 4b25160 commit b028e0f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 5 additions & 3 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from hydra.utils import get_class
from onnxruntime import SessionOptions
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.onnxruntime import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
Expand Down Expand Up @@ -298,8 +297,11 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

if self.config.model_type not in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs.pop("position_ids", None)
if self.config.library == "transformers":
for key, value in list(inputs.items()):
if key in ["position_ids", "token_type_ids"]:
if key not in self.pretrained_model.input_names:
inputs.pop(key)

for key, value in inputs.items():
if isinstance(value, torch.Tensor):
Expand Down
5 changes: 3 additions & 2 deletions optimum_benchmark/scenarios/inference/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:

self.logger.info("\t+ Preparing input shapes for Inference")
self.config.input_shapes = backend.prepare_input_shapes(input_shapes=self.config.input_shapes)
self.logger.info("\t+ Preparing inputs for Inference")
self.inputs = backend.prepare_inputs(inputs=self.inputs)

self.run_model_loading_tracking(backend)

self.logger.info("\t+ Preparing inputs for Inference")
self.inputs = backend.prepare_inputs(inputs=self.inputs)

if self.config.memory:
if backend.config.task in TEXT_GENERATION_TASKS:
self.run_text_generation_memory_tracking(backend)
Expand Down

0 comments on commit b028e0f

Please sign in to comment.