diff --git a/examples/cuda_pytorch_llama_quants.py b/examples/cuda_pytorch_llama_quants.py index 01d492cb..e6221ca4 100644 --- a/examples/cuda_pytorch_llama_quants.py +++ b/examples/cuda_pytorch_llama_quants.py @@ -10,23 +10,33 @@ WEIGHTS_CONFIGS = { "float16": { "torch_dtype": "float16", - "quantization_scheme": None, "quantization_config": {}, }, "4bit-awq-gemm": { "torch_dtype": "float16", - "quantization_scheme": "awq", - "quantization_config": {"bits": 4, "version": "gemm"}, + "quantization_config": { + "quant_method": "awq", + "bits": 4, + "version": "gemm", + }, }, "4bit-gptq-exllama-v2": { "torch_dtype": "float16", - "quantization_scheme": "gptq", - "quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256}, + "quantization_config": { + "quant_method": "gptq", + "bits": 4, + "use_exllama ": True, + "version": 2, + "model_seqlen": 256, + }, }, "torchao-int4wo-128": { "torch_dtype": "bfloat16", - "quantization_scheme": "torchao", - "quantization_config": {"quant_type": "int4_weight_only", "group_size": 128}, + "quantization_config": { + "quant_method": "torchao", + "quant_type": "int4_weight_only", + "group_size": 128, + }, }, } diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index dd11ddfd..b1c505fe 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -308,7 +308,10 @@ def process_quantization_config(self) -> None: self.logger.info("\t+ Processing AutoQuantization config") self.quantization_config = AutoQuantizationConfig.from_dict( - dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) + dict( + getattr(self.pretrained_config, "quantization_config", {}), + **self.config.quantization_config, + ) ) @property diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index 61f9dfc0..96ff44a2 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -1,14 +1,17 @@ from dataclasses import dataclass, field +from logging import getLogger from typing import Any, Dict, Optional from ...import_utils import torch_version -from ...system_utils import is_rocm_system from ..config import BackendConfig AMP_DTYPES = ["bfloat16", "float16"] TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"] -QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}, "torchao": {}} +QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}} + + +LOGGER = getLogger(__name__) @dataclass @@ -66,15 +69,17 @@ def __post_init__(self): raise ValueError(f"`autocast_dtype` must be one of {AMP_DTYPES}. Got {self.autocast_dtype} instead.") if self.quantization_scheme is not None: - if self.quantization_scheme not in QUANTIZATION_CONFIGS: - raise ValueError( - f"`quantization_scheme` must be one of {list(QUANTIZATION_CONFIGS.keys())}. " - f"Got {self.quantization_scheme} instead." - ) - - if self.quantization_scheme == "bnb" and is_rocm_system(): - raise ValueError("BitsAndBytes is not supported on ROCm GPUs. Please disable it.") - - if self.quantization_config: - QUANTIZATION_CONFIG = QUANTIZATION_CONFIGS[self.quantization_scheme] - self.quantization_config = {**QUANTIZATION_CONFIG, **self.quantization_config} + LOGGER.warning( + "`backend.quantization_scheme` is deprecated and will be removed in a future version. " + "Please use `quantization_config.quant_method` instead." + ) + if self.quantization_config is None: + self.quantization_config = {"quant_method": self.quantization_scheme} + else: + self.quantization_config["quant_method"] = self.quantization_scheme + + if self.quantization_config is not None: + self.quantization_config = dict( + QUANTIZATION_CONFIGS.get(self.quantization_scheme, {}), # default config + **self.quantization_config, # user config (overwrites default) + )