From d1bba1f649938e14c7e4d72267940d82e488128b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 12:46:47 +0100 Subject: [PATCH 01/21] prepare inputs before and after model loading --- optimum_benchmark/backends/base.py | 13 ++++++++----- optimum_benchmark/scenarios/inference/scenario.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 78d0bef7..5fddcca9 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -106,20 +106,23 @@ def create_no_weights_model(self) -> None: self.logger.info("\t+ Saving no weights model's config") self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """ - This method is used to prepare and register the input shapes before using them by the model. - It can be used to pad the inputs to the correct shape, or compile it to the correct format. + This method is used to prepare and register the inputs before passing them to the model. + It can be used to move the inputs to the correct device, or rename their keys. """ - return input_shapes + return inputs - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """ This method is used to prepare and register the inputs before passing them to the model. It can be used to move the inputs to the correct device, or rename their keys. """ return inputs + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return self.prepare_inputs_after_load(self.prepare_inputs_before_load(inputs)) + def load(self) -> None: raise NotImplementedError("Backend must implement load method") diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index 45461714..c84f27e0 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -118,8 +118,6 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: device_ids=self.backend.config.device_ids, ) - self.run_model_loading_tracking() - self.logger.info(f"\t+ Generating inputs for task {self.backend.config.task}") self.inputs = InputGenerator( task=self.backend.config.task, @@ -127,8 +125,14 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: model_type=self.backend.config.model_type, input_shapes=self.config.input_shapes, )() - self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name}") - self.inputs = self.backend.prepare_inputs(inputs=self.inputs) + + self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name} before model loading.") + self.inputs = self.backend.prepare_inputs_before_load(inputs=self.inputs) + + self.run_model_loading_tracking() + + self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name} after model loading.") + self.inputs = self.backend.prepare_inputs_after_load(inputs=self.inputs) if self.config.warmup_runs > 0: if self.backend.config.task in TEXT_GENERATION_TASKS: From 71f110050775653a0de01d17389310ef6323dedf Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 12:47:18 +0100 Subject: [PATCH 02/21] update ipex --- optimum_benchmark/backends/ipex/backend.py | 26 ++++++---------------- optimum_benchmark/backends/ipex/config.py | 8 +++---- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index 7e4983a9..049b2c7b 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -45,41 +45,29 @@ def load(self) -> None: self.tmpdir.cleanup() - def _load_automodel_from_pretrained(self) -> None: - self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs) - - def _load_automodel_with_no_weights(self) -> None: - original_model, self.config.model = self.config.model, self.no_weights_model - - with fast_weights_init(): - self._load_automodel_from_pretrained() - - self.logger.info("\t+ Tying model weights") - self.pretrained_model.tie_weights() - - self.config.model = original_model - def _load_ipexmodel_from_pretrained(self) -> None: self.pretrained_model = self.ipexmodel_class.from_pretrained( self.config.model, - export=self.config.export, **self.config.model_kwargs, - **self.automodel_kwargs, + **self.ipexmodel_kwargs, ) def _load_ipexmodel_with_no_weights(self) -> None: with fast_weights_init(): + self.logger.info("\t+ Loading no weights IPEXModel") original_model, self.config.model = self.config.model, self.no_weights_model original_export, self.config.export = self.config.export, True - self.logger.info("\t+ Loading no weights IPEXModel") self._load_ipexmodel_from_pretrained() self.config.export = original_export self.config.model = original_model @property - def automodel_kwargs(self) -> Dict[str, Any]: + def ipexmodel_kwargs(self) -> Dict[str, Any]: kwargs = {} + if self.config.export: + kwargs["export"] = self.config.export + if self.config.torch_dtype is not None: kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) @@ -89,7 +77,7 @@ def automodel_kwargs(self) -> Dict[str, Any]: def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs diff --git a/optimum_benchmark/backends/ipex/config.py b/optimum_benchmark/backends/ipex/config.py index 5ee4aad1..4fb553da 100644 --- a/optimum_benchmark/backends/ipex/config.py +++ b/optimum_benchmark/backends/ipex/config.py @@ -13,17 +13,17 @@ class IPEXConfig(BackendConfig): version: Optional[str] = ipex_version() _target_: str = "optimum_benchmark.backends.ipex.backend.IPEXBackend" - # load options no_weights: bool = False - torch_dtype: Optional[str] = None - # export options - export: bool = True + # ipexmodel kwargs + export: Optional[bool] = None + torch_dtype: Optional[str] = None def __post_init__(self): super().__post_init__() self.device = self.device.lower() + if self.device not in ["cpu", "gpu"]: raise ValueError(f"IPEXBackend only supports CPU devices, got {self.device}") From 6a649903ae6d99eb1278263365f120bc1fca6c7d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 12:54:43 +0100 Subject: [PATCH 03/21] fix llama cpp --- optimum_benchmark/backends/llama_cpp/backend.py | 13 ++++++------- optimum_benchmark/backends/llama_cpp/config.py | 2 ++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/optimum_benchmark/backends/llama_cpp/backend.py b/optimum_benchmark/backends/llama_cpp/backend.py index c9d6bbf8..67b2dbb2 100644 --- a/optimum_benchmark/backends/llama_cpp/backend.py +++ b/optimum_benchmark/backends/llama_cpp/backend.py @@ -28,8 +28,7 @@ def load_model_from_pretrained(self) -> None: """ self.pretrained_model = Llama.from_pretrained( - repo_id=self.config.model, - filename=self.config.filename, + self.config.model, **self.llama_cpp_kwargs, ) @@ -37,20 +36,20 @@ def load_model_from_pretrained(self) -> None: def llama_cpp_kwargs(self) -> Dict[str, Any]: return { "embedding": self.config.task == "feature-extraction", + "filename": self.config.filename, "verbose": False, "echo": False, } - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task == "text-generation": if inputs["input_ids"].shape[0] != 1: - raise ValueError("Batch size must be 1 for LlamaCpp text generation") + raise ValueError("Batch size must be 1 for Text Generation with llama-cpp-python") return {"tokens": inputs["input_ids"].squeeze(0).tolist()} - elif self.config.task == "feature-extraction": return {"input": [self.pretrained_model.detokenize(x).decode("utf-8") for x in inputs["input_ids"]]} - - raise ValueError(f"Task {self.config.task} not supported by {self.NAME}") + else: + raise ValueError(f"Task {self.config.task} not supported by {self.NAME}") def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: self.pretrained_model.embed(**inputs) diff --git a/optimum_benchmark/backends/llama_cpp/config.py b/optimum_benchmark/backends/llama_cpp/config.py index d2902860..183a86f8 100644 --- a/optimum_benchmark/backends/llama_cpp/config.py +++ b/optimum_benchmark/backends/llama_cpp/config.py @@ -12,6 +12,8 @@ class LlamaCppConfig(BackendConfig): _target_: str = "optimum_benchmark.backends.llama_cpp.backend.LlamaCppBackend" no_weights: bool = False + + # llamamodel kwargs filename: Optional[str] = None def __post_init__(self): From 3644c449ace61f3c9ff0e491d1c8a664f5ca8829 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 12:59:24 +0100 Subject: [PATCH 04/21] txi --- optimum_benchmark/backends/py_txi/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum_benchmark/backends/py_txi/backend.py b/optimum_benchmark/backends/py_txi/backend.py index 6e637a31..cf9a6920 100644 --- a/optimum_benchmark/backends/py_txi/backend.py +++ b/optimum_benchmark/backends/py_txi/backend.py @@ -139,7 +139,7 @@ def load_model_from_pretrained(self) -> None: else: raise NotImplementedError(f"TXI does not support task {self.config.task}") - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} elif self.config.task in TEXT_EMBEDDING_TASKS: From fb7a99ec223fd59482b362f021bf015231baa013 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 12:59:53 +0100 Subject: [PATCH 05/21] pytorch --- optimum_benchmark/backends/pytorch/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 651e6d12..3fb8d80a 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -416,7 +416,7 @@ def split_between_processes(self) -> bool: and not self.config.deepspeed_inference ) - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs From 3a21aa5ab6ddcbb43bfd67fe102dae8b4ecfd636 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 13:00:45 +0100 Subject: [PATCH 06/21] openvino --- .../backends/openvino/backend.py | 126 ++++-------------- optimum_benchmark/backends/openvino/config.py | 30 ++--- optimum_benchmark/backends/openvino/utils.py | 10 +- 3 files changed, 44 insertions(+), 122 deletions(-) diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index f0aa1925..729c43c1 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -5,17 +5,12 @@ import torch from hydra.utils import get_class -from openvino.runtime import properties -from optimum.intel.openvino import OVConfig as OVQuantizationConfig # naming conflict -from optimum.intel.openvino import OVQuantizer -from ...generators.dataset_generator import DatasetGenerator from ...import_utils import is_accelerate_available, is_torch_distributed_available -from ...task_utils import TEXT_GENERATION_TASKS from ..base import Backend from ..transformers_utils import fast_weights_init -from .config import OVConfig -from .utils import TASKS_TO_MODEL_TYPES_TO_OVPIPELINE, TASKS_TO_OVMODEL +from .config import OVConfig as OVBackendConfig +from .utils import TASKS_OVPIPELINE, TASKS_TO_OVMODEL if is_accelerate_available(): from accelerate import Accelerator @@ -24,53 +19,26 @@ import torch.distributed -class OVBackend(Backend[OVConfig]): +class OVBackend(Backend[OVBackendConfig]): NAME: str = "openvino" - def __init__(self, config: OVConfig) -> None: + def __init__(self, config: OVBackendConfig) -> None: super().__init__(config) if self.config.task in TASKS_TO_OVMODEL: self.ovmodel_class = get_class(TASKS_TO_OVMODEL[self.config.task]) self.logger.info(f"\t+ Using OVModel class {self.ovmodel_class.__name__}") - elif self.config.task in TASKS_TO_MODEL_TYPES_TO_OVPIPELINE: - if self.config.model_type in TASKS_TO_MODEL_TYPES_TO_OVPIPELINE[self.config.task]: - self.ovmodel_class = get_class( - TASKS_TO_MODEL_TYPES_TO_OVPIPELINE[self.config.task][self.config.model_type] - ) - self.logger.info(f"\t+ Using OVPipeline class {self.ovmodel_class.__name__}") - else: - raise NotImplementedError( - f"OVBackend does not support model {self.config.model_type} for task {self.config.task}" - ) + elif self.config.task in TASKS_OVPIPELINE: + self.ovmodel_class = get_class(TASKS_OVPIPELINE[self.config.task]) + self.logger.info(f"\t+ Using OVDiffusionPipeline class {self.ovmodel_class.__name__}") else: raise NotImplementedError(f"OVBackend does not support task {self.config.task}") - if self.config.inter_op_num_threads is not None: - self.logger.info(f"\t+ Setting inter_op_num_threads to {self.config.inter_op_num_threads}") - self.config.openvino_config[properties.inference_num_threads()] = self.config.inter_op_num_threads - def load(self) -> None: self.logger.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() - if self.config.quantization: - if self.config.no_weights: - self.logger.info("\t+ Creating no weights AutoModel") - self.create_no_weights_model() - self.logger.info("\t+ Loading no weights AutoModel") - self._load_automodel_with_no_weights() - else: - self.logger.info("\t+ Loading pretrained AutoModel") - self._load_automodel_from_pretrained() - self.logger.info("\t+ Applying post-training quantization") - self.quantize_automodel() - original_model, self.config.model = self.config.model, self.quantized_model - original_export, self.config.export = self.config.export, False - self.logger.info("\t+ Loading quantized OVModel") - self._load_ovmodel_from_pretrained() - self.config.model, self.config.export = original_model, original_export - elif self.config.no_weights: + if self.config.no_weights: self.logger.info("\t+ Creating no weights OVModel") self.create_no_weights_model() self.logger.info("\t+ Loading no weights OVModel") @@ -85,9 +53,6 @@ def load(self) -> None: for key, value in self.model_shapes.items() if key in inspect.getfullargspec(self.pretrained_model.reshape).args } - if ("sequence_length" in static_shapes) and ("height" in static_shapes) and ("width" in static_shapes): - # for vision models, sequence_length is the number of channels - static_shapes["sequence_length"] = self.model_shapes.get("num_channels") self.logger.info(f"\t+ Reshaping model with static shapes: {static_shapes}") self.pretrained_model.reshape(**static_shapes) @@ -102,26 +67,9 @@ def load(self) -> None: self.tmpdir.cleanup() - def _load_automodel_from_pretrained(self) -> None: - self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs) - - def _load_automodel_with_no_weights(self) -> None: - original_model, self.config.model = self.config.model, self.no_weights_model - - with fast_weights_init(): - self._load_automodel_from_pretrained() - - self.logger.info("\t+ Tying model weights") - self.pretrained_model.tie_weights() - - self.config.model = original_model - def _load_ovmodel_from_pretrained(self) -> None: self.pretrained_model = self.ovmodel_class.from_pretrained( self.config.model, - export=self.config.export, - ov_config=self.config.openvino_config, - device=self.config.device, **self.config.model_kwargs, **self.ovmodel_kwargs, ) @@ -135,61 +83,36 @@ def _load_ovmodel_with_no_weights(self) -> None: self.config.export = original_export self.config.model = original_model - def quantize_automodel(self) -> None: - self.logger.info("\t+ Attempting quantization") - self.quantized_model = f"{self.tmpdir.name}/quantized_model" - self.logger.info("\t+ Processing quantization config") - quantization_config = OVQuantizationConfig(**self.config.quantization_config) - self.logger.info("\t+ Creating quantizer") - quantizer = OVQuantizer.from_pretrained(self.pretrained_model, task=self.config.task, seed=self.config.seed) - - if self.config.calibration: - self.logger.info("\t+ Generating calibration dataset") - dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes} - calibration_dataset = DatasetGenerator( - task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes - )() - columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._export_input_names)) - calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) - else: - calibration_dataset = None - - self.logger.info("\t+ Quantizing model") - quantizer.quantize( - save_directory=self.quantized_model, - quantization_config=quantization_config, - calibration_dataset=calibration_dataset, - # TODO: add support for these (maybe) - remove_unused_columns=True, - data_collator=None, - weights_only=False, - file_name=None, - batch_size=1, - ) - @property def ovmodel_kwargs(self) -> Dict[str, Any]: kwargs = {} - if self.config.task in TEXT_GENERATION_TASKS: + if self.config.export is not None: + kwargs["export"] = self.config.export + + if self.config.use_cache is not None: kwargs["use_cache"] = self.config.use_cache + + if self.config.use_merged is not None: kwargs["use_merged"] = self.config.use_merged + if self.config.load_in_8bit is not None: + kwargs["load_in_8bit"] = self.config.load_in_8bit + + if self.config.load_in_4bit is not None: + kwargs["load_in_4bit"] = self.config.load_in_4bit + return kwargs @property def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs - for key in list(inputs.keys()): - if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: - inputs.pop(key) - if "input_ids" in inputs: self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape))) @@ -200,6 +123,13 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs + def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + for key in list(inputs.keys()): + if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: + inputs.pop(key) + + return inputs + def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.forward(**inputs, **kwargs) diff --git a/optimum_benchmark/backends/openvino/config.py b/optimum_benchmark/backends/openvino/config.py index 7e6eac25..ffdf1566 100644 --- a/optimum_benchmark/backends/openvino/config.py +++ b/optimum_benchmark/backends/openvino/config.py @@ -11,28 +11,22 @@ class OVConfig(BackendConfig): version: Optional[str] = openvino_version() _target_: str = "optimum_benchmark.backends.openvino.backend.OVBackend" - # load options no_weights: bool = False - # export options - export: bool = True - use_cache: bool = True - use_merged: bool = False - - # openvino config - openvino_config: Dict[str, Any] = field(default_factory=dict) + # ovmodel kwargs + export: Optional[bool] = None + use_cache: Optional[bool] = None + use_merged: Optional[bool] = None + load_in_8bit: Optional[bool] = None + load_in_4bit: Optional[bool] = None # compilation options half: bool = False + compile: bool = False reshape: bool = False - # quantization options - quantization: bool = False - quantization_config: Dict[str, Any] = field(default_factory=dict) - - # calibration options - calibration: bool = False - calibration_config: Dict[str, Any] = field(default_factory=dict) + # openvino config + ov_config: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): super().__post_init__() @@ -42,7 +36,7 @@ def __post_init__(self): raise ValueError(f"OVBackend only supports CPU devices, got {self.device}") if self.intra_op_num_threads is not None: - raise NotImplementedError("OVBackend does not support intra_op_num_threads") + raise NotImplementedError("OVBackend does not support intra_op_num_threads. Please use the ov_config") - if self.quantization and not self.calibration: - raise ValueError("OpenVINO quantization requires enabling calibration.") + if self.inter_op_num_threads is not None: + raise NotImplementedError("OVBackend does not support inter_op_num_threads. Please use the ov_config") diff --git a/optimum_benchmark/backends/openvino/utils.py b/optimum_benchmark/backends/openvino/utils.py index 35518346..51f9629c 100644 --- a/optimum_benchmark/backends/openvino/utils.py +++ b/optimum_benchmark/backends/openvino/utils.py @@ -10,10 +10,8 @@ "audio-classification": "optimum.intel.openvino.OVModelForAudioClassification", "pix2struct": "optimum.intel.openvino.OVModelForPix2Struct", } -TASKS_TO_MODEL_TYPES_TO_OVPIPELINE = { - "text-to-image": { - "lcm": "optimum.intel.openvino.OVLatentConsistencyModelPipeline", - "stable-diffusion": "optimum.intel.openvino.OVStableDiffusionPipeline", - "stable-diffusion-xl": "optimum.intel.openvino.OVStableDiffusionXLPipeline", - }, +TASKS_OVPIPELINE = { + "inpainting": "optimum.intel.openvino.OVPipelineForInpainting", + "text-to-image": "optimum.intel.openvino.OVPipelineForText2Image", + "image-to-image": "optimum.intel.openvino.OVPipelineForImage2Image", } From e812e6e31a2c00d10216f7afc345992cdba4b2f2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 13:06:16 +0100 Subject: [PATCH 07/21] onnxruntime --- .../backends/onnxruntime/backend.py | 81 ++++++++++--------- .../backends/onnxruntime/config.py | 16 ++-- .../backends/onnxruntime/utils.py | 17 +--- .../backends/openvino/backend.py | 10 +-- 4 files changed, 57 insertions(+), 67 deletions(-) diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 2fffcc36..feec7ec3 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -23,13 +23,12 @@ from ...generators.dataset_generator import DatasetGenerator from ...import_utils import is_accelerate_available, is_torch_distributed_available -from ...task_utils import TEXT_GENERATION_TASKS from ..base import Backend from ..transformers_utils import fast_weights_init from .config import ORTConfig from .utils import ( - TASKS_TO_MODEL_TYPES_TO_ORTPIPELINES, TASKS_TO_ORTMODELS, + TASKS_TO_ORTPIPELINE, format_calibration_config, format_quantization_config, ) @@ -47,28 +46,15 @@ class ORTBackend(Backend[ORTConfig]): def __init__(self, config: ORTConfig) -> None: super().__init__(config) - if self.config.task in TASKS_TO_ORTMODELS: + if self.config.library != "diffusers" and self.config.task in TASKS_TO_ORTMODELS: self.ort_model_loader = get_class(TASKS_TO_ORTMODELS[self.config.task]) - self.logger.info(f"Using ORT Model class {self.ort_model_loader.__name__}") - elif self.config.task in TASKS_TO_MODEL_TYPES_TO_ORTPIPELINES: - if self.config.model_type in TASKS_TO_MODEL_TYPES_TO_ORTPIPELINES[self.config.task]: - self.ort_model_loader = get_class( - TASKS_TO_MODEL_TYPES_TO_ORTPIPELINES[self.config.task][self.config.model_type] - ) - self.logger.info(f"Using ORT Pipeline class {self.ort_model_loader.__name__}") - else: - raise NotImplementedError( - f"ORTBackend does not support model {self.config.model_type} for task {self.config.task}" - ) + self.logger.info(f"Using ORTModel class {self.ort_model_loader.__name__}") + elif self.config.library == "diffusers" and self.config.task in TASKS_TO_ORTPIPELINE: + self.ort_model_loader = get_class(TASKS_TO_ORTPIPELINE[self.config.task]) + self.logger.info(f"Using ORTDiffusionPipeline class {self.ort_model_loader.__name__}") else: raise NotImplementedError(f"ORTBackend does not support task {self.config.task}") - self.session_options = SessionOptions() - if self.config.session_options: - self.logger.info("\t+ Processing session options") - for key, value in self.config.session_options.items(): - setattr(self.session_options, key, value) - def validate_execution_provider(self) -> None: if not self.pretrained_model.providers[0] == self.config.provider: raise ValueError( @@ -117,22 +103,18 @@ def load(self) -> None: def load_ortmodel_from_pretrained(self) -> None: self.pretrained_model = self.ort_model_loader.from_pretrained( self.config.model, - export=self.config.export, - session_options=self.session_options, - provider_options=self.config.provider_options, - use_io_binding=self.config.use_io_binding, - provider=self.config.provider, **self.config.model_kwargs, **self.ortmodel_kwargs, ) def load_ortmodel_with_no_weights(self) -> None: - original_model, self.config.model = self.config.model, self.no_weights_model - with fast_weights_init(): + original_model, self.config.model = self.config.model, self.no_weights_model + original_export, self.config.export = self.config.export, True + self.logger.info("\t+ Loading no weights ORTModel") self.load_ortmodel_from_pretrained() - - self.config.model = original_model + self.config.export = original_export + self.config.model = original_model @property def is_optimized(self) -> bool: @@ -146,18 +128,36 @@ def is_quantized(self) -> bool: def is_calibrated(self) -> bool: return (self.config.auto_calibration is not None) or self.config.calibration - @property - def is_dp_distributed(self) -> bool: - return is_torch_distributed_available() and torch.distributed.is_initialized() - @property def ortmodel_kwargs(self) -> Dict[str, Any]: kwargs = {} - if self.config.task in TEXT_GENERATION_TASKS: + if self.config.export is not None: + kwargs["export"] = self.config.export + + if self.config.provider is not None: + kwargs["provider"] = self.config.provider + + if self.config.use_cache is not None: kwargs["use_cache"] = self.config.use_cache + + if self.config.use_merged is not None: kwargs["use_merged"] = self.config.use_merged + if self.config.torch_dtype is not None: + kwargs["torch_dtype"] = self.config.torch_dtype + + if self.config.use_io_binding is not None: + kwargs["use_io_binding"] = self.config.use_io_binding + + if self.config.session_options: + kwargs["session_options"] = SessionOptions() + for key, value in self.config.session_options.items(): + setattr(kwargs["session_options"], key, value) + + if self.config.provider_options: + kwargs["provider_options"] = self.config.provider_options + return kwargs @property @@ -284,21 +284,24 @@ def quantize_onnx_files(self) -> None: def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs - for key in list(inputs.keys()): - if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: - inputs.pop(key) - for key, value in inputs.items(): if isinstance(value, torch.Tensor): inputs[key] = value.to(self.config.device) return inputs + def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + for key in list(inputs.keys()): + if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: + inputs.pop(key) + + return inputs + @torch.inference_mode() def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.forward(**inputs, **kwargs) diff --git a/optimum_benchmark/backends/onnxruntime/config.py b/optimum_benchmark/backends/onnxruntime/config.py index 07101f78..4c171a3c 100644 --- a/optimum_benchmark/backends/onnxruntime/config.py +++ b/optimum_benchmark/backends/onnxruntime/config.py @@ -35,19 +35,15 @@ class ORTConfig(BackendConfig): # load options no_weights: bool = False - # export options - export: bool = True - use_cache: bool = True - use_merged: bool = False - torch_dtype: Optional[str] = None - - # provider options + # ortmodel kwargs + export: Optional[bool] = None provider: Optional[str] = None - provider_options: Dict[str, Any] = field(default_factory=dict) - - # inference options + use_cache: Optional[bool] = None + use_merged: Optional[bool] = None + torch_dtype: Optional[str] = None use_io_binding: Optional[bool] = None session_options: Dict[str, Any] = field(default_factory=dict) + provider_options: Dict[str, Any] = field(default_factory=dict) # null, O1, O2, O3, O4 auto_optimization: Optional[str] = None diff --git a/optimum_benchmark/backends/onnxruntime/utils.py b/optimum_benchmark/backends/onnxruntime/utils.py index 6177ae8e..63598223 100644 --- a/optimum_benchmark/backends/onnxruntime/utils.py +++ b/optimum_benchmark/backends/onnxruntime/utils.py @@ -7,19 +7,10 @@ task: f"optimum.onnxruntime.{task_dict['class'][0].__name__}" for task, task_dict in ORT_SUPPORTED_TASKS.items() } -TASKS_TO_MODEL_TYPES_TO_ORTPIPELINES = { - "text-to-image": { - "stable-diffusion": "optimum.onnxruntime.ORTStableDiffusionPipeline", - "stable-diffusion-xl": "optimum.onnxruntime.ORTStableDiffusionXLPipeline", - "latent-consistency": "optimum.onnxruntime.ORTLatentConsistencyModelPipeline", - }, - "image-to-image": { - "stable-diffusion": "optimum.onnxruntime.ORTStableDiffusionImg2ImgPipeline", - "stable-diffusion-xl": "optimum.onnxruntime.ORTStableDiffusionImg2ImgXLPipeline", - }, - "inpainting": { - "stable-diffusion": "optimum.onnxruntime.ORTStableDiffusionInpaintingPipeline", - }, +TASKS_TO_ORTPIPELINE = { + "inpainting": "optimum.onnxruntime.ORTPipelineForInpainting", + "text-to-image": "optimum.onnxruntime.ORTPipelineForText2Image", + "image-to-image": "optimum.onnxruntime.ORTPipelineForImage2Image", } diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index 729c43c1..f4ead46c 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -42,10 +42,10 @@ def load(self) -> None: self.logger.info("\t+ Creating no weights OVModel") self.create_no_weights_model() self.logger.info("\t+ Loading no weights OVModel") - self._load_ovmodel_with_no_weights() + self.load_ovmodel_with_no_weights() else: self.logger.info("\t+ Loading pretrained OVModel") - self._load_ovmodel_from_pretrained() + self.load_ovmodel_from_pretrained() if self.config.reshape: static_shapes = { @@ -67,19 +67,19 @@ def load(self) -> None: self.tmpdir.cleanup() - def _load_ovmodel_from_pretrained(self) -> None: + def load_ovmodel_from_pretrained(self) -> None: self.pretrained_model = self.ovmodel_class.from_pretrained( self.config.model, **self.config.model_kwargs, **self.ovmodel_kwargs, ) - def _load_ovmodel_with_no_weights(self) -> None: + def load_ovmodel_with_no_weights(self) -> None: with fast_weights_init(): original_model, self.config.model = self.config.model, self.no_weights_model original_export, self.config.export = self.config.export, True self.logger.info("\t+ Loading no weights OVModel") - self._load_ovmodel_from_pretrained() + self.load_ovmodel_from_pretrained() self.config.export = original_export self.config.model = original_model From 4d18b2481b426c0d0aa51e0b142498eceac13b3e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 13:28:38 +0100 Subject: [PATCH 08/21] tests --- tests/configs/_diffusers_.yaml | 11 +++++++---- tests/configs/_export_.yaml | 2 ++ tests/configs/_inc_quant_.yaml | 3 --- ...ce_neural_compressor_inc_quant_text_decoders.yaml | 12 ------------ ...ce_neural_compressor_inc_quant_text_encoders.yaml | 12 ------------ .../configs/cpu_inference_onnxruntime_diffusers.yaml | 1 + .../configs/cpu_inference_onnxruntime_ort_quant.yaml | 1 + .../cpu_inference_onnxruntime_text_decoders.yaml | 1 + .../cpu_inference_onnxruntime_text_encoders.yaml | 1 + ...inference_onnxruntime_text_encoders_decoders.yaml | 1 + tests/configs/cpu_inference_onnxruntime_timm.yaml | 1 + .../cuda_inference_onnxruntime_text_decoders.yaml | 1 + .../cuda_inference_onnxruntime_text_encoders.yaml | 1 + 13 files changed, 17 insertions(+), 31 deletions(-) create mode 100644 tests/configs/_export_.yaml delete mode 100644 tests/configs/_inc_quant_.yaml delete mode 100644 tests/configs/cpu_inference_neural_compressor_inc_quant_text_decoders.yaml delete mode 100644 tests/configs/cpu_inference_neural_compressor_inc_quant_text_encoders.yaml diff --git a/tests/configs/_diffusers_.yaml b/tests/configs/_diffusers_.yaml index 607b2502..7083e261 100644 --- a/tests/configs/_diffusers_.yaml +++ b/tests/configs/_diffusers_.yaml @@ -1,4 +1,7 @@ -backend: - library: diffusers - task: text-to-image - model: hf-internal-testing/tiny-stable-diffusion-torch +hydra: + mode: MULTIRUN + sweeper: + params: + backend.library: diffusers + backend.task: text-to-image,image-to-image,inpainting + backend.model: hf-internal-testing/tiny-stable-diffusion-torch diff --git a/tests/configs/_export_.yaml b/tests/configs/_export_.yaml new file mode 100644 index 00000000..50f1bb0f --- /dev/null +++ b/tests/configs/_export_.yaml @@ -0,0 +1,2 @@ +backend: + export: true diff --git a/tests/configs/_inc_quant_.yaml b/tests/configs/_inc_quant_.yaml deleted file mode 100644 index 1347abfc..00000000 --- a/tests/configs/_inc_quant_.yaml +++ /dev/null @@ -1,3 +0,0 @@ -backend: - ptq_quantization: true - calibration: true diff --git a/tests/configs/cpu_inference_neural_compressor_inc_quant_text_decoders.yaml b/tests/configs/cpu_inference_neural_compressor_inc_quant_text_decoders.yaml deleted file mode 100644 index 7865da6e..00000000 --- a/tests/configs/cpu_inference_neural_compressor_inc_quant_text_decoders.yaml +++ /dev/null @@ -1,12 +0,0 @@ -defaults: - # order of inheritance, last one overrides previous ones - - _base_ # inherits from base config - - _cpu_ # inherits from cpu config - - _inference_ # inherits from inference config - - _inc_quant_ # inherits from incremental quantization config - - _text_decoders_ # inherits from text decoders config - - _no_weights_ # inherits from no weights config - - _self_ # hydra 1.1 compatibility - - override backend: neural-compressor - -name: cpu_inference_neural_compressor_text_decoders diff --git a/tests/configs/cpu_inference_neural_compressor_inc_quant_text_encoders.yaml b/tests/configs/cpu_inference_neural_compressor_inc_quant_text_encoders.yaml deleted file mode 100644 index 91451cf1..00000000 --- a/tests/configs/cpu_inference_neural_compressor_inc_quant_text_encoders.yaml +++ /dev/null @@ -1,12 +0,0 @@ -defaults: - # order of inheritance, last one overrides previous ones - - _base_ # inherits from base config - - _cpu_ # inherits from cpu config - - _inference_ # inherits from inference config - - _inc_quant_ # inherits from incremental quantization config - - _text_encoders_ # inherits from text encoders config - - _no_weights_ # inherits from no weights config - - _self_ # hydra 1.1 compatibility - - override backend: neural-compressor - -name: cpu_inference_neural_compressor_text_encoders diff --git a/tests/configs/cpu_inference_onnxruntime_diffusers.yaml b/tests/configs/cpu_inference_onnxruntime_diffusers.yaml index 5b44c0f2..852f6a4f 100644 --- a/tests/configs/cpu_inference_onnxruntime_diffusers.yaml +++ b/tests/configs/cpu_inference_onnxruntime_diffusers.yaml @@ -4,6 +4,7 @@ defaults: - _cpu_ # inherits from cpu config - _inference_ # inherits from inference config - _diffusers_ # inherits from diffusers config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cpu_inference_onnxruntime_ort_quant.yaml b/tests/configs/cpu_inference_onnxruntime_ort_quant.yaml index 0f0e095c..628a8c40 100644 --- a/tests/configs/cpu_inference_onnxruntime_ort_quant.yaml +++ b/tests/configs/cpu_inference_onnxruntime_ort_quant.yaml @@ -5,6 +5,7 @@ defaults: - _inference_ # inherits from inference config - _ort_quant_ # inherits from ort static quant config - _no_weights_ # inherits from no weights sweep config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cpu_inference_onnxruntime_text_decoders.yaml b/tests/configs/cpu_inference_onnxruntime_text_decoders.yaml index d87c4552..48a61019 100644 --- a/tests/configs/cpu_inference_onnxruntime_text_decoders.yaml +++ b/tests/configs/cpu_inference_onnxruntime_text_decoders.yaml @@ -5,6 +5,7 @@ defaults: - _inference_ # inherits from inference config - _text_decoders_ # inherits from text decoders config - _no_weights_ # inherits from no weights config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cpu_inference_onnxruntime_text_encoders.yaml b/tests/configs/cpu_inference_onnxruntime_text_encoders.yaml index 5e9bdb9f..a846360b 100644 --- a/tests/configs/cpu_inference_onnxruntime_text_encoders.yaml +++ b/tests/configs/cpu_inference_onnxruntime_text_encoders.yaml @@ -5,6 +5,7 @@ defaults: - _inference_ # inherits from inference config - _text_encoders_ # inherits from text encoders config - _no_weights_ # inherits from no weights config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cpu_inference_onnxruntime_text_encoders_decoders.yaml b/tests/configs/cpu_inference_onnxruntime_text_encoders_decoders.yaml index a9b725d4..5c8eeb50 100644 --- a/tests/configs/cpu_inference_onnxruntime_text_encoders_decoders.yaml +++ b/tests/configs/cpu_inference_onnxruntime_text_encoders_decoders.yaml @@ -5,6 +5,7 @@ defaults: - _inference_ # inherits from inference config - _text_encoders_decoders_ # inherits from text encoders decoders config - _no_weights_ # inherits from no weights config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cpu_inference_onnxruntime_timm.yaml b/tests/configs/cpu_inference_onnxruntime_timm.yaml index 9859c540..5e487cf9 100644 --- a/tests/configs/cpu_inference_onnxruntime_timm.yaml +++ b/tests/configs/cpu_inference_onnxruntime_timm.yaml @@ -3,6 +3,7 @@ defaults: - _base_ # inherits from base config - _cpu_ # inherits from cpu config - _inference_ # inherits from inference config + - _export_ # inherits from export config - _timm_ # inherits from timm config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cuda_inference_onnxruntime_text_decoders.yaml b/tests/configs/cuda_inference_onnxruntime_text_decoders.yaml index d43725e8..b751e282 100644 --- a/tests/configs/cuda_inference_onnxruntime_text_decoders.yaml +++ b/tests/configs/cuda_inference_onnxruntime_text_decoders.yaml @@ -6,6 +6,7 @@ defaults: - _text_decoders_ # inherits from text decoders sweep config - _device_isolation_ # inherits from device isolation config - _no_weights_ # inherits from no weights config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime diff --git a/tests/configs/cuda_inference_onnxruntime_text_encoders.yaml b/tests/configs/cuda_inference_onnxruntime_text_encoders.yaml index 2ac7133b..ad36c13a 100644 --- a/tests/configs/cuda_inference_onnxruntime_text_encoders.yaml +++ b/tests/configs/cuda_inference_onnxruntime_text_encoders.yaml @@ -6,6 +6,7 @@ defaults: - _text_encoders_ # inherits from text encoders sweep config - _device_isolation_ # inherits from device isolation config - _no_weights_ # inherits from no weights config + - _export_ # inherits from export config - _self_ # hydra 1.1 compatibility - override backend: onnxruntime From 02dfe3ca89a5b0ead3baf1d0f9ae75fa1b8445de Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 13:29:17 +0100 Subject: [PATCH 09/21] the rest --- .../backends/tensorrt_llm/backend.py | 160 +++++++++++++----- .../backends/tensorrt_llm/config.py | 30 ++-- .../backends/torch_ort/backend.py | 14 +- .../backends/torch_ort/config.py | 3 +- optimum_benchmark/backends/vllm/backend.py | 35 ++-- optimum_benchmark/backends/vllm/config.py | 8 - 6 files changed, 157 insertions(+), 93 deletions(-) diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index a05187c3..8205b83d 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -1,10 +1,16 @@ +import os from collections import OrderedDict from tempfile import TemporaryDirectory from typing import Any, Dict +import torch +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from hydra.utils import get_class +from safetensors.torch import save_file +from ...task_utils import TEXT_GENERATION_TASKS from ..base import Backend +from ..transformers_utils import fast_weights_init from .config import TRTLLMConfig from .utils import MODEL_TYPE_TO_TRTLLMMODEL @@ -25,62 +31,122 @@ def load(self) -> None: self.logger.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() - self.logger.info("\t+ Loading pretrained TRTLLMModel") - self.load_trtmodel_from_pretrained() + if self.config.no_weights: + self.logger.info("\t+ Creating no weights model") + self.create_no_weights_model() + self.logger.info("\t+ Loading no weights model") + self.load_trtllm_with_no_weights() + else: + self.logger.info("\t+ Downloading pretrained model") + self.download_pretrained_model() + if self.config.task in TEXT_GENERATION_TASKS: + self.logger.info("\t+ Preparing generation config") + self.prepare_generation_config() + self.logger.info("\t+ Loading pretrained model") + self.load_trtllm_from_pretrained() self.logger.info("\t+ Cleaning up backend temporary directory") self.tmpdir.cleanup() - def load_trtmodel_from_pretrained(self) -> None: + def download_pretrained_model(self) -> None: + with torch.device("meta"): + self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs) + + def prepare_generation_config(self) -> None: + self.generation_config.eos_token_id = None + self.generation_config.pad_token_id = None + + model_cache_folder = f"models/{self.config.model}".replace("/", "--") + model_cache_path = f"{HUGGINGFACE_HUB_CACHE}/{model_cache_folder}" + snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}" + snapshot_ref = open(snapshot_file, "r").read().strip() + model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" + self.generation_config.save_pretrained(save_directory=model_snapshot_path) + + def create_no_weights_model(self) -> None: + self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model") + self.logger.info("\t+ Creating no weights model directory") + os.makedirs(self.no_weights_model, exist_ok=True) + self.logger.info("\t+ Creating no weights model state dict") + state_dict = torch.nn.Linear(1, 1).state_dict() + self.logger.info("\t+ Saving no weights model safetensors") + safetensor = os.path.join(self.no_weights_model, "model.safetensors") + save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"}) + self.logger.info("\t+ Saving no weights model pretrained config") + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) + self.logger.info("\t+ Saving no weights model pretrained processor") + self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) + # unlike Transformers, TRT-LLM won't accept any missing tensors so we need to materialize the model + self.logger.info(f"\t+ Loading no weights model from {self.no_weights_model}") + with fast_weights_init(): + self.pretrained_model = self.automodel_loader.from_pretrained( + self.no_weights_model, **self.config.model_kwargs, device_map="auto", _fast_init=False + ) + self.logger.info("\t+ Saving no weights model") + self.pretrained_model.save_pretrained(save_directory=self.no_weights_model) + del self.pretrained_model + torch.cuda.empty_cache() + + if self.config.task in TEXT_GENERATION_TASKS: + self.logger.info("\t+ Modifying generation config for fixed length generation") + self.generation_config.eos_token_id = None + self.generation_config.pad_token_id = None + self.logger.info("\t+ Saving new pretrained generation config") + self.generation_config.save_pretrained(save_directory=self.no_weights_model) + + def load_trtllm_with_no_weights(self) -> None: + original_model, self.config.model = self.config.model, self.no_weights_model + self.load_trtllm_from_pretrained() + self.config.model = original_model + + def load_trtllm_from_pretrained(self) -> None: self.pretrained_model = self.trtllm_loader.from_pretrained( self.config.model, - tp=self.config.tp, - pp=self.config.pp, - dtype=self.config.dtype, - use_fp8=self.config.use_fp8, - world_size=self.config.world_size, - gpus_per_node=self.config.gpus_per_node, - use_cuda_graph=self.config.use_cuda_graph, - optimization_level=self.config.optimization_level, - max_prompt_length=self.config.max_prompt_length, - max_batch_size=self.config.max_batch_size, - max_new_tokens=self.config.max_new_tokens, - max_beam_width=self.config.max_beam_width, **self.config.model_kwargs, + **self.trtllm_kwargs, ) + @property + def trtllm_kwargs(self): + kwargs = {} + + if self.config.tp is not None: + kwargs["tp"] = self.config.tp + + if self.config.pp is not None: + kwargs["pp"] = self.config.pp + + if self.config.dtype is not None: + kwargs["dtype"] = self.config.dtype + + if self.config.use_fp8 is not None: + kwargs["use_fp8"] = self.config.use_fp8 + + if self.config.world_size is not None: + kwargs["world_size"] = self.config.world_size + + if self.config.gpus_per_node is not None: + kwargs["gpus_per_node"] = self.config.gpus_per_node + + if self.config.use_cuda_graph is not None: + kwargs["use_cuda_graph"] = self.config.use_cuda_graph + + if self.config.optimization_level is not None: + kwargs["optimization_level"] = self.config.optimization_level + + if self.config.max_prompt_length is not None: + kwargs["max_prompt_length"] = self.config.max_prompt_length + + if self.config.tp is not None: + kwargs["max_new_tokens"] = self.config.max_new_tokens + + if self.config.max_beam_width is not None: + kwargs["max_beam_width"] = self.config.max_beam_width + + return kwargs + def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - return self.pretrained_model.generate( - input_ids=inputs.get("input_ids"), - attention_mask=inputs.get("attention_mask"), - min_length=kwargs.get("min_new_tokens", -1), - max_new_tokens=kwargs.get("max_new_tokens", -1), - repetition_penalty=kwargs.get("repetition_penalty", 1.0), - length_penalty=kwargs.get("length_penalty", 1.0), - pad_token_id=kwargs.get("pad_token_id", 0), - bos_token_id=kwargs.get("bos_token_id", 1), - eos_token_id=kwargs.get("eos_token_id", 2), - temperature=kwargs.get("temperature", 1.0), - num_beams=kwargs.get("num_beams", 1), - top_p=kwargs.get("top_p", 1.0), - top_k=kwargs.get("top_k", 50), - seed=kwargs.get("seed", 42), - ) + return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - return self.pretrained_model.generate( - input_ids=inputs.get("input_ids"), - attention_mask=inputs.get("attention_mask"), - min_length=kwargs.get("min_new_tokens", -1), - max_new_tokens=kwargs.get("max_new_tokens", -1), - repetition_penalty=kwargs.get("repetition_penalty", 1.0), - length_penalty=kwargs.get("length_penalty", 1.0), - pad_token_id=kwargs.get("pad_token_id", 0), - bos_token_id=kwargs.get("bos_token_id", 1), - eos_token_id=kwargs.get("eos_token_id", 2), - temperature=kwargs.get("temperature", 1.0), - num_beams=kwargs.get("num_beams", 1), - top_p=kwargs.get("top_p", 1.0), - top_k=kwargs.get("top_k", 50), - seed=kwargs.get("seed", 42), - ) + return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs) diff --git a/optimum_benchmark/backends/tensorrt_llm/config.py b/optimum_benchmark/backends/tensorrt_llm/config.py index d7f4b1cb..8a011bbd 100644 --- a/optimum_benchmark/backends/tensorrt_llm/config.py +++ b/optimum_benchmark/backends/tensorrt_llm/config.py @@ -13,21 +13,21 @@ class TRTLLMConfig(BackendConfig): version: Optional[str] = tesnorrt_llm_version() _target_: str = "optimum_benchmark.backends.tensorrt_llm.backend.TRTLLMBackend" - # build config - tp: int = 1 - pp: int = 1 - use_fp8: bool = False - dtype: str = "float16" - optimization_level: int = 2 - use_cuda_graph: bool = False - - world_size: int = 1 - gpus_per_node: int = 1 - - max_prompt_length: int = 128 - max_new_tokens: int = -1 - max_batch_size: int = 1 - max_beam_width: int = 1 + no_weights: bool = False + + # trtllm kwargs + tp: Optional[int] = None + pp: Optional[int] = None + dtype: Optional[str] = None + use_fp8: Optional[bool] = None + world_size: Optional[int] = None + gpus_per_node: Optional[int] = None + use_cuda_graph: Optional[bool] = None + optimization_level: Optional[int] = None + max_prompt_length: Optional[int] = None + max_new_tokens: Optional[int] = None + max_batch_size: Optional[int] = None + max_beam_width: Optional[int] = None def __post_init__(self) -> None: super().__post_init__() diff --git a/optimum_benchmark/backends/torch_ort/backend.py b/optimum_benchmark/backends/torch_ort/backend.py index 61401a75..7b7c1c04 100644 --- a/optimum_benchmark/backends/torch_ort/backend.py +++ b/optimum_benchmark/backends/torch_ort/backend.py @@ -39,19 +39,17 @@ def load(self) -> None: self.tmpdir.cleanup() def load_automodel_with_no_weights(self) -> None: - original_model, self.config.model = self.config.model, self.no_weights_model - with fast_weights_init(): + original_model, self.config.model = self.config.model, self.no_weights_model self.load_automodel_from_pretrained() - - self.logger.info("\t+ Tying model weights") - self.pretrained_model.tie_weights() - - self.config.model = original_model + self.pretrained_model.tie_weights() + self.config.model = original_model def load_automodel_from_pretrained(self) -> None: self.pretrained_model = self.automodel_loader.from_pretrained( - self.config.model, **self.automodel_kwargs, **self.config.model_kwargs + self.config.model, + **self.config.model_kwargs, + **self.automodel_kwargs, ).to(self.config.device) @property diff --git a/optimum_benchmark/backends/torch_ort/config.py b/optimum_benchmark/backends/torch_ort/config.py index adc37288..a1e129b1 100644 --- a/optimum_benchmark/backends/torch_ort/config.py +++ b/optimum_benchmark/backends/torch_ort/config.py @@ -14,8 +14,7 @@ class TorchORTConfig(BackendConfig): # load options no_weights: bool = False torch_dtype: Optional[str] = None - # sdpa, which has became default of many architectures, fails with torch ort - attn_implementation: Optional[str] = "eager" + attn_implementation: Optional[str] = None # peft options peft_type: Optional[str] = None diff --git a/optimum_benchmark/backends/vllm/backend.py b/optimum_benchmark/backends/vllm/backend.py index eadd6c0a..af5969b4 100644 --- a/optimum_benchmark/backends/vllm/backend.py +++ b/optimum_benchmark/backends/vllm/backend.py @@ -6,7 +6,10 @@ import torch from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from safetensors.torch import save_file -from vllm import AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams from ...task_utils import TEXT_GENERATION_TASKS from ..base import Backend @@ -32,7 +35,7 @@ def load(self) -> None: self.logger.info("\t+ Creating no weights model") self.create_no_weights_model() self.logger.info("\t+ Loading no weights model") - self.load_model_with_no_weights() + self.load_vllm_with_no_weights() else: self.logger.info("\t+ Downloading pretrained model") self.download_pretrained_model() @@ -40,7 +43,7 @@ def load(self) -> None: self.logger.info("\t+ Preparing generation config") self.prepare_generation_config() self.logger.info("\t+ Loading pretrained model") - self.load_model_from_pretrained() + self.load_vllm_from_pretrained() self.logger.info("\t+ Cleaning up backend temporary directory") self.tmpdir.cleanup() @@ -52,13 +55,11 @@ def download_pretrained_model(self) -> None: def prepare_generation_config(self) -> None: self.generation_config.eos_token_id = None self.generation_config.pad_token_id = None - model_cache_folder = f"models/{self.config.model}".replace("/", "--") model_cache_path = f"{HUGGINGFACE_HUB_CACHE}/{model_cache_folder}" snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}" snapshot_ref = open(snapshot_file, "r").read().strip() model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" - self.logger.info("\t+ Saving new pretrained generation config") self.generation_config.save_pretrained(save_directory=model_snapshot_path) def create_no_weights_model(self) -> None: @@ -92,19 +93,27 @@ def create_no_weights_model(self) -> None: self.logger.info("\t+ Saving new pretrained generation config") self.generation_config.save_pretrained(save_directory=self.no_weights_model) - def load_model_with_no_weights(self) -> None: + def load_vllm_with_no_weights(self) -> None: original_model, self.config.model = self.config.model, self.no_weights_model - self.logger.info("\t+ Loading no weights model") - self.load_model_from_pretrained() + self.load_vllm_from_pretrained() self.config.model = original_model - def load_model_from_pretrained(self) -> None: + def load_vllm_from_pretrained(self) -> None: if self.config.serving_mode == "offline": - self.pretrained_model = LLMEngine.from_engine_args(EngineArgs(**self.config.to_engine_args())) + self.pretrained_model = LLMEngine.from_engine_args(EngineArgs(**self.vllm_kwargs)) else: - self.pretrained_model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.config.to_engine_args())) - - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + self.pretrained_model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.vllm_kwargs)) + + @property + def vllm_kwargs(self): + return { + "model": self.config.model, + "tokenizer": self.config.processor, + "device": self.config.device, + **self.config.engine_args, + } + + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompts": self.pretrained_processor.batch_decode(inputs["input_ids"])} else: diff --git a/optimum_benchmark/backends/vllm/config.py b/optimum_benchmark/backends/vllm/config.py index 47ae475b..00157220 100644 --- a/optimum_benchmark/backends/vllm/config.py +++ b/optimum_benchmark/backends/vllm/config.py @@ -54,11 +54,3 @@ def __post_init__(self): if self.serving_mode == "online": if self.engine_args.get("disable_log_requests", None) is None: self.engine_args["disable_log_requests"] = True - - def to_engine_args(self) -> Dict[str, Any]: - return dict( - model=self.model, - tokenizer=self.processor, - device=self.device, - **self.engine_args, - ) From d5cd55cd92b58081edbafc6bdf2799958ebfbb6e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 13:41:05 +0100 Subject: [PATCH 10/21] fix --- tests/configs/_diffusers_.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/configs/_diffusers_.yaml b/tests/configs/_diffusers_.yaml index 7083e261..25bcc6be 100644 --- a/tests/configs/_diffusers_.yaml +++ b/tests/configs/_diffusers_.yaml @@ -3,5 +3,5 @@ hydra: sweeper: params: backend.library: diffusers - backend.task: text-to-image,image-to-image,inpainting + backend.task: text-to-image backend.model: hf-internal-testing/tiny-stable-diffusion-torch From 6738b852db7aa5d10905f2f5fc9ac5cc89d1f0b7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 13:47:12 +0100 Subject: [PATCH 11/21] fix --- optimum_benchmark/backends/tensorrt_llm/backend.py | 2 +- optimum_benchmark/backends/tensorrt_llm/config.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index 8205b83d..2dbc0fea 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -137,7 +137,7 @@ def trtllm_kwargs(self): if self.config.max_prompt_length is not None: kwargs["max_prompt_length"] = self.config.max_prompt_length - if self.config.tp is not None: + if self.config.max_new_tokens is not None: kwargs["max_new_tokens"] = self.config.max_new_tokens if self.config.max_beam_width is not None: diff --git a/optimum_benchmark/backends/tensorrt_llm/config.py b/optimum_benchmark/backends/tensorrt_llm/config.py index 8a011bbd..84d119af 100644 --- a/optimum_benchmark/backends/tensorrt_llm/config.py +++ b/optimum_benchmark/backends/tensorrt_llm/config.py @@ -4,7 +4,7 @@ from ...import_utils import tesnorrt_llm_version from ..config import BackendConfig -SUPPORTED_DTYPES = ["float16", "bfloat16", "float32"] +SUPPORTED_DTYPES = [None, "float16", "bfloat16", "float32"] @dataclass @@ -38,8 +38,13 @@ def __post_init__(self) -> None: if self.dtype not in SUPPORTED_DTYPES: raise ValueError(f"dtype must be one of float16, bfloat16, float32, got {self.dtype}") - if self.gpus_per_node != self.world_size: + if self.gpus_per_node is not None and self.world_size is not None and self.gpus_per_node != self.world_size: raise ValueError(f"gpus_per_node ({self.gpus_per_node}) != world_size ({self.world_size})") - if self.world_size != self.pp * self.tp: + if ( + self.world_size is not None + and self.pp is not None + and self.tp is not None + and self.world_size != self.pp * self.tp + ): raise ValueError(f"world_size ({self.gpus_per_node}) != pp ({self.pp}) * tp ({self.tp})") From b4b8df7e146e3f42e515818be1e0a13525364d30 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 14:14:09 +0100 Subject: [PATCH 12/21] sngle prepare inputs --- optimum_benchmark/backends/base.py | 12 +------- optimum_benchmark/backends/ipex/backend.py | 13 ++++----- .../backends/llama_cpp/backend.py | 8 ++++-- .../backends/onnxruntime/backend.py | 5 +--- .../backends/openvino/backend.py | 28 ++++--------------- optimum_benchmark/backends/openvino/config.py | 5 ++-- optimum_benchmark/backends/py_txi/backend.py | 2 +- optimum_benchmark/backends/py_txi/config.py | 2 +- optimum_benchmark/backends/pytorch/backend.py | 2 +- .../backends/tensorrt_llm/backend.py | 8 ++++-- .../backends/torch_ort/config.py | 4 ++- optimum_benchmark/backends/vllm/backend.py | 2 +- .../scenarios/inference/scenario.py | 7 ++--- 13 files changed, 36 insertions(+), 62 deletions(-) diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 5fddcca9..f79ea1a2 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -106,23 +106,13 @@ def create_no_weights_model(self) -> None: self.logger.info("\t+ Saving no weights model's config") self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """ - This method is used to prepare and register the inputs before passing them to the model. - It can be used to move the inputs to the correct device, or rename their keys. - """ - return inputs - - def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """ This method is used to prepare and register the inputs before passing them to the model. It can be used to move the inputs to the correct device, or rename their keys. """ return inputs - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - return self.prepare_inputs_after_load(self.prepare_inputs_before_load(inputs)) - def load(self) -> None: raise NotImplementedError("Backend must implement load method") diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index 049b2c7b..d32a203d 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -38,26 +38,25 @@ def load(self) -> None: self.logger.info("\t+ Creating no weights IPEXModel") self.create_no_weights_model() self.logger.info("\t+ Loading no weights IPEXModel") - self._load_ipexmodel_with_no_weights() + self.load_ipexmodel_with_no_weights() else: self.logger.info("\t+ Loading pretrained IPEXModel") - self._load_ipexmodel_from_pretrained() + self.load_ipexmodel_from_pretrained() self.tmpdir.cleanup() - def _load_ipexmodel_from_pretrained(self) -> None: + def load_ipexmodel_from_pretrained(self) -> None: self.pretrained_model = self.ipexmodel_class.from_pretrained( self.config.model, **self.config.model_kwargs, **self.ipexmodel_kwargs, ) - def _load_ipexmodel_with_no_weights(self) -> None: + def load_ipexmodel_with_no_weights(self) -> None: with fast_weights_init(): - self.logger.info("\t+ Loading no weights IPEXModel") original_model, self.config.model = self.config.model, self.no_weights_model original_export, self.config.export = self.config.export, True - self._load_ipexmodel_from_pretrained() + self.load_ipexmodel_from_pretrained() self.config.export = original_export self.config.model = original_model @@ -77,7 +76,7 @@ def ipexmodel_kwargs(self) -> Dict[str, Any]: def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs diff --git a/optimum_benchmark/backends/llama_cpp/backend.py b/optimum_benchmark/backends/llama_cpp/backend.py index 67b2dbb2..ef888ddd 100644 --- a/optimum_benchmark/backends/llama_cpp/backend.py +++ b/optimum_benchmark/backends/llama_cpp/backend.py @@ -41,7 +41,7 @@ def llama_cpp_kwargs(self) -> Dict[str, Any]: "echo": False, } - def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task == "text-generation": if inputs["input_ids"].shape[0] != 1: raise ValueError("Batch size must be 1 for Text Generation with llama-cpp-python") @@ -55,9 +55,11 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: self.pretrained_model.embed(**inputs) def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]: - next(self.pretrained_model.generate(**inputs)) + generator = self.pretrained_model.generate(**inputs, reset=True) + for _ in range(kwargs["max_new_tokens"]): + next(generator) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]: - generator = self.pretrained_model.generate(**inputs) + generator = self.pretrained_model.generate(**inputs, reset=True) for _ in range(kwargs["max_new_tokens"]): next(generator) diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index feec7ec3..1af41db0 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -284,7 +284,7 @@ def quantize_onnx_files(self) -> None: def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs @@ -293,9 +293,6 @@ def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if isinstance(value, torch.Tensor): inputs[key] = value.to(self.config.device) - return inputs - - def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: for key in list(inputs.keys()): if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: inputs.pop(key) diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index f4ead46c..be903e7d 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -1,4 +1,3 @@ -import inspect from collections import OrderedDict from tempfile import TemporaryDirectory from typing import Any, Dict @@ -48,14 +47,8 @@ def load(self) -> None: self.load_ovmodel_from_pretrained() if self.config.reshape: - static_shapes = { - key: value - for key, value in self.model_shapes.items() - if key in inspect.getfullargspec(self.pretrained_model.reshape).args - } - - self.logger.info(f"\t+ Reshaping model with static shapes: {static_shapes}") - self.pretrained_model.reshape(**static_shapes) + self.logger.info("\t+ Reshaping model with static shapes") + self.pretrained_model.reshape(**self.config.reshape_kwargs) if self.config.half: self.logger.info("\t+ Converting model to half precision") @@ -78,7 +71,6 @@ def load_ovmodel_with_no_weights(self) -> None: with fast_weights_init(): original_model, self.config.model = self.config.model, self.no_weights_model original_export, self.config.export = self.config.export, True - self.logger.info("\t+ Loading no weights OVModel") self.load_ovmodel_from_pretrained() self.config.export = original_export self.config.model = original_model @@ -102,28 +94,20 @@ def ovmodel_kwargs(self) -> Dict[str, Any]: if self.config.load_in_4bit is not None: kwargs["load_in_4bit"] = self.config.load_in_4bit + if self.config.ov_config: + kwargs["ov_config"] = self.config.ov_config + return kwargs @property def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs - if "input_ids" in inputs: - self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape))) - - if "pixel_values" in inputs: - self.model_shapes.update( - dict(zip(["batch_size", "num_channels", "height", "width"], inputs["pixel_values"].shape)) - ) - - return inputs - - def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: for key in list(inputs.keys()): if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: inputs.pop(key) diff --git a/optimum_benchmark/backends/openvino/config.py b/optimum_benchmark/backends/openvino/config.py index ffdf1566..e6716b86 100644 --- a/optimum_benchmark/backends/openvino/config.py +++ b/optimum_benchmark/backends/openvino/config.py @@ -19,14 +19,13 @@ class OVConfig(BackendConfig): use_merged: Optional[bool] = None load_in_8bit: Optional[bool] = None load_in_4bit: Optional[bool] = None + ov_config: Dict[str, Any] = field(default_factory=dict) # compilation options half: bool = False compile: bool = False reshape: bool = False - - # openvino config - ov_config: Dict[str, Any] = field(default_factory=dict) + reshape_kwargs: Dict[str, int] = field(default_factory=dict) def __post_init__(self): super().__post_init__() diff --git a/optimum_benchmark/backends/py_txi/backend.py b/optimum_benchmark/backends/py_txi/backend.py index cf9a6920..6e637a31 100644 --- a/optimum_benchmark/backends/py_txi/backend.py +++ b/optimum_benchmark/backends/py_txi/backend.py @@ -139,7 +139,7 @@ def load_model_from_pretrained(self) -> None: else: raise NotImplementedError(f"TXI does not support task {self.config.task}") - def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} elif self.config.task in TEXT_EMBEDDING_TASKS: diff --git a/optimum_benchmark/backends/py_txi/config.py b/optimum_benchmark/backends/py_txi/config.py index 73b75b75..dae410c4 100644 --- a/optimum_benchmark/backends/py_txi/config.py +++ b/optimum_benchmark/backends/py_txi/config.py @@ -51,8 +51,8 @@ class PyTXIConfig(BackendConfig): num_shard: Optional[int] = None speculate: Optional[int] = None cuda_graphs: Optional[int] = None - disable_custom_kernels: Optional[bool] = None trust_remote_code: Optional[bool] = None + disable_custom_kernels: Optional[bool] = None # TEI specific pooling: Optional[str] = None diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 3fb8d80a..651e6d12 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -416,7 +416,7 @@ def split_between_processes(self) -> bool: and not self.config.deepspeed_inference ) - def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index 2dbc0fea..f23d294e 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -146,7 +146,11 @@ def trtllm_kwargs(self): return kwargs def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs) + return self.pretrained_model.generate( + inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs + ) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs) + return self.pretrained_model.generate( + inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs + ) diff --git a/optimum_benchmark/backends/torch_ort/config.py b/optimum_benchmark/backends/torch_ort/config.py index a1e129b1..17d2895d 100644 --- a/optimum_benchmark/backends/torch_ort/config.py +++ b/optimum_benchmark/backends/torch_ort/config.py @@ -14,7 +14,9 @@ class TorchORTConfig(BackendConfig): # load options no_weights: bool = False torch_dtype: Optional[str] = None - attn_implementation: Optional[str] = None + attn_implementation: Optional[str] = ( + "eager" # we pin eager because sdpa became default of many architectures, which fails with torch-ort + ) # peft options peft_type: Optional[str] = None diff --git a/optimum_benchmark/backends/vllm/backend.py b/optimum_benchmark/backends/vllm/backend.py index af5969b4..7405d4dc 100644 --- a/optimum_benchmark/backends/vllm/backend.py +++ b/optimum_benchmark/backends/vllm/backend.py @@ -113,7 +113,7 @@ def vllm_kwargs(self): **self.config.engine_args, } - def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompts": self.pretrained_processor.batch_decode(inputs["input_ids"])} else: diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index c84f27e0..d18761b3 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -126,13 +126,10 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: input_shapes=self.config.input_shapes, )() - self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name} before model loading.") - self.inputs = self.backend.prepare_inputs_before_load(inputs=self.inputs) - self.run_model_loading_tracking() - self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name} after model loading.") - self.inputs = self.backend.prepare_inputs_after_load(inputs=self.inputs) + self.logger.info(f"\t+ Preparing inputs for backend {self.backend.config.name}") + self.inputs = self.backend.prepare_inputs(inputs=self.inputs) if self.config.warmup_runs > 0: if self.backend.config.task in TEXT_GENERATION_TASKS: From 795badba0d0a20997df1ca8a5cec43fafc6647d0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 14:35:17 +0100 Subject: [PATCH 13/21] fixes --- examples/cpu_ipex_bert.yaml | 4 +-- examples/cpu_ipex_llama.yaml | 4 +-- examples/cpu_onnxruntime_timm.yaml | 20 ------------- examples/cpu_openvino_8bit_bert.yaml | 5 +++- examples/cuda_tgi_llama.yaml | 1 + examples/cuda_trt_llama.yaml | 1 + examples/cuda_vllm_llama.yaml | 3 +- examples/mps_pytorch_bert.yaml | 12 ++++---- .../backends/onnxruntime/backend.py | 4 ++- .../backends/tensorrt_llm/backend.py | 30 +++++++++++++++++-- 10 files changed, 49 insertions(+), 35 deletions(-) delete mode 100644 examples/cpu_onnxruntime_timm.yaml diff --git a/examples/cpu_ipex_bert.yaml b/examples/cpu_ipex_bert.yaml index 0e7ed37b..1974e9c3 100644 --- a/examples/cpu_ipex_bert.yaml +++ b/examples/cpu_ipex_bert.yaml @@ -17,8 +17,8 @@ launcher: backend: device: cpu export: true - no_weights: false # because on multi-node machines, intializing weights could harm performance - torch_dtype: float32 # but use bfloat16 on compatible Intel CPUs + no_weights: false # on multi-node machines, intializing weights in the benchmark could harm performance + torch_dtype: float32 # use bfloat16 on compatible Intel CPUs model: google-bert/bert-base-uncased scenario: diff --git a/examples/cpu_ipex_llama.yaml b/examples/cpu_ipex_llama.yaml index 898ed0df..50e23c55 100644 --- a/examples/cpu_ipex_llama.yaml +++ b/examples/cpu_ipex_llama.yaml @@ -17,8 +17,8 @@ launcher: backend: device: cpu export: true - no_weights: false # because on multi-node machines, intializing weights could harm performance - torch_dtype: float32 # but use bfloat16 on compatible Intel CPUs + no_weights: false # on multi-node machines, intializing weights in the benchmark could harm performance + torch_dtype: float32 # use bfloat16 on compatible Intel CPUs model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 scenario: diff --git a/examples/cpu_onnxruntime_timm.yaml b/examples/cpu_onnxruntime_timm.yaml deleted file mode 100644 index 963f44f0..00000000 --- a/examples/cpu_onnxruntime_timm.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - benchmark - - backend: onnxruntime - - scenario: inference - - launcher: process - - _base_ - - _self_ - -name: onnxruntime_timm - -backend: - device: cpu - export: true - model: timm/tiny_vit_21m_224.in1k - -scenario: - memory: true - latency: true - input_shapes: - batch_size: 2 diff --git a/examples/cpu_openvino_8bit_bert.yaml b/examples/cpu_openvino_8bit_bert.yaml index 73ef474d..a3c33327 100644 --- a/examples/cpu_openvino_8bit_bert.yaml +++ b/examples/cpu_openvino_8bit_bert.yaml @@ -12,8 +12,11 @@ backend: device: cpu reshape: true no_weights: true - load_in_8bit: false # enable 8bit on compatible Intel CPU machines + load_in_8bit: true model: google-bert/bert-base-uncased + reshape_kwargs: + batch_size: 1 + sequence_length: 128 scenario: memory: true diff --git a/examples/cuda_tgi_llama.yaml b/examples/cuda_tgi_llama.yaml index 297403c8..f7ac411c 100644 --- a/examples/cuda_tgi_llama.yaml +++ b/examples/cuda_tgi_llama.yaml @@ -16,6 +16,7 @@ backend: device: cuda device_ids: 0 cuda_graphs: 0 # remove for better perf but bigger memory footprint + no_weights: true model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 scenario: diff --git a/examples/cuda_trt_llama.yaml b/examples/cuda_trt_llama.yaml index c483fc2f..501280e8 100644 --- a/examples/cuda_trt_llama.yaml +++ b/examples/cuda_trt_llama.yaml @@ -15,6 +15,7 @@ launcher: backend: device: cuda device_ids: 0 + no_weights: true max_batch_size: 4 max_new_tokens: 32 max_prompt_length: 64 diff --git a/examples/cuda_vllm_llama.yaml b/examples/cuda_vllm_llama.yaml index 5ec4b5a8..4a624cc1 100644 --- a/examples/cuda_vllm_llama.yaml +++ b/examples/cuda_vllm_llama.yaml @@ -15,7 +15,8 @@ launcher: backend: device: cuda device_ids: 0 - serving_mode: online # server-like + no_weights: true + serving_mode: online model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 engine_args: enforce_eager: true # remove for better perf but bigger memory footprint diff --git a/examples/mps_pytorch_bert.yaml b/examples/mps_pytorch_bert.yaml index 27368eb1..f805abed 100644 --- a/examples/mps_pytorch_bert.yaml +++ b/examples/mps_pytorch_bert.yaml @@ -8,14 +8,14 @@ defaults: name: mps_pytorch_bert +backend: + device: mps + no_weights: true + model: bert-base-uncased + scenario: - latency: true memory: true + latency: true input_shapes: batch_size: 1 sequence_length: 128 - -backend: - device: mps - no_weights: true - model: bert-base-uncased diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 1af41db0..b50e2258 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -223,7 +223,7 @@ def quantize_onnx_files(self) -> None: if self.is_calibrated: self.logger.info("\t+ Generating calibration dataset") - dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes} + dataset_shapes = {"dataset_size": 2, "sequence_length": 2, "num_choices": 2} calibration_dataset = DatasetGenerator( task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes )() @@ -275,8 +275,10 @@ def quantize_onnx_files(self) -> None: preprocessor=None, file_suffix="", ) + if self.pretrained_processor is not None: self.pretrained_processor.save_pretrained(self.quantized_model) + if self.pretrained_config is not None: self.pretrained_config.save_pretrained(self.quantized_model) diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index f23d294e..e0b2b1e9 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -147,10 +147,36 @@ def trtllm_kwargs(self): def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.generate( - inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs + input_ids=inputs.get("input_ids"), + attention_mask=inputs.get("attention_mask"), + min_length=kwargs.get("min_new_tokens", None), + max_new_tokens=kwargs.get("max_new_tokens", None), + repetition_penalty=kwargs.get("repetition_penalty", None), + length_penalty=kwargs.get("length_penalty", None), + pad_token_id=kwargs.get("pad_token_id", None), + bos_token_id=kwargs.get("bos_token_id", None), + eos_token_id=kwargs.get("eos_token_id", None), + temperature=kwargs.get("temperature", None), + num_beams=kwargs.get("num_beams", None), + top_p=kwargs.get("top_p", None), + top_k=kwargs.get("top_k", None), + seed=kwargs.get("seed", None), ) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.generate( - inputs=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), **kwargs + input_ids=inputs.get("input_ids"), + attention_mask=inputs.get("attention_mask"), + min_length=kwargs.get("min_new_tokens", None), + max_new_tokens=kwargs.get("max_new_tokens", None), + repetition_penalty=kwargs.get("repetition_penalty", None), + length_penalty=kwargs.get("length_penalty", None), + pad_token_id=kwargs.get("pad_token_id", None), + bos_token_id=kwargs.get("bos_token_id", None), + eos_token_id=kwargs.get("eos_token_id", None), + temperature=kwargs.get("temperature", None), + num_beams=kwargs.get("num_beams", None), + top_p=kwargs.get("top_p", None), + top_k=kwargs.get("top_k", None), + seed=kwargs.get("seed", None), ) From 7e9ef7f3283738f77ed9afd18f7446cca7f8e13d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 15:08:32 +0100 Subject: [PATCH 14/21] fix --- examples/cuda_tgi_llama.yaml | 2 +- .../backends/tensorrt_llm/backend.py | 30 ++++--------------- setup.py | 2 +- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/examples/cuda_tgi_llama.yaml b/examples/cuda_tgi_llama.yaml index f7ac411c..ac5bcdc3 100644 --- a/examples/cuda_tgi_llama.yaml +++ b/examples/cuda_tgi_llama.yaml @@ -16,7 +16,7 @@ backend: device: cuda device_ids: 0 cuda_graphs: 0 # remove for better perf but bigger memory footprint - no_weights: true + no_weights: false # investigate later model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 scenario: diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index e0b2b1e9..5e7bad89 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -149,34 +149,16 @@ def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict return self.pretrained_model.generate( input_ids=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), - min_length=kwargs.get("min_new_tokens", None), - max_new_tokens=kwargs.get("max_new_tokens", None), - repetition_penalty=kwargs.get("repetition_penalty", None), - length_penalty=kwargs.get("length_penalty", None), - pad_token_id=kwargs.get("pad_token_id", None), - bos_token_id=kwargs.get("bos_token_id", None), - eos_token_id=kwargs.get("eos_token_id", None), - temperature=kwargs.get("temperature", None), - num_beams=kwargs.get("num_beams", None), - top_p=kwargs.get("top_p", None), - top_k=kwargs.get("top_k", None), - seed=kwargs.get("seed", None), + pad_token_id=kwargs.get("pad_token_id", 0), + eos_token_id=kwargs.get("eos_token_id", 1), + **kwargs, ) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.generate( input_ids=inputs.get("input_ids"), attention_mask=inputs.get("attention_mask"), - min_length=kwargs.get("min_new_tokens", None), - max_new_tokens=kwargs.get("max_new_tokens", None), - repetition_penalty=kwargs.get("repetition_penalty", None), - length_penalty=kwargs.get("length_penalty", None), - pad_token_id=kwargs.get("pad_token_id", None), - bos_token_id=kwargs.get("bos_token_id", None), - eos_token_id=kwargs.get("eos_token_id", None), - temperature=kwargs.get("temperature", None), - num_beams=kwargs.get("num_beams", None), - top_p=kwargs.get("top_p", None), - top_k=kwargs.get("top_k", None), - seed=kwargs.get("seed", None), + pad_token_id=kwargs.get("pad_token_id", 0), + eos_token_id=kwargs.get("eos_token_id", 1), + **kwargs, ) diff --git a/setup.py b/setup.py index 46a1ed60..c7f81246 100644 --- a/setup.py +++ b/setup.py @@ -65,10 +65,10 @@ "testing": ["pytest", "hydra-joblib-launcher"], # optimum backends "ipex": [f"optimum[ipex]>={MIN_OPTIMUM_VERSION}"], + "tensorrt-llm": [f"optimum[nvidia]>={MIN_OPTIMUM_VERSION}"], "openvino": [f"optimum[openvino,nncf]>={MIN_OPTIMUM_VERSION}"], "onnxruntime": [f"optimum[onnxruntime]>={MIN_OPTIMUM_VERSION}"], "onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={MIN_OPTIMUM_VERSION}"], - "neural-compressor": [f"optimum[neural-compressor]>={MIN_OPTIMUM_VERSION}"], "torch-ort": ["torch-ort", "onnxruntime-training", f"optimum>={MIN_OPTIMUM_VERSION}"], # other backends "llama-cpp": ["llama-cpp-python"], From 4c506a098563a375a7334a0351e0683801d0dad2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 15:10:08 +0100 Subject: [PATCH 15/21] reduce diffusion --- examples/cpu_openvino_diffusion.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/cpu_openvino_diffusion.yaml b/examples/cpu_openvino_diffusion.yaml index 30d21935..147f160f 100644 --- a/examples/cpu_openvino_diffusion.yaml +++ b/examples/cpu_openvino_diffusion.yaml @@ -17,3 +17,7 @@ backend: scenario: input_shapes: batch_size: 1 + sequence_length: 16 + + call_kwargs: + num_inference_steps: 4 From 7faf1e6953beca93229f5cdcbef9c827c79a24a4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 15:10:16 +0100 Subject: [PATCH 16/21] diffusion task --- examples/cpu_openvino_diffusion.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/cpu_openvino_diffusion.yaml b/examples/cpu_openvino_diffusion.yaml index 147f160f..37b4ee71 100644 --- a/examples/cpu_openvino_diffusion.yaml +++ b/examples/cpu_openvino_diffusion.yaml @@ -11,6 +11,7 @@ name: openvino_diffusion backend: device: cpu export: true + task: text-to-image model: stabilityai/stable-diffusion-2-1 half: false # enable half-precision on compatible Intel CPU machines From c92db7361d13f76f572a796f9dde890413bedeef Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 15:16:03 +0100 Subject: [PATCH 17/21] fix naming --- optimum_benchmark/backends/ipex/backend.py | 6 +++--- optimum_benchmark/backends/ipex/utils.py | 2 +- optimum_benchmark/backends/onnxruntime/backend.py | 6 +++--- optimum_benchmark/backends/onnxruntime/utils.py | 2 +- optimum_benchmark/backends/openvino/backend.py | 10 +++++----- optimum_benchmark/backends/openvino/utils.py | 4 ++-- optimum_benchmark/backends/tensorrt_llm/backend.py | 6 +++--- optimum_benchmark/backends/tensorrt_llm/utils.py | 2 +- 8 files changed, 19 insertions(+), 19 deletions(-) diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index d32a203d..a637ced0 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -9,7 +9,7 @@ from ..base import Backend from ..transformers_utils import fast_weights_init from .config import IPEXConfig -from .utils import TASKS_TO_IPEXMODEL +from .utils import TASKS_TO_IPEXMODELS if is_accelerate_available(): from accelerate import Accelerator @@ -24,8 +24,8 @@ class IPEXBackend(Backend[IPEXConfig]): def __init__(self, config: IPEXConfig) -> None: super().__init__(config) - if self.config.task in TASKS_TO_IPEXMODEL: - self.ipexmodel_class = get_class(TASKS_TO_IPEXMODEL[self.config.task]) + if self.config.task in TASKS_TO_IPEXMODELS: + self.ipexmodel_class = get_class(TASKS_TO_IPEXMODELS[self.config.task]) self.logger.info(f"\t+ Using IPEXModel class {self.ipexmodel_class.__name__}") else: raise NotImplementedError(f"IPEXBackend does not support task {self.config.task}") diff --git a/optimum_benchmark/backends/ipex/utils.py b/optimum_benchmark/backends/ipex/utils.py index dd68428e..7b4a83f8 100644 --- a/optimum_benchmark/backends/ipex/utils.py +++ b/optimum_benchmark/backends/ipex/utils.py @@ -1,4 +1,4 @@ -TASKS_TO_IPEXMODEL = { +TASKS_TO_IPEXMODELS = { "fill-mask": "optimum.intel.IPEXModelForMaskedLM", "text-generation": "optimum.intel.IPEXModelForCausalLM", "feature-extraction": "optimum.intel.IPEXModel", diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index b50e2258..ce7386b3 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -28,7 +28,7 @@ from .config import ORTConfig from .utils import ( TASKS_TO_ORTMODELS, - TASKS_TO_ORTPIPELINE, + TASKS_TO_ORTPIPELINES, format_calibration_config, format_quantization_config, ) @@ -49,8 +49,8 @@ def __init__(self, config: ORTConfig) -> None: if self.config.library != "diffusers" and self.config.task in TASKS_TO_ORTMODELS: self.ort_model_loader = get_class(TASKS_TO_ORTMODELS[self.config.task]) self.logger.info(f"Using ORTModel class {self.ort_model_loader.__name__}") - elif self.config.library == "diffusers" and self.config.task in TASKS_TO_ORTPIPELINE: - self.ort_model_loader = get_class(TASKS_TO_ORTPIPELINE[self.config.task]) + elif self.config.library == "diffusers" and self.config.task in TASKS_TO_ORTPIPELINES: + self.ort_model_loader = get_class(TASKS_TO_ORTPIPELINES[self.config.task]) self.logger.info(f"Using ORTDiffusionPipeline class {self.ort_model_loader.__name__}") else: raise NotImplementedError(f"ORTBackend does not support task {self.config.task}") diff --git a/optimum_benchmark/backends/onnxruntime/utils.py b/optimum_benchmark/backends/onnxruntime/utils.py index 63598223..e8cbe1eb 100644 --- a/optimum_benchmark/backends/onnxruntime/utils.py +++ b/optimum_benchmark/backends/onnxruntime/utils.py @@ -7,7 +7,7 @@ task: f"optimum.onnxruntime.{task_dict['class'][0].__name__}" for task, task_dict in ORT_SUPPORTED_TASKS.items() } -TASKS_TO_ORTPIPELINE = { +TASKS_TO_ORTPIPELINES = { "inpainting": "optimum.onnxruntime.ORTPipelineForInpainting", "text-to-image": "optimum.onnxruntime.ORTPipelineForText2Image", "image-to-image": "optimum.onnxruntime.ORTPipelineForImage2Image", diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index be903e7d..11be6e13 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -9,7 +9,7 @@ from ..base import Backend from ..transformers_utils import fast_weights_init from .config import OVConfig as OVBackendConfig -from .utils import TASKS_OVPIPELINE, TASKS_TO_OVMODEL +from .utils import TASKS_TO_OVMODELS, TASKS_TO_OVPIPELINES if is_accelerate_available(): from accelerate import Accelerator @@ -24,11 +24,11 @@ class OVBackend(Backend[OVBackendConfig]): def __init__(self, config: OVBackendConfig) -> None: super().__init__(config) - if self.config.task in TASKS_TO_OVMODEL: - self.ovmodel_class = get_class(TASKS_TO_OVMODEL[self.config.task]) + if self.config.library != "diffusers" and self.config.task in TASKS_TO_OVMODELS: + self.ovmodel_class = get_class(TASKS_TO_OVMODELS[self.config.task]) self.logger.info(f"\t+ Using OVModel class {self.ovmodel_class.__name__}") - elif self.config.task in TASKS_OVPIPELINE: - self.ovmodel_class = get_class(TASKS_OVPIPELINE[self.config.task]) + elif self.config.library == "diffusers" and self.config.task in TASKS_TO_OVPIPELINES: + self.ovmodel_class = get_class(TASKS_TO_OVPIPELINES[self.config.task]) self.logger.info(f"\t+ Using OVDiffusionPipeline class {self.ovmodel_class.__name__}") else: raise NotImplementedError(f"OVBackend does not support task {self.config.task}") diff --git a/optimum_benchmark/backends/openvino/utils.py b/optimum_benchmark/backends/openvino/utils.py index 51f9629c..e382d724 100644 --- a/optimum_benchmark/backends/openvino/utils.py +++ b/optimum_benchmark/backends/openvino/utils.py @@ -1,4 +1,4 @@ -TASKS_TO_OVMODEL = { +TASKS_TO_OVMODELS = { "fill-mask": "optimum.intel.openvino.OVModelForMaskedLM", "text-generation": "optimum.intel.openvino.OVModelForCausalLM", "text2text-generation": "optimum.intel.openvino.OVModelForSeq2SeqLM", @@ -10,7 +10,7 @@ "audio-classification": "optimum.intel.openvino.OVModelForAudioClassification", "pix2struct": "optimum.intel.openvino.OVModelForPix2Struct", } -TASKS_OVPIPELINE = { +TASKS_TO_OVPIPELINES = { "inpainting": "optimum.intel.openvino.OVPipelineForInpainting", "text-to-image": "optimum.intel.openvino.OVPipelineForText2Image", "image-to-image": "optimum.intel.openvino.OVPipelineForImage2Image", diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index 5e7bad89..8cd046eb 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -12,7 +12,7 @@ from ..base import Backend from ..transformers_utils import fast_weights_init from .config import TRTLLMConfig -from .utils import MODEL_TYPE_TO_TRTLLMMODEL +from .utils import MODEL_TYPE_TO_TRTLLMMODELS class TRTLLMBackend(Backend[TRTLLMConfig]): @@ -21,8 +21,8 @@ class TRTLLMBackend(Backend[TRTLLMConfig]): def __init__(self, config: TRTLLMConfig): super().__init__(config) - if self.config.model_type in MODEL_TYPE_TO_TRTLLMMODEL: - self.trtllm_loader = get_class(MODEL_TYPE_TO_TRTLLMMODEL[self.config.model_type]) + if self.config.model_type in MODEL_TYPE_TO_TRTLLMMODELS: + self.trtllm_loader = get_class(MODEL_TYPE_TO_TRTLLMMODELS[self.config.model_type]) self.logger.info(f"\t+ Using TRTLLMModel class {self.trtllm_loader.__name__}") else: raise NotImplementedError(f"TRTLLMBackend does not support model_type {self.config.model_type}") diff --git a/optimum_benchmark/backends/tensorrt_llm/utils.py b/optimum_benchmark/backends/tensorrt_llm/utils.py index 4574da53..01b6ed0e 100644 --- a/optimum_benchmark/backends/tensorrt_llm/utils.py +++ b/optimum_benchmark/backends/tensorrt_llm/utils.py @@ -1 +1 @@ -MODEL_TYPE_TO_TRTLLMMODEL = {"llama": "optimum.nvidia.models.llama.LlamaForCausalLM"} +MODEL_TYPE_TO_TRTLLMMODELS = {"llama": "optimum.nvidia.models.llama.LlamaForCausalLM"} From 32d08fab62483b96c2d4b0ada6e8986b9e31be0d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 16:24:39 +0100 Subject: [PATCH 18/21] fixes --- .github/workflows/test_api_rocm.yaml | 1 + .github/workflows/test_cli_cuda_tensorrt_llm.yaml | 2 ++ .github/workflows/test_cli_rocm_pytorch.yaml | 2 ++ 3 files changed, 5 insertions(+) diff --git a/.github/workflows/test_api_rocm.yaml b/.github/workflows/test_api_rocm.yaml index f6f20aa4..a1ded080 100644 --- a/.github/workflows/test_api_rocm.yaml +++ b/.github/workflows/test_api_rocm.yaml @@ -33,6 +33,7 @@ jobs: with: machine_type: single-gpu install_extras: testing,timm,diffusers,codecarbon + test_file: tests/test_api.py pytest_keywords: api and cuda secrets: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/.github/workflows/test_cli_cuda_tensorrt_llm.yaml b/.github/workflows/test_cli_cuda_tensorrt_llm.yaml index c75aac92..ea49bacc 100644 --- a/.github/workflows/test_cli_cuda_tensorrt_llm.yaml +++ b/.github/workflows/test_cli_cuda_tensorrt_llm.yaml @@ -45,6 +45,7 @@ jobs: - name: Install dependencies run: | pip install -e .[testing] + pip install huggingface-cli[cli] - name: Run tests run: | @@ -57,6 +58,7 @@ jobs: }} name: Run examples run: | + huggingface-cli delete-cache pytest tests/test_examples.py -x -s -k "cli and cuda and trt" diff --git a/.github/workflows/test_cli_rocm_pytorch.yaml b/.github/workflows/test_cli_rocm_pytorch.yaml index 3057b726..80299ffb 100644 --- a/.github/workflows/test_cli_rocm_pytorch.yaml +++ b/.github/workflows/test_cli_rocm_pytorch.yaml @@ -34,6 +34,7 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/optimum_benchmark_instinct_ci.yaml@testing with: machine_type: single-gpu + test_file: tests/test_cli.py install_extras: testing,diffusers,timm,peft,autoawq,auto-gptq pytest_keywords: cli and cuda and pytorch and not (dp or ddp or device_map or deepspeed) and not bnb @@ -51,5 +52,6 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/optimum_benchmark_instinct_ci.yaml@testing with: machine_type: multi-gpu + test_file: tests/test_cli.py install_extras: testing,diffusers,timm,peft pytest_keywords: cli and cuda and pytorch and (dp or ddp or device_map) From 7ed07768ee23571e6c11fcd657789bdc5c32fb0f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 16:25:17 +0100 Subject: [PATCH 19/21] fix --- .github/workflows/test_cli_cuda_tensorrt_llm.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test_cli_cuda_tensorrt_llm.yaml b/.github/workflows/test_cli_cuda_tensorrt_llm.yaml index ea49bacc..e84d629b 100644 --- a/.github/workflows/test_cli_cuda_tensorrt_llm.yaml +++ b/.github/workflows/test_cli_cuda_tensorrt_llm.yaml @@ -45,7 +45,6 @@ jobs: - name: Install dependencies run: | pip install -e .[testing] - pip install huggingface-cli[cli] - name: Run tests run: | @@ -58,7 +57,7 @@ jobs: }} name: Run examples run: | - + pip install huggingface-cli[cli] huggingface-cli delete-cache pytest tests/test_examples.py -x -s -k "cli and cuda and trt" From 18db6fdb6d71fa36580f7e9859e3b0c4df21a78f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 17:03:39 +0100 Subject: [PATCH 20/21] fix --- .github/workflows/test_api_rocm.yaml | 2 +- .github/workflows/test_cli_cuda_tensorrt_llm.yaml | 7 +++---- .github/workflows/test_cli_rocm_pytorch.yaml | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_api_rocm.yaml b/.github/workflows/test_api_rocm.yaml index a1ded080..5c795cbd 100644 --- a/.github/workflows/test_api_rocm.yaml +++ b/.github/workflows/test_api_rocm.yaml @@ -33,7 +33,7 @@ jobs: with: machine_type: single-gpu install_extras: testing,timm,diffusers,codecarbon - test_file: tests/test_api.py + test_file: test_api.py pytest_keywords: api and cuda secrets: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/.github/workflows/test_cli_cuda_tensorrt_llm.yaml b/.github/workflows/test_cli_cuda_tensorrt_llm.yaml index e84d629b..5a199d5a 100644 --- a/.github/workflows/test_cli_cuda_tensorrt_llm.yaml +++ b/.github/workflows/test_cli_cuda_tensorrt_llm.yaml @@ -44,7 +44,7 @@ jobs: - name: Install dependencies run: | - pip install -e .[testing] + pip install -e .[testing,tesnsorrt-llm] - name: Run tests run: | @@ -57,8 +57,7 @@ jobs: }} name: Run examples run: | - pip install huggingface-cli[cli] - huggingface-cli delete-cache + rm -rf /root/.cache/huggingface pytest tests/test_examples.py -x -s -k "cli and cuda and trt" cli_cuda_tensorrt_llm_multi_gpu_tests: @@ -85,7 +84,7 @@ jobs: - name: Install dependencies run: | - pip install -e .[testing] + pip install -e .[testing,tesnsorrt-llm] - name: Run tests (sequential) run: | diff --git a/.github/workflows/test_cli_rocm_pytorch.yaml b/.github/workflows/test_cli_rocm_pytorch.yaml index 80299ffb..47008b46 100644 --- a/.github/workflows/test_cli_rocm_pytorch.yaml +++ b/.github/workflows/test_cli_rocm_pytorch.yaml @@ -34,8 +34,8 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/optimum_benchmark_instinct_ci.yaml@testing with: machine_type: single-gpu - test_file: tests/test_cli.py install_extras: testing,diffusers,timm,peft,autoawq,auto-gptq + test_file: test_cli.py pytest_keywords: cli and cuda and pytorch and not (dp or ddp or device_map or deepspeed) and not bnb run_cli_rocm_pytorch_multi_gpu_tests: @@ -52,6 +52,6 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/optimum_benchmark_instinct_ci.yaml@testing with: machine_type: multi-gpu - test_file: tests/test_cli.py install_extras: testing,diffusers,timm,peft + test_file: test_cli.py pytest_keywords: cli and cuda and pytorch and (dp or ddp or device_map) From a38800198210c4295bed1109ab3678bf4339061d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 11 Dec 2024 18:30:36 +0100 Subject: [PATCH 21/21] max batch size --- .../backends/tensorrt_llm/backend.py | 16 ++++++++-------- .../backends/tensorrt_llm/config.py | 9 ++++----- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index 8cd046eb..60b82675 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -128,20 +128,20 @@ def trtllm_kwargs(self): if self.config.gpus_per_node is not None: kwargs["gpus_per_node"] = self.config.gpus_per_node - if self.config.use_cuda_graph is not None: - kwargs["use_cuda_graph"] = self.config.use_cuda_graph + if self.config.max_batch_size is not None: + kwargs["max_batch_size"] = self.config.max_batch_size - if self.config.optimization_level is not None: - kwargs["optimization_level"] = self.config.optimization_level + if self.config.max_new_tokens is not None: + kwargs["max_new_tokens"] = self.config.max_new_tokens if self.config.max_prompt_length is not None: kwargs["max_prompt_length"] = self.config.max_prompt_length - if self.config.max_new_tokens is not None: - kwargs["max_new_tokens"] = self.config.max_new_tokens + if self.config.optimization_level is not None: + kwargs["optimization_level"] = self.config.optimization_level - if self.config.max_beam_width is not None: - kwargs["max_beam_width"] = self.config.max_beam_width + if self.config.use_cuda_graph is not None: + kwargs["use_cuda_graph"] = self.config.use_cuda_graph return kwargs diff --git a/optimum_benchmark/backends/tensorrt_llm/config.py b/optimum_benchmark/backends/tensorrt_llm/config.py index 84d119af..2497d5d4 100644 --- a/optimum_benchmark/backends/tensorrt_llm/config.py +++ b/optimum_benchmark/backends/tensorrt_llm/config.py @@ -22,12 +22,11 @@ class TRTLLMConfig(BackendConfig): use_fp8: Optional[bool] = None world_size: Optional[int] = None gpus_per_node: Optional[int] = None - use_cuda_graph: Optional[bool] = None - optimization_level: Optional[int] = None - max_prompt_length: Optional[int] = None - max_new_tokens: Optional[int] = None max_batch_size: Optional[int] = None - max_beam_width: Optional[int] = None + max_new_tokens: Optional[int] = None + max_prompt_length: Optional[int] = None + optimization_level: Optional[int] = None + use_cuda_graph: Optional[bool] = None def __post_init__(self) -> None: super().__post_init__()