diff --git a/.github/workflows/test_cli_cuda_torch_ort_training.yaml b/.github/workflows/test_cli_cuda_torch_ort.yaml similarity index 100% rename from .github/workflows/test_cli_cuda_torch_ort_training.yaml rename to .github/workflows/test_cli_cuda_torch_ort.yaml diff --git a/examples/neural_compressor_ptq_bert.yaml b/examples/neural_compressor_ptq_bert.yaml index a8c83c88..c8b0ee6e 100644 --- a/examples/neural_compressor_ptq_bert.yaml +++ b/examples/neural_compressor_ptq_bert.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/job_logging: colorlog # colorful logging - override hydra/hydra_logging: colorlog # colorful logging -experiment_name: openvino_static_quant_bert +experiment_name: neural_compressor_ptq_bert backend: device: cpu diff --git a/optimum_benchmark/aggregators/__init__.py b/optimum_benchmark/aggregators/__init__.py deleted file mode 100644 index a3015d55..00000000 --- a/optimum_benchmark/aggregators/__init__.py +++ /dev/null @@ -1,109 +0,0 @@ -from pathlib import Path -from typing import Tuple, List, Dict - -import pandas as pd -from rich.table import Table -from omegaconf import OmegaConf -import matplotlib.pyplot as plt -from rich.console import Console -from flatten_dict import flatten -from rich.terminal_theme import MONOKAI - - -def gather(root_folders: List[Path]) -> pd.DataFrame: - configs_dfs = {} - results_dfs = {} - - for root_folder in root_folders: - if not root_folder.exists(): - raise ValueError(f"{root_folder} does not exist") - - for f in root_folder.glob("**/hydra_config.yaml"): - parent_folder = f.parent.absolute().as_posix() - configs_dfs[parent_folder] = pd.DataFrame.from_dict( - flatten(OmegaConf.load(f), reducer="dot"), orient="index" - ).T - - for f in root_folder.glob("**/*_results.csv"): - parent_folder = f.parent.absolute().as_posix() - results_dfs[parent_folder] = pd.read_csv(f) - - if (len(results_dfs) == 0) or (len(configs_dfs) == 0): - raise ValueError(f"Results are missing in {root_folders}") - - # Merge inference and config dataframes - full_dfs = {} - for parent_folder in results_dfs: - full_df = pd.concat( - [configs_dfs[parent_folder], results_dfs[parent_folder]], - axis=1, - ) - full_df["parent_folder"] = parent_folder - full_dfs[parent_folder] = full_df - - # Concatenate all dataframes - full_report = pd.concat(full_dfs.values(), ignore_index=True, axis=0) - - return full_report - - -def format_element(element): - if isinstance(element, float): - if element != element: - formated_element = "" - elif abs(element) >= 1: - formated_element = f"{element:.2f}" - elif abs(element) > 1e-6: - formated_element = f"{element:.2e}" - else: - formated_element = f"{element}" - elif element is None: - formated_element = "" - elif isinstance(element, bool): - if element: - formated_element = "[green]✔[/green]" - else: - formated_element = "[red]✘[/red]" - else: - formated_element = str(element) - - return formated_element - - -def display(report: pd.DataFrame) -> Table: - table = Table(show_header=True, show_lines=True) - - for column in report.columns: - table.add_column(column, justify="right", header_style="bold") - - for _, row in report.iterrows(): - formated_row = [] - for element in row.values: - formated_row.append(format_element(element)) - table.add_row(*formated_row) - - console = Console(record=True, theme=MONOKAI) - console.print(table, justify="center") - - return console, table - - -def rename(report: pd.DataFrame, rename_dict: Dict[str, str]): - summarized_report = report[list(rename_dict.keys())].rename(columns=rename_dict) - - return summarized_report - - -def plot(report: pd.DataFrame, x_axis: str, y_axis: str, groupby: str) -> Tuple[plt.Figure, plt.Axes]: - fig, ax = plt.subplots() - - for group, sweep in report.groupby(groupby): - sorted_sweep = sweep.sort_values(by=x_axis) - ax.plot(sorted_sweep[x_axis], sorted_sweep[y_axis], label=group, marker="o") - - ax.set_xlabel(x_axis) - ax.set_ylabel(y_axis) - ax.set_title(f"{y_axis} per {x_axis}") - ax.legend(fancybox=True, shadow=True) - - return fig, ax diff --git a/optimum_benchmark/backends/neural_compressor/backend.py b/optimum_benchmark/backends/neural_compressor/backend.py index 092affff..dd2a7a82 100644 --- a/optimum_benchmark/backends/neural_compressor/backend.py +++ b/optimum_benchmark/backends/neural_compressor/backend.py @@ -4,22 +4,19 @@ from logging import getLogger from tempfile import TemporaryDirectory +from ...generators.dataset_generator import DatasetGenerator +from ..transformers_utils import randomize_weights +from .utils import TASKS_TO_INCMODELS +from .config import INCConfig +from ..base import Backend + import torch from hydra.utils import get_class from transformers.utils import ModelOutput from transformers.modeling_utils import no_init_weights from transformers.utils.logging import set_verbosity_error from optimum.intel.neural_compressor.quantization import INCQuantizer -from neural_compressor.config import ( - PostTrainingQuantConfig, - AccuracyCriterion, - TuningCriterion, -) - -from ...generators.dataset_generator import DatasetGenerator -from .utils import TASKS_TO_INCMODELS -from .config import INCConfig -from ..base import Backend +from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion, TuningCriterion # disable transformers logging set_verbosity_error() @@ -34,9 +31,7 @@ def __init__(self, config: INCConfig): super().__init__(config) self.validate_task() - self.incmodel_class = get_class(TASKS_TO_INCMODELS[self.config.task]) - LOGGER.info(f"Using INCModel class {self.incmodel_class.__name__}") - + LOGGER.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() if self.config.ptq_quantization: @@ -52,57 +47,65 @@ def __init__(self, config: INCConfig): else: self.load_incmodel_from_pretrained() - self.tmpdir.cleanup() - def validate_task(self) -> None: if self.config.task not in TASKS_TO_INCMODELS: raise NotImplementedError(f"INCBackend does not support task {self.config.task}") + self.incmodel_class = get_class(TASKS_TO_INCMODELS[self.config.task]) + LOGGER.info(f"Using INCModel class {self.incmodel_class.__name__}") + def load_automodel_from_pretrained(self) -> None: LOGGER.info("\t+ Loading AutoModel from pretrained") self.pretrained_model = self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs) - def load_automodel_with_no_weights(self) -> None: - no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + def create_no_weights_model(self) -> None: + LOGGER.info("\t+ Creating no weights model state_dict") + state_dict = torch.nn.Linear(1, 1).state_dict() - if not os.path.exists(no_weights_model): - LOGGER.info("\t+ Creating no weights model directory") - os.makedirs(no_weights_model) + LOGGER.info("\t+ Creating no weights model directory") + self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + os.makedirs(self.no_weights_model, exist_ok=True) - LOGGER.info("\t+ Saving pretrained config") - self.pretrained_config.save_pretrained(save_directory=no_weights_model) + LOGGER.info("\t+ Saving no weights model pretrained config") + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - LOGGER.info("\t+ Creating no weights model") - state_dict = torch.nn.Linear(1, 1).state_dict() + LOGGER.info("\t+ Saving no weights model state_dict") + torch.save(state_dict, os.path.join(self.no_weights_model, "pytorch_model.bin")) - LOGGER.info("\t+ Saving no weights model") - torch.save(state_dict, os.path.join(no_weights_model, "pytorch_model.bin")) + def load_automodel_with_no_weights(self) -> None: + self.create_no_weights_model() - LOGGER.info("\t+ Loading no weights model") with no_init_weights(): original_model = self.config.model - self.config.model = no_weights_model + self.config.model = self.no_weights_model + LOGGER.info("\t+ Loading no weights model") self.load_automodel_from_pretrained() self.config.model = original_model + LOGGER.info("\t+ Randomizing model weights") + randomize_weights(self.pretrained_model) + LOGGER.info("\t+ Tying model weights") + self.pretrained_model.tie_weights() + def load_incmodel_from_pretrained(self) -> None: LOGGER.info("\t+ Loading INCModel from pretrained") self.pretrained_model = self.incmodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs) def load_incmodel_with_no_weights(self) -> None: - no_weights_model = os.path.join(self.tmpdir.name, "no_weights") - - LOGGER.info("\t+ Loading AutoModel with no weights") - self.load_automodel_with_no_weights() - self.delete_pretrained_model() + self.create_no_weights_model() - LOGGER.info("\t+ Loading INCModel with no weights") with no_init_weights(): original_model = self.config.model - self.config.model = no_weights_model + self.config.model = self.no_weights_model + LOGGER.info("\t+ Loading no weights model") self.load_incmodel_from_pretrained() self.config.model = original_model + LOGGER.info("\t+ Randomizing model weights") + randomize_weights(self.pretrained_model.model) + LOGGER.info("\t+ Tying model weights") + self.pretrained_model.model.tie_weights() + def quantize_automodel(self) -> None: LOGGER.info("\t+ Attempting to quantize model") quantized_model_path = f"{self.tmpdir.name}/quantized" @@ -134,7 +137,7 @@ def quantize_automodel(self) -> None: task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes, - ).generate() + )() columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._signature_columns)) calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) else: @@ -169,6 +172,7 @@ def clean(self) -> None: super().clean() if hasattr(self, "tmpdir"): + LOGGER.info("\t+ Cleaning backend temporary directory") self.tmpdir.cleanup() gc.collect() diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 9983ead2..07d5d860 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -4,26 +4,21 @@ from collections import OrderedDict from tempfile import TemporaryDirectory from typing import Any, Callable, Dict, List -from ...generators.dataset_generator import DatasetGenerator -from ...task_utils import TEXT_GENERATION_TASKS -from .config import ORTConfig + from ..base import Backend -from .utils import ( - format_calibration_config, - format_quantization_config, - TASKS_TO_ORTMODELS, - TASKS_TO_ORTSD, -) +from .config import ORTConfig +from ...task_utils import TEXT_GENERATION_TASKS +from ...generators.dataset_generator import DatasetGenerator +from .utils import format_calibration_config, format_quantization_config, TASKS_TO_ORTMODELS, TASKS_TO_ORTSD import torch from datasets import Dataset from hydra.utils import get_class from onnxruntime import SessionOptions from safetensors.torch import save_file -from transformers import TrainerCallback, TrainerState +from transformers import TrainerCallback from transformers.modeling_utils import no_init_weights from transformers.utils.logging import set_verbosity_error -from optimum.onnxruntime import ONNX_DECODER_WITH_PAST_NAME, ONNX_DECODER_NAME, ORTOptimizer, ORTQuantizer from optimum.onnxruntime.configuration import ( AutoOptimizationConfig, AutoQuantizationConfig, @@ -32,7 +27,14 @@ QuantizationConfig, CalibrationConfig, ) - +from optimum.onnxruntime import ( + ONNX_DECODER_WITH_PAST_NAME, + ONNX_DECODER_NAME, + ORTTrainingArguments, + ORTOptimizer, + ORTQuantizer, + ORTTrainer, +) # disable transformers logging set_verbosity_error() @@ -56,15 +58,19 @@ def __init__(self, config: ORTConfig) -> None: else: raise NotImplementedError(f"ORTBackend does not support task {self.config.task}") - self.set_session_options() + LOGGER.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() + self.session_options = SessionOptions() + for key, value in self.config.session_options.items(): + setattr(self.session_options, key, value) + if self.config.no_weights: self.load_ortmodel_with_no_weights() else: self.load_ortmodel_from_pretrained() - if self.is_deferred_trt_loading(): + if self.is_trt_text_generation: return if self.is_optimized or self.is_quantized: @@ -94,35 +100,30 @@ def validate_provider(self) -> None: self.pretrained_model.providers[0] == self.config.provider ), f"{self.config.provider} is not first in providers list: {self.pretrained_model.providers}" - def is_deferred_trt_loading(self) -> bool: - return self.config.provider == "TensorrtExecutionProvider" and self.config.task in TEXT_GENERATION_TASKS - - def set_session_options(self) -> None: - self.session_options = SessionOptions() - for key, value in self.config.session_options.items(): - setattr(self.session_options, key, value) - - def load_ortmodel_with_no_weights(self) -> None: + def create_no_weights_model(self) -> None: LOGGER.info("\t+ Creating no weights model directory") - no_weights_model = os.path.join(self.tmpdir.name, "no_weights") - os.makedirs(no_weights_model, exist_ok=True) + self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + os.makedirs(self.no_weights_model, exist_ok=True) LOGGER.info("\t+ Saving pretrained config") - self.pretrained_config.save_pretrained(save_directory=no_weights_model) + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - LOGGER.info("\t+ Creating no weights model weights") + LOGGER.info("\t+ Creating no weights model state dict") state_dict = torch.nn.Linear(1, 1).state_dict() - LOGGER.info("\t+ Saving no weights model weights") + LOGGER.info("\t+ Saving no weights model state dict") save_file( - filename=os.path.join(no_weights_model, "model.safetensors"), + filename=os.path.join(self.no_weights_model, "model.safetensors"), metadata={"format": "pt"}, tensors=state_dict, ) + def load_ortmodel_with_no_weights(self) -> None: + self.create_no_weights_model() + with no_init_weights(): original_model = self.config.model - self.config.model = no_weights_model + self.config.model = self.no_weights_model LOGGER.info("\t+ Loading no weights model") self.load_ortmodel_from_pretrained() self.config.model = original_model @@ -139,6 +140,10 @@ def load_ortmodel_from_pretrained(self) -> None: **self.ortmodel_kwargs, ) + @property + def is_trt_text_generation(self) -> bool: + return self.config.provider == "TensorrtExecutionProvider" and self.config.task in TEXT_GENERATION_TASKS + @property def is_optimized(self) -> bool: return (self.config.auto_optimization is not None) or self.config.optimization @@ -247,7 +252,7 @@ def quantize_onnx_files(self) -> None: task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes, - ).generate() + )() columns_to_be_removed = list(set(calibration_dataset.column_names) - set(self.inputs_names)) calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) @@ -309,7 +314,7 @@ def quantize_onnx_files(self) -> None: self.config.model = quantized_model_path def prepare_for_inference(self, **kwargs) -> None: - if self.is_deferred_trt_loading(): + if self.is_trt_text_generation: LOGGER.info("\t+ Creating dynamic shapes for Tensorrt engine. Engine creation might take a while.") batch_size = kwargs["batch_size"] max_new_tokens = kwargs["max_new_tokens"] @@ -363,9 +368,7 @@ def train( training_arguments: Dict[str, Any], training_callbacks: List[TrainerCallback], training_data_collator: Callable[[List[Dict[str, Any]]], Dict[str, Any]], - ) -> TrainerState: - from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments - + ) -> None: LOGGER.info("\t+ Setting dataset format to `torch`") training_dataset.set_format(type="torch", columns=list(training_dataset.features.keys())) LOGGER.info("\t+ Wrapping training arguments with optimum.onnxruntime.ORTTrainingArguments") @@ -382,13 +385,11 @@ def train( trainer.train() LOGGER.info("\t+ Training finished successfully") - return trainer.state - def clean(self) -> None: super().clean() if hasattr(self, "tmpdir"): - LOGGER.info("\t+ Cleaning temporary directory") + LOGGER.info("\t+ Cleaning backend temporary directory") self.tmpdir.cleanup() gc.collect() diff --git a/optimum_benchmark/backends/onnxruntime/config.py b/optimum_benchmark/backends/onnxruntime/config.py index 0f9262cc..e0191b88 100644 --- a/optimum_benchmark/backends/onnxruntime/config.py +++ b/optimum_benchmark/backends/onnxruntime/config.py @@ -38,6 +38,7 @@ class ORTConfig(BackendConfig): version: Optional[str] = onnxruntime_version() _target_: str = "optimum_benchmark.backends.onnxruntime.backend.ORTBackend" + # load options no_weights: bool = False # export options diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index 2fdb97e3..73cbd63d 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -6,6 +6,13 @@ from collections import OrderedDict from tempfile import TemporaryDirectory +from ..base import Backend +from .config import OVConfig +from .utils import TASKS_TO_OVMODEL +from ...task_utils import TEXT_GENERATION_TASKS +from ..transformers_utils import randomize_weights +from ...generators.dataset_generator import DatasetGenerator + import torch from hydra.utils import get_class from openvino.runtime import properties @@ -15,14 +22,6 @@ from transformers.utils.logging import set_verbosity_error from optimum.intel.openvino import OVConfig as OVQuantizationConfig # naming conflict -from ..base import Backend -from .config import OVConfig -from .utils import TASKS_TO_OVMODEL -from ...task_utils import TEXT_GENERATION_TASKS -from ..transformers_utils import randomize_weights -from ...generators.dataset_generator import DatasetGenerator - - # disable transformers logging set_verbosity_error() @@ -149,7 +148,11 @@ def quantize_automodel(self) -> None: "sequence_length": 1, **self.model_shapes, } - calibration_dataset = DatasetGenerator(task=self.config.task, dataset_shapes=dataset_shapes).generate() + calibration_dataset = DatasetGenerator( + task=self.config.task, + dataset_shapes=dataset_shapes, + model_shapes=self.model_shapes, + )() columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._export_input_names)) calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) else: diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index de1b2327..0dc32371 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -38,18 +38,7 @@ class PyTorchBackend(Backend[PyTorchConfig]): def __init__(self, config: PyTorchConfig): super().__init__(config) - - if self.config.library == "timm": - LOGGER.info("\t+ Using method timm.create_model") - else: - automodel = self.automodel_class.__name__ - if self.config.library == "diffusers": - LOGGER.info(f"\t+ Using Pipeline class {automodel}") - else: - LOGGER.info(f"\t+ Using AutoModel class {automodel}") - - # Mixed precision - self.amp_dtype = getattr(torch, self.config.amp_dtype) if self.config.amp_dtype is not None else None + self.validate_library() # Threads if self.config.inter_op_num_threads is not None: @@ -59,6 +48,13 @@ def __init__(self, config: PyTorchConfig): LOGGER.info(f"\t+ Setting pytorch intra_op_num_threads({self.config.intra_op_num_threads}))") torch.set_num_interop_threads(self.config.intra_op_num_threads) + # Mixed precision + if self.config.amp_dtype: + LOGGER.info(f"\t+ Setting mixed precision dtype to {self.config.amp_dtype}") + self.amp_dtype = getattr(torch, self.config.amp_dtype) + else: + self.amp_dtype = None + # Quantization if self.is_quantized: LOGGER.info("\t+ Processing quantization config") @@ -66,7 +62,9 @@ def __init__(self, config: PyTorchConfig): else: self.quantization_config = None + LOGGER.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() + if self.config.no_weights and self.config.library == "diffusers": raise ValueError("Diffusion pipelines are not supported with no_weights=True") elif self.config.no_weights: @@ -81,7 +79,7 @@ def __init__(self, config: PyTorchConfig): self.pretrained_model.generation_config.cache_implementation = self.config.cache_implementation # Eval mode - if self.config.eval_mode and not self.config.library == "diffusers": + if self.config.eval_mode and self.config.library == "diffusers": LOGGER.info("\t+ Turning on model's eval mode") self.pretrained_model.eval() @@ -120,7 +118,15 @@ def __init__(self, config: PyTorchConfig): dtype=getattr(self.pretrained_model, "dtype", None), ) - self.tmpdir.cleanup() + def validate_library(self) -> None: + if self.config.library == "timm": + LOGGER.info(f"\t+ Using Timm method {self.automodel_class.__name__}") + elif self.config.library == "diffusers": + LOGGER.info(f"\t+ Using Pipeline class {self.automodel_class.__name__}") + elif self.config.library == "transformers": + LOGGER.info(f"\t+ Using AutoModel class {self.automodel_class.__name__}") + else: + raise ValueError(f"Library {self.config.library} not supported") def load_model_from_pretrained(self) -> None: if self.config.library == "timm": @@ -132,8 +138,8 @@ def load_model_from_pretrained(self) -> None: self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.config.model, device_map=self.config.device_map, - **self.automodel_kwargs, **self.config.hub_kwargs, + **self.automodel_kwargs, ) if self.config.device_map is None: LOGGER.info(f"\t+ Moving pipeline to device: {self.config.device}") @@ -151,7 +157,7 @@ def load_model_from_pretrained(self) -> None: self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.config.model, # for gptq, we need to specify the device_map to either auto - # or a cuda adevice to avoid any modules being assigned to cpu + # or a cuda adevice to avoid any modules being assigned to cpu ¯\_(ツ)_/¯ device_map=self.config.device_map or torch.device(self.config.device), **self.config.hub_kwargs, **self.automodel_kwargs, @@ -166,39 +172,39 @@ def load_model_from_pretrained(self) -> None: ) else: # this is the fastest way to load a model on a specific device + # but not compatible with all quantization methods (and pipelines) LOGGER.info(f"\t+ Loading model directly on device: {self.config.device}") with torch.device(self.config.device): self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.config.model, - **self.automodel_kwargs, **self.config.hub_kwargs, + **self.automodel_kwargs, ) def create_no_weights_model(self) -> None: - self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights") - - LOGGER.info("\t+ Creating no weights model directory") - os.makedirs(self.no_weights_model, exist_ok=True) - - if self.is_quantized: - # tricking from_pretrained to load the model as if it was quantized - self.pretrained_config.quantization_config = self.quantization_config.to_dict() - - LOGGER.info("\t+ Saving pretrained config") - self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - LOGGER.info("\t+ Creating no weights model state_dict") state_dict = torch.nn.Linear(1, 1).state_dict() if self.is_exllamav2: - # for exllamav2 we need to add g_idx to the state_dict + # for exllamav2 we need to add g_idx to the state_dict which + # requires some information about linear layers dimensions with torch.device("meta"): meta_model = self.automodel_class.from_config(self.pretrained_config) - for name, module in meta_model.named_modules(): if hasattr(module, "in_features"): state_dict[name + ".g_idx"] = torch.ones((module.in_features,), dtype=torch.int32) + if self.is_quantized: + # tricking from_pretrained to load the model as if it was quantized + self.pretrained_config.quantization_config = self.quantization_config.to_dict() + + LOGGER.info("\t+ Creating no weights model directory") + self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + os.makedirs(self.no_weights_model, exist_ok=True) + + LOGGER.info("\t+ Saving no weights model pretrained config") + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) + LOGGER.info("\t+ Saving no weights model state_dict") save_file( filename=os.path.join(self.no_weights_model, "model.safetensors"), @@ -284,8 +290,8 @@ def is_awq_quantized(self) -> bool: def is_exllamav2(self) -> bool: return ( self.is_gptq_quantized - and "exllama_config" in self.config.quantization_config - and self.config.quantization_config["exllama_config"]["version"] == 2 + and "exllama_config" in self.quantization_config + and self.quantization_config["exllama_config"].get("version", None) == 2 ) @property @@ -345,8 +351,8 @@ def train( training_arguments = TrainingArguments(**training_arguments) LOGGER.info("\t+ Wrapping model with transformers.Trainer") trainer = Trainer( - model=self.pretrained_model, args=training_arguments, + model=self.pretrained_model, callbacks=training_callbacks, train_dataset=training_dataset, data_collator=training_data_collator, @@ -366,7 +372,7 @@ def clean(self) -> None: super().clean() if hasattr(self, "tmpdir"): - LOGGER.info("\t+ Cleaning temporary directory") + LOGGER.info("\t+ Cleaning backend temporary directory") self.tmpdir.cleanup() gc.collect() diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index 3dd11e75..d8089f60 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -56,7 +56,6 @@ class PyTorchConfig(BackendConfig): quantization_config: Dict[str, Any] = field(default_factory=dict) # distributed inference options - data_parallel: bool = False deepspeed_inference: bool = False deepspeed_inference_config: Dict[str, Any] = field(default_factory=dict) diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index 43a5fd75..7c86adeb 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -1,13 +1,13 @@ from logging import getLogger from typing import Any, Dict -from hydra.utils import get_class -from transformers.utils import ModelOutput - from ..base import Backend from .config import TRTLLMConfig from .utils import MODEL_TYPE_TO_TRTLLMMODEL +from hydra.utils import get_class +from transformers.utils import ModelOutput + LOGGER = getLogger("tensorrt-llm") @@ -18,15 +18,15 @@ def __init__(self, config: TRTLLMConfig): super().__init__(config) self.validate_model_type() - self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTLLMMODEL[self.model_type]) - LOGGER.info(f"\t+ Using TRTLLMModel class {self.trtmodel_class.__name__}") - self.load_trtmodel_from_pretrained() def validate_model_type(self) -> None: if self.model_type not in MODEL_TYPE_TO_TRTLLMMODEL: raise NotImplementedError(f"TRTLLMBackend does not support model_type {self.model_type}") + self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTLLMMODEL[self.model_type]) + LOGGER.info(f"\t+ Using TRTLLMModel class {self.trtmodel_class.__name__}") + def load_trtmodel_from_pretrained(self) -> None: self.pretrained_model = self.trtmodel_class.from_pretrained( self.config.model, diff --git a/optimum_benchmark/backends/text_generation_inference/backend.py b/optimum_benchmark/backends/text_generation_inference/backend.py index fbd3d1de..538de53c 100644 --- a/optimum_benchmark/backends/text_generation_inference/backend.py +++ b/optimum_benchmark/backends/text_generation_inference/backend.py @@ -6,6 +6,11 @@ from tempfile import TemporaryDirectory from concurrent.futures import ThreadPoolExecutor +from ..base import Backend +from .config import TGIConfig +from ...task_utils import TEXT_GENERATION_TASKS +from ..transformers_utils import randomize_weights + import torch import docker import docker.types @@ -14,10 +19,6 @@ from huggingface_hub import InferenceClient, snapshot_download from huggingface_hub.inference._text_generation import TextGenerationResponse -from ..base import Backend -from .config import TGIConfig -from ..transformers_utils import randomize_weights - # bachend logger LOGGER = getLogger("text-generation-inference") @@ -29,8 +30,7 @@ def __init__(self, config: TGIConfig) -> None: super().__init__(config) self.validate_task() - LOGGER.info(f"Using AutoModel class {self.automodel_class.__name__}") - + LOGGER.info("\t+ Creating backend temporary directory") self.tmp_dir = TemporaryDirectory() if self.config.no_weights: @@ -40,9 +40,11 @@ def __init__(self, config: TGIConfig) -> None: self.load_model_from_pretrained() def validate_task(self) -> None: - if self.config.task not in ["text-generation", "text2text-generation"]: + if self.config.task not in TEXT_GENERATION_TASKS: raise NotImplementedError(f"TGI does not support task {self.config.task}") + LOGGER.info(f"Using AutoModel class {self.automodel_class.__name__}") + def download_pretrained_model(self) -> None: LOGGER.info("\t+ Downloading pretrained model") snapshot_download(self.config.model, **self.config.hub_kwargs) @@ -93,7 +95,7 @@ def create_no_weights_model(self) -> None: self.pretrained_model = self.automodel_class.from_pretrained( self.no_weights_model, **self.config.hub_kwargs, - device_map="auto", + device_map="auto", # for faster/safer loading ) LOGGER.info("\t+ Randomizing weights") diff --git a/optimum_benchmark/backends/text_generation_inference/config.py b/optimum_benchmark/backends/text_generation_inference/config.py index edf37ba3..8b73617e 100644 --- a/optimum_benchmark/backends/text_generation_inference/config.py +++ b/optimum_benchmark/backends/text_generation_inference/config.py @@ -11,6 +11,9 @@ class TGIConfig(BackendConfig): version: Optional[str] = "0.0.1" _target_: str = "optimum_benchmark.backends.text_generation_inference.backend.TGIBackend" + # optimum benchmark specific + no_weights: bool = False + # docker options image: str = "ghcr.io/huggingface/text-generation-inference:latest" volume: str = f"{os.path.expanduser('~')}/.cache/huggingface/hub" @@ -28,9 +31,6 @@ class TGIConfig(BackendConfig): sharded: Optional[bool] = None # None, True, False num_shard: Optional[int] = None # None, 1, 2, 4, 8, 16, 32, 64 - # optimum benchmark specific - no_weights: bool = False # True, False - def __post_init__(self): super().__post_init__() diff --git a/optimum_benchmark/backends/torch_ort/backend.py b/optimum_benchmark/backends/torch_ort/backend.py index aefce8ea..a7515d2f 100644 --- a/optimum_benchmark/backends/torch_ort/backend.py +++ b/optimum_benchmark/backends/torch_ort/backend.py @@ -4,6 +4,11 @@ from tempfile import TemporaryDirectory from typing import Any, Callable, Dict, List +from ..transformers_utils import randomize_weights +from ..peft_utils import get_peft_config_class +from .config import TorchORTConfig +from ..base import Backend + import torch from datasets import Dataset from safetensors.torch import save_file @@ -12,11 +17,6 @@ from transformers.utils.logging import set_verbosity_error from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments -from ..transformers_utils import randomize_weights -from ..peft_utils import get_peft_config_class -from .config import TorchORTConfig -from ..base import Backend - # disable transformers logging set_verbosity_error() @@ -28,9 +28,9 @@ class TorchORTBackend(Backend[TorchORTConfig]): def __init__(self, config: TorchORTConfig): super().__init__(config) + self.validate_library() - LOGGER.info(f"Using AutoModel: {self.automodel_class.__name__}") - + LOGGER.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() if self.config.no_weights: @@ -46,7 +46,11 @@ def __init__(self, config: TorchORTConfig): peft_config = peft_config_class(**self.config.peft_config) self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config) - self.tmpdir.cleanup() + def validate_library(self) -> None: + if self.config.library == "transformers": + LOGGER.info(f"Using AutoModel: {self.automodel_class.__name__}") + else: + raise NotImplementedError(f"TorchORTBackend does not support {self.config.library} library") def create_no_weights_model(self) -> None: LOGGER.info("\t+ Creating no weights model directory") @@ -76,9 +80,9 @@ def load_automodel_with_no_weights(self) -> None: self.load_automodel_from_pretrained() self.config.model = original_model - LOGGER.info("\t+ Randomizing weights") + LOGGER.info("\t+ Randomizing model weights") randomize_weights(self.pretrained_model) - LOGGER.info("\t+ Tying model weights after randomization") + LOGGER.info("\t+ Tying model weights") self.pretrained_model.tie_weights() def load_automodel_from_pretrained(self) -> None: @@ -126,7 +130,7 @@ def clean(self) -> None: super().clean() if hasattr(self, "tmpdir"): - LOGGER.info("\t+ Cleaning temporary directory") + LOGGER.info("\t+ Cleaning backend temporary directory") self.tmpdir.cleanup() gc.collect() diff --git a/optimum_benchmark/backends/transformers_utils.py b/optimum_benchmark/backends/transformers_utils.py index 5ba169fc..1d7ad410 100644 --- a/optimum_benchmark/backends/transformers_utils.py +++ b/optimum_benchmark/backends/transformers_utils.py @@ -127,16 +127,15 @@ def extract_transformers_shapes_from_artifacts( def randomize_weights(model: "torch.nn.Module") -> None: for param in model.parameters(): - if param.data.dtype in (torch.float32, torch.float16, torch.bfloat16): + if param.data.is_floating_point(): if torch.cuda.is_available() and param.device.type != "cuda": param.data.cuda().normal_(mean=0.0, std=0.2).cpu() elif torch.backends.mps.is_available() and param.device.type != "mps": - param.data.to("mps").normal_(mean=0.0, std=0.2).to("cpu") + param.data.to("mps").normal_(mean=0.0, std=0.2).cpu() else: param.data.normal_(mean=0.0, std=0.2) - # TODO: enable quantized weights randomization - elif param.data.dtype in (torch.int8, torch.int16, torch.int32, torch.int64): + elif param.data.dtype in (torch.int32, torch.int16, torch.int8): if torch.cuda.is_available() and param.device.type != "cuda": param.data.copy_(torch.randint(-127, 127, param.data.shape, device="cuda")) elif torch.backends.mps.is_available() and param.device.type != "mps": diff --git a/optimum_benchmark/benchmarks/training/benchmark.py b/optimum_benchmark/benchmarks/training/benchmark.py index 204feaf1..90c231d0 100644 --- a/optimum_benchmark/benchmarks/training/benchmark.py +++ b/optimum_benchmark/benchmarks/training/benchmark.py @@ -4,9 +4,9 @@ from ..base import Benchmark from .config import TrainingConfig from .report import TrainingReport -from .callback import MeasurementCallback from ...trackers.memory import MemoryTracker from ...trackers.energy import EnergyTracker +from .callback import LatencyTrainerCallback from ...backends.base import Backend, BackendConfigT from ...generators.dataset_generator import DatasetGenerator @@ -43,7 +43,7 @@ def run(self, backend: Backend[BackendConfigT]) -> None: training_callbackes = [] if self.config.latency: LOGGER.info("\t+ Adding latency measuring callback") - latency_callback = MeasurementCallback(device=backend.config.device, backend=backend.config.name) + latency_callback = LatencyTrainerCallback(device=backend.config.device, backend=backend.config.name) training_callbackes.append(latency_callback) training_trackers = [] diff --git a/optimum_benchmark/benchmarks/training/callback.py b/optimum_benchmark/benchmarks/training/callback.py index 6517f337..88026d79 100644 --- a/optimum_benchmark/benchmarks/training/callback.py +++ b/optimum_benchmark/benchmarks/training/callback.py @@ -1,15 +1,11 @@ import time from typing import List -from ...import_utils import is_torch_available - -if is_torch_available(): - import torch - +import torch from transformers import TrainerCallback -class MeasurementCallback(TrainerCallback): +class LatencyTrainerCallback(TrainerCallback): def __init__(self, device: str, backend: str) -> None: self.device = device self.backend = backend diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 28c59a59..f327e85c 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -70,9 +70,9 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: if len(outputs) == 1: report: BenchmarkReport = outputs[0] else: + LOGGER.info(f"\t+ Merging benchmark reports from {len(outputs)} workers") report: BenchmarkReport = sum(outputs[1:], outputs[0]) - - report.log_all() + report.log_all() return report diff --git a/optimum_benchmark/logging_utils.py b/optimum_benchmark/logging_utils.py index 2725584c..72f76889 100644 --- a/optimum_benchmark/logging_utils.py +++ b/optimum_benchmark/logging_utils.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf -JOB_LOGGING = { +API_JOB_LOGGING = { "version": 1, "formatters": { "simple": {"format": "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s"}, @@ -29,13 +29,8 @@ "stream": "ext://sys.stdout", "class": "logging.StreamHandler", }, - # "file": { - # "filename": "api.log", - # "formatter": "simple", - # "class": "logging.FileHandler", - # }, }, - "root": {"level": "INFO", "handlers": ["console"]}, # "file"]}, + "root": {"level": "INFO", "handlers": ["console"]}, "disable_existing_loggers": False, } @@ -48,8 +43,7 @@ def setup_logging(level: str = "INFO", prefix: Optional[str] = None): resolve=True, ) else: - job_logging = JOB_LOGGING.copy() - job_logging["root"]["handlers"] = ["console"] + job_logging = API_JOB_LOGGING.copy() job_logging["root"]["level"] = level