Skip to content

Commit

Permalink
Fix auto quantization and deprecate quantization scheme (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jan 31, 2025
1 parent 4eb7a37 commit e0b65f8
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
24 changes: 17 additions & 7 deletions examples/cuda_pytorch_llama_quants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,33 @@
WEIGHTS_CONFIGS = {
"float16": {
"torch_dtype": "float16",
"quantization_scheme": None,
"quantization_config": {},
},
"4bit-awq-gemm": {
"torch_dtype": "float16",
"quantization_scheme": "awq",
"quantization_config": {"bits": 4, "version": "gemm"},
"quantization_config": {
"quant_method": "awq",
"bits": 4,
"version": "gemm",
},
},
"4bit-gptq-exllama-v2": {
"torch_dtype": "float16",
"quantization_scheme": "gptq",
"quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256},
"quantization_config": {
"quant_method": "gptq",
"bits": 4,
"use_exllama ": True,
"version": 2,
"model_seqlen": 256,
},
},
"torchao-int4wo-128": {
"torch_dtype": "bfloat16",
"quantization_scheme": "torchao",
"quantization_config": {"quant_type": "int4_weight_only", "group_size": 128},
"quantization_config": {
"quant_method": "torchao",
"quant_type": "int4_weight_only",
"group_size": 128,
},
},
}

Expand Down
5 changes: 4 additions & 1 deletion optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def process_quantization_config(self) -> None:

self.logger.info("\t+ Processing AutoQuantization config")
self.quantization_config = AutoQuantizationConfig.from_dict(
dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
dict(
getattr(self.pretrained_config, "quantization_config", {}),
**self.config.quantization_config,
)
)

@property
Expand Down
33 changes: 19 additions & 14 deletions optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict, Optional

from ...import_utils import torch_version
from ...system_utils import is_rocm_system
from ..config import BackendConfig

AMP_DTYPES = ["bfloat16", "float16"]
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]

QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}, "torchao": {}}
QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}}


LOGGER = getLogger(__name__)


@dataclass
Expand Down Expand Up @@ -66,15 +69,17 @@ def __post_init__(self):
raise ValueError(f"`autocast_dtype` must be one of {AMP_DTYPES}. Got {self.autocast_dtype} instead.")

if self.quantization_scheme is not None:
if self.quantization_scheme not in QUANTIZATION_CONFIGS:
raise ValueError(
f"`quantization_scheme` must be one of {list(QUANTIZATION_CONFIGS.keys())}. "
f"Got {self.quantization_scheme} instead."
)

if self.quantization_scheme == "bnb" and is_rocm_system():
raise ValueError("BitsAndBytes is not supported on ROCm GPUs. Please disable it.")

if self.quantization_config:
QUANTIZATION_CONFIG = QUANTIZATION_CONFIGS[self.quantization_scheme]
self.quantization_config = {**QUANTIZATION_CONFIG, **self.quantization_config}
LOGGER.warning(
"`backend.quantization_scheme` is deprecated and will be removed in a future version. "
"Please use `quantization_config.quant_method` instead."
)
if self.quantization_config is None:
self.quantization_config = {"quant_method": self.quantization_scheme}
else:
self.quantization_config["quant_method"] = self.quantization_scheme

if self.quantization_config is not None:
self.quantization_config = dict(
QUANTIZATION_CONFIGS.get(self.quantization_scheme, {}), # default config
**self.quantization_config, # user config (overwrites default)
)

0 comments on commit e0b65f8

Please sign in to comment.