From fd32ad53f52d22637d34c00b3f523c5a3d17af76 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 8 Mar 2024 12:58:04 +0100 Subject: [PATCH] Support Py-TXI (TGI and TEI) (#147) --- ...u_py_tgi.yaml => test_cli_cpu_py_txi.yaml} | 11 +- examples/tei_bert.yaml | 36 ++++ examples/tgi_llama.yaml | 16 +- optimum_benchmark/backends/py_tgi/backend.py | 153 -------------- optimum_benchmark/backends/py_tgi/config.py | 55 ------ .../backends/{py_tgi => py_txi}/__init__.py | 0 optimum_benchmark/backends/py_txi/backend.py | 186 ++++++++++++++++++ optimum_benchmark/backends/py_txi/config.py | 80 ++++++++ .../backends/transformers_utils.py | 2 +- .../benchmarks/inference/benchmark.py | 2 - optimum_benchmark/benchmarks/report.py | 2 +- optimum_benchmark/cli.py | 4 +- optimum_benchmark/experiment.py | 2 +- optimum_benchmark/hub_utils.py | 39 +++- optimum_benchmark/import_utils.py | 12 +- optimum_benchmark/task_utils.py | 2 + setup.py | 19 +- tests/configs/_bert_.yaml | 2 +- tests/configs/cpu_inference_py_txi_bert.yaml | 13 ++ ...gpt.yaml => cpu_inference_py_txi_gpt.yaml} | 4 +- 20 files changed, 391 insertions(+), 249 deletions(-) rename .github/workflows/{test_cli_cpu_py_tgi.yaml => test_cli_cpu_py_txi.yaml} (77%) create mode 100644 examples/tei_bert.yaml delete mode 100644 optimum_benchmark/backends/py_tgi/backend.py delete mode 100644 optimum_benchmark/backends/py_tgi/config.py rename optimum_benchmark/backends/{py_tgi => py_txi}/__init__.py (100%) create mode 100644 optimum_benchmark/backends/py_txi/backend.py create mode 100644 optimum_benchmark/backends/py_txi/config.py create mode 100644 tests/configs/cpu_inference_py_txi_bert.yaml rename tests/configs/{cpu_inference_py_tgi_gpt.yaml => cpu_inference_py_txi_gpt.yaml} (81%) diff --git a/.github/workflows/test_cli_cpu_py_tgi.yaml b/.github/workflows/test_cli_cpu_py_txi.yaml similarity index 77% rename from .github/workflows/test_cli_cpu_py_tgi.yaml rename to .github/workflows/test_cli_cpu_py_txi.yaml index b7fc1c5a..384e4189 100644 --- a/.github/workflows/test_cli_cpu_py_tgi.yaml +++ b/.github/workflows/test_cli_cpu_py_txi.yaml @@ -1,4 +1,4 @@ -name: CLI CPU Py-TGI Tests +name: CLI CPU Py-TXI Tests on: workflow_dispatch: @@ -12,7 +12,7 @@ concurrency: cancel-in-progress: true jobs: - run_cli_cpu_py_tgi_tests: + run_cli_cpu_py_txi_tests: runs-on: ubuntu-latest steps: - name: Free disk space @@ -35,10 +35,13 @@ jobs: - name: Install requirements run: | pip install --upgrade pip - pip install -e .[testing,py-tgi] + pip install -e .[testing,py-txi] - name: Pull TGI docker image run: docker pull ghcr.io/huggingface/text-generation-inference:latest + - name: Pull TEI docker image + run: docker pull ghcr.io/huggingface/text-embeddings-inference:cpu-latest + - name: Run tests - run: pytest -k "cli and cpu and py_tgi" + run: pytest -k "cli and cpu and py_txi" diff --git a/examples/tei_bert.yaml b/examples/tei_bert.yaml new file mode 100644 index 00000000..90d16f43 --- /dev/null +++ b/examples/tei_bert.yaml @@ -0,0 +1,36 @@ +defaults: + - backend: py-txi + - launcher: inline # default launcher + - benchmark: inference # default benchmark + - experiment # inheriting experiment schema + - _self_ # for hydra 1.1 compatibility + - override hydra/job_logging: colorlog # colorful logging + - override hydra/hydra_logging: colorlog # colorful logging + +experiment_name: tei_bert + +backend: + device: cpu + pooling: cls + model: bert-base-uncased + +benchmark: + input_shapes: + batch_size: 64 + sequence_length: 128 + +# hydra/cli specific settings +hydra: + run: + # where to store run results + dir: runs/${experiment_name} + sweep: + # where to store sweep results + dir: sweeps/${experiment_name} + job: + # change working directory to the run directory + chdir: true + env_set: + # set environment variable OVERRIDE_BENCHMARKS to 1 + # to not skip benchmarks that have been run before + OVERRIDE_BENCHMARKS: 1 diff --git a/examples/tgi_llama.yaml b/examples/tgi_llama.yaml index a23c5c55..8e4557e7 100644 --- a/examples/tgi_llama.yaml +++ b/examples/tgi_llama.yaml @@ -1,7 +1,7 @@ defaults: - - backend: text-generation-inference # default backend - - benchmark: inference # default benchmark + - backend: py-txi - launcher: inline # default launcher + - benchmark: inference # default benchmark - experiment # inheriting experiment schema - _self_ # for hydra 1.1 compatibility - override hydra/job_logging: colorlog # colorful logging @@ -10,18 +10,16 @@ defaults: experiment_name: tgi_llama backend: - device: cuda + device: cpu device_ids: 0,1 - device_map: true - model: TheBloke/Llama-2-7B-AWQ - quantization_scheme: awq - sharded: false + no_weights: true + model: NousResearch/Nous-Hermes-llama-2-7b benchmark: input_shapes: - batch_size: 1 + batch_size: 4 sequence_length: 256 - new_tokens: 1000 + new_tokens: 100 # hydra/cli specific settings hydra: diff --git a/optimum_benchmark/backends/py_tgi/backend.py b/optimum_benchmark/backends/py_tgi/backend.py deleted file mode 100644 index 42e1b9e9..00000000 --- a/optimum_benchmark/backends/py_tgi/backend.py +++ /dev/null @@ -1,153 +0,0 @@ -import gc -import os -from logging import getLogger -from tempfile import TemporaryDirectory -from typing import Any, Dict, List - -import torch -from py_tgi import TGI -from safetensors.torch import save_file -from transformers import GenerationConfig - -from ...task_utils import TEXT_GENERATION_TASKS -from ..base import Backend -from ..transformers_utils import random_init_weights -from .config import PyTGIConfig - -# bachend logger -LOGGER = getLogger("py-tgi") - - -class PyTGIBackend(Backend[PyTGIConfig]): - NAME: str = "py-tgi" - - def __init__(self, config: PyTGIConfig) -> None: - super().__init__(config) - self.validate_task() - - if self.generation_config is None: - self.generation_config = GenerationConfig() - - LOGGER.info("\t+ Creating backend temporary directory") - self.tmpdir = TemporaryDirectory() - - if self.config.no_weights: - LOGGER.info("\t+ Loading no weights model") - self.load_model_with_no_weights() - else: - LOGGER.info("\t+ Downloading pretrained model") - self.download_pretrained_model() - LOGGER.info("\t+ Preparing generation config") - self.prepare_generation_config() - LOGGER.info("\t+ Loading pretrained model") - self.load_model_from_pretrained() - - self.tmpdir.cleanup() - - def validate_task(self) -> None: - if self.config.task not in TEXT_GENERATION_TASKS: - raise NotImplementedError(f"TGI does not support task {self.config.task}") - - def download_pretrained_model(self) -> None: - LOGGER.info("\t+ Downloading pretrained model") - with torch.device("meta"): - self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs) - - def prepare_generation_config(self) -> None: - LOGGER.info("\t+ Modifying generation config for fixed length generation") - self.generation_config.eos_token_id = None - self.generation_config.pad_token_id = None - model_cache_folder = f"models/{self.config.model}".replace("/", "--") - model_cache_path = f"{self.config.volume}/{model_cache_folder}" - snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}" - snapshot_ref = open(snapshot_file, "r").read().strip() - model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" - LOGGER.info("\t+ Saving new pretrained generation config") - self.generation_config.save_pretrained(save_directory=model_snapshot_path) - - def create_no_weights_model(self) -> None: - self.no_weights_model = os.path.join(self.tmp_dir.name, "no_weights_model") - LOGGER.info("\t+ Creating no weights model directory") - os.makedirs(self.no_weights_model, exist_ok=True) - LOGGER.info("\t+ Creating no weights model state dict") - state_dict = torch.nn.Linear(1, 1).state_dict() - LOGGER.info("\t+ Saving no weights model safetensors") - safetensor = os.path.join(self.no_weights_model, "model.safetensors") - save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"}) - # unlike Transformers api, TGI won't accept any missing tensors - # so we need to materialize the model and resave it - LOGGER.info(f"\t+ Loading no weights model from {self.no_weights_model}") - with random_init_weights(): - self.pretrained_model = self.automodel_class.from_pretrained( - self.no_weights_model, **self.config.hub_kwargs, device_map="auto", _fast_init=False - ) - LOGGER.info("\t+ Saving no weights model") - self.pretrained_model.save_pretrained(save_directory=self.no_weights_model) - LOGGER.info("\t+ Saving no weights model pretrained config") - self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - LOGGER.info("\t+ Saving no weights model pretrained processor") - self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) - LOGGER.info("\t+ Modifying generation config for fixed length generation") - self.generation_config.eos_token_id = None - self.generation_config.pad_token_id = None - LOGGER.info("\t+ Saving new pretrained generation config") - self.generation_config.save_pretrained(save_directory=self.no_weights_model) - - def load_model_with_no_weights(self) -> None: - LOGGER.info("\t+ Creating no weights model") - self.create_no_weights_model() - - original_volume, self.config.volume = self.config.volume, self.tmp_dir.name - original_model, self.config.model = self.config.model, "/data/no_weights_model" - LOGGER.info("\t+ Loading no weights model") - self.load_model_from_pretrained() - self.config.model, self.config.volume = original_model, original_volume - - def load_model_from_pretrained(self) -> None: - self.pretrained_model = TGI( - # model - model=self.config.model, - dtype=self.config.dtype, - quantize=self.config.quantize, - # docker - image=self.config.image, - shm_size=self.config.shm_size, - address=self.config.address, - volume=self.config.volume, - port=self.config.port, - # device - gpus=self.config.gpus, - devices=self.config.devices, - # sharding - sharded=self.config.sharded, - num_shard=self.config.num_shard, - # other - disable_custom_kernels=self.config.disable_custom_kernels, - trust_remote_code=self.config.hub_kwargs.get("trust_remote_code", False), - revision=self.config.hub_kwargs.get("revision", "main"), - ) - - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if "inputs" in inputs: - return {"prompt": self.pretrained_processor.batch_decode(inputs["inputs"].tolist())} - elif "input_ids" in inputs: - return {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} - else: - raise ValueError("inputs must contain either input_ids or inputs") - - def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: - return self.pretrained_model.generate(**inputs, **kwargs, max_new_tokens=1) - - def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: - return self.pretrained_model.generate( - **inputs, do_sample=kwargs.get("do_sample", False), max_new_tokens=kwargs.get("max_new_tokens", 1) - ) - - def clean(self) -> None: - super().clean() - - if hasattr(self, "tmpdir"): - LOGGER.info("\t+ Cleaning temporary directory") - self.tmpdir.cleanup() - - gc.collect() diff --git a/optimum_benchmark/backends/py_tgi/config.py b/optimum_benchmark/backends/py_tgi/config.py deleted file mode 100644 index 62e91321..00000000 --- a/optimum_benchmark/backends/py_tgi/config.py +++ /dev/null @@ -1,55 +0,0 @@ -import os -from dataclasses import dataclass -from typing import List, Optional - -from ...import_utils import py_tgi_version -from ...system_utils import is_nvidia_system, is_rocm_system -from ..config import BackendConfig - - -@dataclass -class PyTGIConfig(BackendConfig): - name: str = "py-tgi" - version: Optional[str] = py_tgi_version() - _target_: str = "optimum_benchmark.backends.py_tgi.backend.PyTGIBackend" - - # optimum benchmark specific - no_weights: bool = False - - # docker options - image: str = "ghcr.io/huggingface/text-generation-inference:latest" - volume: str = os.path.expanduser("~/.cache/huggingface/hub") - address: str = "127.0.0.1" - shm_size: str = "1g" - port: int = 1111 - - gpus: Optional[str] = None # "0,1,2,3" - devices: Optional[List[str]] = None # ["/dev/dri/renderD128", "/dev/dri/renderD129"] - - # sharding options - sharded: Optional[bool] = None # None, True, False - num_shard: Optional[int] = None # None, 1, 2, 4, 8, 16, 32, 64 - # torch options - dtype: Optional[str] = None # None, float32, float16, bfloat16 - quantize: Optional[str] = None # None, bitsandbytes-nf4, bitsandbytes-fp4 - # optimization options - disable_custom_kernels: bool = False # True, False - - def __post_init__(self): - super().__post_init__() - - if self.dtype is not None: - if self.dtype not in ["float32", "float16", "bfloat16"]: - raise ValueError(f"Invalid value for dtype: {self.dtype}") - - if self.quantize is not None: - if self.quantize not in ["bitsandbytes-nf4", "bitsandbytes-fp4", "awq", "gptq"]: - raise ValueError(f"Invalid value for quantize: {self.quantize}") - - if self.gpus is None and self.device == "cuda" and is_nvidia_system(): - self.gpus = self.device_ids - - if self.devices is None and self.device == "cuda" and is_rocm_system(): - device_ids = list(map(int, self.device_ids.split(","))) - renderDs = [file for file in os.listdir("/dev/dri") if file.startswith("renderD")] - self.devices = ["/dev/kfd"] + [f"/dev/dri/{renderDs[i]}" for i in device_ids] diff --git a/optimum_benchmark/backends/py_tgi/__init__.py b/optimum_benchmark/backends/py_txi/__init__.py similarity index 100% rename from optimum_benchmark/backends/py_tgi/__init__.py rename to optimum_benchmark/backends/py_txi/__init__.py diff --git a/optimum_benchmark/backends/py_txi/backend.py b/optimum_benchmark/backends/py_txi/backend.py new file mode 100644 index 00000000..d7b50e89 --- /dev/null +++ b/optimum_benchmark/backends/py_txi/backend.py @@ -0,0 +1,186 @@ +import gc +import os +from logging import getLogger +from tempfile import TemporaryDirectory +from typing import Any, Dict, List + +import torch +from py_txi import TEI, TGI, TEIConfig, TGIConfig +from safetensors.torch import save_file + +from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS +from ..base import Backend +from ..transformers_utils import random_init_weights +from .config import PyTXIConfig + +# bachend logger +LOGGER = getLogger("py-txi") + + +class PyTXIBackend(Backend[PyTXIConfig]): + NAME: str = "py-txi" + + def __init__(self, config: PyTXIConfig) -> None: + super().__init__(config) + + LOGGER.info("\t+ Creating backend temporary directory") + self.tmpdir = TemporaryDirectory() + self.volume = list(self.config.volumes.keys())[0] + + if self.config.no_weights: + LOGGER.info("\t+ Loading no weights model") + self.load_model_with_no_weights() + else: + LOGGER.info("\t+ Downloading pretrained model") + self.download_pretrained_model() + + if self.config.task in TEXT_GENERATION_TASKS: + LOGGER.info("\t+ Preparing generation config") + self.prepare_generation_config() + + LOGGER.info("\t+ Loading pretrained model") + self.load_model_from_pretrained() + + self.tmpdir.cleanup() + + def download_pretrained_model(self) -> None: + # directly downloads pretrained model in volume (/data) to change generation config before loading model + with torch.device("meta"): + self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs, cache_dir=self.volume) + + def prepare_generation_config(self) -> None: + self.generation_config.eos_token_id = -100 + self.generation_config.pad_token_id = -100 + self.generation_config.temperature = 1.0 + self.generation_config.top_p = 1.0 + self.generation_config.top_k = 50 + + model_cache_folder = f"models/{self.config.model}".replace("/", "--") + model_cache_path = f"{self.volume}/{model_cache_folder}" + snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}" + snapshot_ref = open(snapshot_file, "r").read().strip() + model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" + LOGGER.info("\t+ Saving new pretrained generation config") + self.generation_config.save_pretrained(save_directory=model_snapshot_path) + + def create_no_weights_model(self) -> None: + self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model") + LOGGER.info("\t+ Creating no weights model directory") + os.makedirs(self.no_weights_model, exist_ok=True) + LOGGER.info("\t+ Creating no weights model state dict") + state_dict = torch.nn.Linear(1, 1).state_dict() + LOGGER.info("\t+ Saving no weights model safetensors") + safetensor = os.path.join(self.no_weights_model, "model.safetensors") + save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"}) + LOGGER.info("\t+ Saving no weights model pretrained config") + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) + LOGGER.info("\t+ Saving no weights model pretrained processor") + self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) + # unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model + LOGGER.info(f"\t+ Loading no weights model from {self.no_weights_model}") + with random_init_weights(): + self.pretrained_model = self.automodel_class.from_pretrained( + self.no_weights_model, **self.config.hub_kwargs, device_map="auto", _fast_init=False + ) + LOGGER.info("\t+ Saving no weights model") + self.pretrained_model.save_pretrained(save_directory=self.no_weights_model) + del self.pretrained_model + torch.cuda.empty_cache() + + if self.config.task in TEXT_GENERATION_TASKS: + LOGGER.info("\t+ Modifying generation config for fixed length generation") + self.generation_config.eos_token_id = -100 + self.generation_config.pad_token_id = -100 + self.generation_config.temperature = 1.0 + self.generation_config.top_p = 1.0 + self.generation_config.top_k = 50 + + LOGGER.info("\t+ Saving new pretrained generation config") + self.generation_config.save_pretrained(save_directory=self.no_weights_model) + + def load_model_with_no_weights(self) -> None: + LOGGER.info("\t+ Creating no weights model") + self.create_no_weights_model() + + original_volumes, self.config.volumes = self.config.volumes, {self.tmpdir.name: {"bind": "/data", "mode": "rw"}} + original_model, self.config.model = self.config.model, "/data/no_weights_model" + LOGGER.info("\t+ Loading no weights model") + self.load_model_from_pretrained() + self.config.model, self.config.volumes = original_model, original_volumes + + def load_model_from_pretrained(self) -> None: + if self.config.task in TEXT_GENERATION_TASKS: + self.pretrained_model = TGI( + config=TGIConfig( + model_id=self.config.model, + gpus=self.config.gpus, + devices=self.config.devices, + volumes=self.config.volumes, + ports=self.config.ports, + environment=self.config.environment, + dtype=self.config.dtype, + sharded=self.config.sharded, + quantize=self.config.quantize, + num_shard=self.config.num_shard, + enable_cuda_graphs=self.config.enable_cuda_graphs, + disable_custom_kernels=self.config.disable_custom_kernels, + trust_remote_code=self.config.trust_remote_code, + max_concurrent_requests=self.config.max_concurrent_requests, + ), + ) + + elif self.config.task in TEXT_EMBEDDING_TASKS: + self.pretrained_model = TEI( + config=TEIConfig( + model_id=self.config.model, + gpus=self.config.gpus, + devices=self.config.devices, + volumes=self.config.volumes, + ports=self.config.ports, + environment=self.config.environment, + dtype=self.config.dtype, + pooling=self.config.pooling, + max_concurrent_requests=self.config.max_concurrent_requests, + ), + ) + else: + raise NotImplementedError(f"TXI does not support task {self.config.task}") + + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + input_key = "inputs" if "inputs" in inputs else "input_ids" + inputs = self.pretrained_processor.batch_decode(inputs[input_key].tolist()) + + if self.config.task in TEXT_GENERATION_TASKS: + return {"prompt": inputs} + elif self.config.task in TEXT_EMBEDDING_TASKS: + return {"text": inputs} + else: + raise NotImplementedError(f"TXI does not support task {self.config.task}") + + def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: + if self.config.task in TEXT_GENERATION_TASKS: + return self.pretrained_model.generate( + **inputs, + do_sample=kwargs.get("do_sample", False), + max_new_tokens=kwargs.get("max_new_tokens", 1), + ) + elif self.config.task in TEXT_EMBEDDING_TASKS: + return self.pretrained_model.encode(**inputs, **kwargs) + else: + raise NotImplementedError(f"TXI does not support task {self.config.task}") + + def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: + return self.pretrained_model.generate( + **inputs, + do_sample=kwargs.get("do_sample", False), + max_new_tokens=kwargs.get("max_new_tokens", 1), + ) + + def clean(self) -> None: + super().clean() + + if hasattr(self, "tmpdir"): + LOGGER.info("\t+ Cleaning temporary directory") + self.tmpdir.cleanup() + + gc.collect() diff --git a/optimum_benchmark/backends/py_txi/config.py b/optimum_benchmark/backends/py_txi/config.py new file mode 100644 index 00000000..030f02bb --- /dev/null +++ b/optimum_benchmark/backends/py_txi/config.py @@ -0,0 +1,80 @@ +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +from ...import_utils import py_txi_version +from ...system_utils import is_nvidia_system, is_rocm_system +from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS +from ..config import BackendConfig + + +@dataclass +class PyTXIConfig(BackendConfig): + name: str = "py-txi" + version: Optional[str] = py_txi_version() + _target_: str = "optimum_benchmark.backends.py_txi.backend.PyTXIBackend" + + # optimum benchmark specific + no_weights: bool = False + + # Image to use for the container + image: Optional[str] = None + # Shared memory size for the container + shm_size: str = "1g" + # List of custom devices to forward to the container e.g. ["/dev/kfd", "/dev/dri"] for ROCm + devices: Optional[List[str]] = None + # NVIDIA-docker GPU device options e.g. "all" (all) or "0,1,2,3" (ids) or 4 (count) + gpus: Optional[Union[str, int]] = None + # Things to forward to the container + ports: Dict[str, Any] = field( + default_factory=lambda: {"80/tcp": ("127.0.0.1", 0)}, + metadata={"help": "Dictionary of ports to expose from the container."}, + ) + volumes: Dict[str, Any] = field( + default_factory=lambda: {os.path.expanduser("~/.cache/huggingface/hub"): {"bind": "/data", "mode": "rw"}}, + metadata={"help": "Dictionary of volumes to mount inside the container."}, + ) + environment: Dict[str, str] = field( + default_factory=lambda: {"HUGGING_FACE_HUB_TOKEN": os.environ.get("HUGGING_FACE_HUB_TOKEN", "")}, + metadata={"help": "Dictionary of environment variables to forward to the container."}, + ) + + # Common options + dtype: Optional[str] = None + max_concurrent_requests: Optional[int] = None + + # TGI specific + sharded: Optional[str] = None + quantize: Optional[str] = None + num_shard: Optional[int] = None + enable_cuda_graphs: Optional[bool] = None + disable_custom_kernels: Optional[bool] = None + trust_remote_code: Optional[bool] = None + + # TEI specific + pooling: Optional[str] = None + + def __post_init__(self): + super().__post_init__() + + if self.task not in TEXT_GENERATION_TASKS + TEXT_EMBEDDING_TASKS: + raise NotImplementedError(f"TXI does not support task {self.task}") + + if self.task in TEXT_GENERATION_TASKS: + self.image = "ghcr.io/huggingface/text-generation-inference:latest" + elif self.task in TEXT_EMBEDDING_TASKS: + self.image = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest" + + if self.max_concurrent_requests is None: + if self.task in TEXT_GENERATION_TASKS: + self.max_concurrent_requests = 128 + elif self.task in TEXT_EMBEDDING_TASKS: + self.max_concurrent_requests = 512 + + if self.device_ids is not None and is_nvidia_system() and self.gpus is None: + self.gpus = self.device_ids + + if self.device_ids is not None and is_rocm_system() and self.devices is None: + ids = list(map(int, self.device_ids.split(","))) + renderDs = [file for file in os.listdir("/dev/dri") if file.startswith("renderD")] + self.devices = ["/dev/kfd"] + [f"/dev/dri/{renderDs[i]}" for i in ids] diff --git a/optimum_benchmark/backends/transformers_utils.py b/optimum_benchmark/backends/transformers_utils.py index 93c35560..fcd87741 100644 --- a/optimum_benchmark/backends/transformers_utils.py +++ b/optimum_benchmark/backends/transformers_utils.py @@ -32,7 +32,7 @@ def get_transformers_generation_config(model: str, **kwargs) -> Optional["Genera # sometimes contains information about the model's input shapes that are not available in the config return GenerationConfig.from_pretrained(model, **kwargs) except Exception: - return None + return GenerationConfig() def get_transformers_pretrained_processor(model: str, **kwargs) -> Optional["PretrainedProcessor"]: diff --git a/optimum_benchmark/benchmarks/inference/benchmark.py b/optimum_benchmark/benchmarks/inference/benchmark.py index cd54fee3..5cc4830d 100644 --- a/optimum_benchmark/benchmarks/inference/benchmark.py +++ b/optimum_benchmark/benchmarks/inference/benchmark.py @@ -145,8 +145,6 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None: LOGGER.info("\t+ Creating inference latency tracker") self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) if backend.config.task in TEXT_GENERATION_TASKS: - LOGGER.info("\t+ Creating latency logits processor tracker") - self.run_text_generation_latency_tracking(backend) elif backend.config.task in IMAGE_DIFFUSION_TASKS: self.run_image_diffusion_latency_tracking(backend) diff --git a/optimum_benchmark/benchmarks/report.py b/optimum_benchmark/benchmarks/report.py index 73ce26a9..831c92e2 100644 --- a/optimum_benchmark/benchmarks/report.py +++ b/optimum_benchmark/benchmarks/report.py @@ -107,5 +107,5 @@ def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": return cls(**aggregated_measurements) @property - def file_name(self) -> str: + def default_file_name(self) -> str: return "benchmark_report.json" diff --git a/optimum_benchmark/cli.py b/optimum_benchmark/cli.py index 0806c09e..867e68a3 100644 --- a/optimum_benchmark/cli.py +++ b/optimum_benchmark/cli.py @@ -10,7 +10,7 @@ from .backends.neural_compressor.config import INCConfig from .backends.onnxruntime.config import ORTConfig from .backends.openvino.config import OVConfig -from .backends.py_tgi.config import PyTGIConfig +from .backends.py_txi.config import PyTXIConfig from .backends.pytorch.config import PyTorchConfig from .backends.tensorrt_llm.config import TRTLLMConfig from .backends.torch_ort.config import TorchORTConfig @@ -34,7 +34,7 @@ cs.store(group="backend", name=TorchORTConfig.name, node=TorchORTConfig) cs.store(group="backend", name=TRTLLMConfig.name, node=TRTLLMConfig) cs.store(group="backend", name=INCConfig.name, node=INCConfig) -cs.store(group="backend", name=PyTGIConfig.name, node=PyTGIConfig) +cs.store(group="backend", name=PyTXIConfig.name, node=PyTXIConfig) cs.store(group="backend", name=LLMSwarmConfig.name, node=LLMSwarmConfig) # benchmarks configurations cs.store(group="benchmark", name=TrainingConfig.name, node=TrainingConfig) diff --git a/optimum_benchmark/experiment.py b/optimum_benchmark/experiment.py index 593a4ea7..92523395 100644 --- a/optimum_benchmark/experiment.py +++ b/optimum_benchmark/experiment.py @@ -46,7 +46,7 @@ class ExperimentConfig(PushToHubMixin): environment: Dict = field(default_factory=lambda: {**get_system_info(), **get_hf_libs_info()}) @property - def file_name(self) -> str: + def default_file_name(self) -> str: return "experiment_config.json" diff --git a/optimum_benchmark/hub_utils.py b/optimum_benchmark/hub_utils.py index e02cb5dd..1218dd65 100644 --- a/optimum_benchmark/hub_utils.py +++ b/optimum_benchmark/hub_utils.py @@ -1,6 +1,6 @@ import os import tempfile -from dataclasses import asdict +from dataclasses import asdict, dataclass from json import dump from logging import getLogger from typing import Any, Dict, Optional, Union @@ -12,13 +12,15 @@ LOGGER = getLogger(__name__) +@dataclass class PushToHubMixin: """ A Mixin to push artifacts to the Hugging Face Hub """ def to_dict(self) -> Dict[str, Any]: - return asdict(self) + data_dict = asdict(self) + return data_dict def to_flat_dict(self) -> Dict[str, Any]: report_dict = self.to_dict() @@ -39,11 +41,24 @@ def to_dataframe(self) -> pd.DataFrame: def to_csv(self, path: str) -> None: self.to_dataframe().to_csv(path, index=False) + def save_pretrained( + self, + save_path: Optional[Union[str, os.PathLike]] = None, + file_name: Optional[Union[str, os.PathLike]] = None, + flat: bool = False, + ) -> None: + save_path = save_path or self.default_save_path + file_name = file_name or self.default_file_name + + file_path = os.path.join(save_path, file_name) + os.makedirs(save_path, exist_ok=True) + self.to_json(file_path, flat=flat) + def push_to_hub( self, repo_id: str, - file_name: Optional[Union[str, os.PathLike]] = None, - path_in_repo: Optional[str] = None, + save_path: Optional[str] = None, + file_name: Optional[str] = None, flat: bool = False, **kwargs, ) -> str: @@ -52,18 +67,24 @@ def push_to_hub( repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id with tempfile.TemporaryDirectory() as tmpdir: - file_name = file_name or self.file_name + save_path = save_path or self.default_save_path + file_name = file_name or self.default_file_name + path_or_fileobj = os.path.join(tmpdir, file_name) - path_in_repo = path_in_repo or file_name + path_in_repo = os.path.join(save_path, file_name) self.to_json(path_or_fileobj, flat=flat) upload_file( - path_or_fileobj=path_or_fileobj, - path_in_repo=path_in_repo, repo_id=repo_id, + path_in_repo=path_in_repo, + path_or_fileobj=path_or_fileobj, **kwargs, ) @property - def file_name(self) -> str: + def default_file_name(self) -> str: return "config.json" + + @property + def default_save_path(self) -> str: + return "benchmarks" diff --git a/optimum_benchmark/import_utils.py b/optimum_benchmark/import_utils.py index 75cfec66..be0a0890 100644 --- a/optimum_benchmark/import_utils.py +++ b/optimum_benchmark/import_utils.py @@ -27,7 +27,7 @@ _tensorrt_llm_available = importlib.util.find_spec("tensorrt_llm") is not None _psutil_available = importlib.util.find_spec("psutil") is not None _optimum_benchmark_available = importlib.util.find_spec("optimum_benchmark") is not None -_py_tgi_available = importlib.util.find_spec("py_tgi") is not None +_py_txi_available = importlib.util.find_spec("py_txi") is not None _pyrsmi_available = importlib.util.find_spec("pyrsmi") is not None _llm_swarm_available = importlib.util.find_spec("llm_swarm") is not None @@ -40,8 +40,8 @@ def is_pyrsmi_available(): return _pyrsmi_available -def is_py_tgi_available(): - return _py_tgi_available +def is_py_txi_available(): + return _py_txi_available def is_psutil_available(): @@ -198,9 +198,9 @@ def optimum_benchmark_version(): return importlib.metadata.version("optimum_benchmark") -def py_tgi_version(): - if _py_tgi_available: - return importlib.metadata.version("py_tgi") +def py_txi_version(): + if _py_txi_available: + return importlib.metadata.version("py_txi") def llm_swarm_version(): diff --git a/optimum_benchmark/task_utils.py b/optimum_benchmark/task_utils.py index dfa3f808..8d0210ba 100644 --- a/optimum_benchmark/task_utils.py +++ b/optimum_benchmark/task_utils.py @@ -99,6 +99,8 @@ TEXT_GENERATION_TASKS = ["image-to-text", "text-generation", "text2text-generation", "automatic-speech-recognition"] +TEXT_EMBEDDING_TASKS = ["feature-extraction", "fill-mask"] + def map_from_synonym(task: str) -> str: if task in _SYNONYM_TASK_MAP: diff --git a/setup.py b/setup.py index 50dc0528..eb2987bc 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ import importlib.util import os +import subprocess from setuptools import find_packages, setup @@ -22,8 +23,20 @@ "pandas", ] -USE_CUDA = os.environ.get("USE_CUDA", None) == "1" -USE_ROCM = os.environ.get("USE_ROCM", None) == "1" +try: + subprocess.run(["nvidia-smi"], check=True) + IS_NVIDIA_SYSTEM = True +except Exception: + IS_NVIDIA_SYSTEM = False + +try: + subprocess.run(["rocm-smi"], check=True) + IS_ROCM_SYSTEM = True +except Exception: + IS_ROCM_SYSTEM = False + +USE_CUDA = (os.environ.get("USE_CUDA", None) == "1") or IS_NVIDIA_SYSTEM +USE_ROCM = (os.environ.get("USE_ROCM", None) == "1") or IS_ROCM_SYSTEM if USE_CUDA: INSTALL_REQUIRES.append("nvidia-ml-py") @@ -48,7 +61,7 @@ "torch-ort": ["torch-ort", "onnxruntime-training", f"optimum>={MIN_OPTIMUM_VERSION}"], # other backends "llm-swarm": ["llm-swarm@git+https://github.com/huggingface/llm-swarm.git"], - "py-tgi": ["py-tgi==0.1.3"], + "py-txi": ["py-txi@git+https://github.com/IlyasMoutawwakil/py-txi.git"], # optional dependencies "codecarbon": ["codecarbon"], "deepspeed": ["deepspeed"], diff --git a/tests/configs/_bert_.yaml b/tests/configs/_bert_.yaml index 53dcc824..7d9ec98e 100644 --- a/tests/configs/_bert_.yaml +++ b/tests/configs/_bert_.yaml @@ -1,3 +1,3 @@ backend: - task: text-classification + task: fill-mask model: bert-base-uncased diff --git a/tests/configs/cpu_inference_py_txi_bert.yaml b/tests/configs/cpu_inference_py_txi_bert.yaml new file mode 100644 index 00000000..c9a49422 --- /dev/null +++ b/tests/configs/cpu_inference_py_txi_bert.yaml @@ -0,0 +1,13 @@ +defaults: + - backend: py-txi + # order of inheritance, last one overrides previous ones + - _base_ # inherits from base config + - _inference_ # inherits from inference config + - _bert_ # inherits from gpt config + - _cpu_ # inherits from cpu config + - _self_ # hydra 1.1 compatibility + +experiment_name: cpu_inference_py_txi_bert + +backend: + pooling: cls diff --git a/tests/configs/cpu_inference_py_tgi_gpt.yaml b/tests/configs/cpu_inference_py_txi_gpt.yaml similarity index 81% rename from tests/configs/cpu_inference_py_tgi_gpt.yaml rename to tests/configs/cpu_inference_py_txi_gpt.yaml index c0805b71..9b63a685 100644 --- a/tests/configs/cpu_inference_py_tgi_gpt.yaml +++ b/tests/configs/cpu_inference_py_txi_gpt.yaml @@ -1,5 +1,5 @@ defaults: - - backend: py-tgi + - backend: py-txi # order of inheritance, last one overrides previous ones - _base_ # inherits from base config - _inference_ # inherits from inference config @@ -7,4 +7,4 @@ defaults: - _cpu_ # inherits from cpu config - _self_ # hydra 1.1 compatibility -experiment_name: cpu_inference_py_tgi_gpt +experiment_name: cpu_inference_py_txi_gpt