Skip to content

Commit

Permalink
use auto quantization dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 31, 2025
1 parent abc587c commit 1c5e33e
Showing 1 changed file with 12 additions and 44 deletions.
56 changes: 12 additions & 44 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
from datasets import Dataset
from safetensors.torch import save_file
from transformers import (
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
TorchAoConfig,
Trainer,
TrainerCallback,
TrainerState,
TrainingArguments,
)
from transformers.quantizers import AutoQuantizationConfig

from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
from ..base import Backend
Expand Down Expand Up @@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None:

def process_quantization_config(self) -> None:
if self.is_gptq_quantized:
self.logger.info("\t+ Processing GPTQ config")

try:
import exllamav2_kernels # noqa: F401
except ImportError:
Expand All @@ -299,12 +294,7 @@ def process_quantization_config(self) -> None:
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
)

self.quantization_config = GPTQConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_awq_quantized:
self.logger.info("\t+ Processing AWQ config")

try:
import exlv2_ext # noqa: F401
except ImportError:
Expand All @@ -316,21 +306,10 @@ def process_quantization_config(self) -> None:
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
)

self.quantization_config = AwqConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_bnb_quantized:
self.logger.info("\t+ Processing BitsAndBytes config")
self.quantization_config = BitsAndBytesConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_torchao_quantized:
self.logger.info("\t+ Processing TorchAO config")
self.quantization_config = TorchAoConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
else:
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")
self.logger.info("\t+ Processing AutoQuantization config")
self.quantization_config = AutoQuantizationConfig.from_dict(
getattr(self.pretrained_config, "quantization_config", {}).update(self.config.quantization_config)
)

@property
def is_quantized(self) -> bool:
Expand All @@ -339,13 +318,6 @@ def is_quantized(self) -> bool:
and self.pretrained_config.quantization_config.get("quant_method", None) is not None
)

@property
def is_bnb_quantized(self) -> bool:
return self.config.quantization_scheme == "bnb" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "bnb"
)

@property
def is_gptq_quantized(self) -> bool:
return self.config.quantization_scheme == "gptq" or (
Expand All @@ -360,13 +332,6 @@ def is_awq_quantized(self) -> bool:
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
)

@property
def is_torchao_quantized(self) -> bool:
return self.config.quantization_scheme == "torchao" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao"
)

@property
def is_exllamav2(self) -> bool:
return (
Expand All @@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
if hasattr(torch, self.config.torch_dtype):
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
else:
kwargs["torch_dtype"] = self.config.torch_dtype

if self.is_quantized:
kwargs["quantization_config"] = self.quantization_config
Expand Down Expand Up @@ -436,9 +404,9 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict

@torch.inference_mode()
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, (
"For prefilling, max_new_tokens and min_new_tokens must be equal to 1"
)
assert (
kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1
), "For prefilling, max_new_tokens and min_new_tokens must be equal to 1"
return self.pretrained_model.generate(**inputs, **kwargs)

@torch.inference_mode()
Expand Down

0 comments on commit 1c5e33e

Please sign in to comment.