Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto quantization #313

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 14 additions & 46 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
from datasets import Dataset
from safetensors.torch import save_file
from transformers import (
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
TorchAoConfig,
Trainer,
TrainerCallback,
TrainerState,
TrainingArguments,
)
from transformers.quantizers import AutoQuantizationConfig

from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
from ..base import Backend
Expand Down Expand Up @@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None:

def process_quantization_config(self) -> None:
if self.is_gptq_quantized:
self.logger.info("\t+ Processing GPTQ config")

try:
import exllamav2_kernels # noqa: F401
except ImportError:
Expand All @@ -299,12 +294,7 @@ def process_quantization_config(self) -> None:
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
)

self.quantization_config = GPTQConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_awq_quantized:
self.logger.info("\t+ Processing AWQ config")

try:
import exlv2_ext # noqa: F401
except ImportError:
Expand All @@ -316,55 +306,30 @@ def process_quantization_config(self) -> None:
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
)

self.quantization_config = AwqConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_bnb_quantized:
self.logger.info("\t+ Processing BitsAndBytes config")
self.quantization_config = BitsAndBytesConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_torchao_quantized:
self.logger.info("\t+ Processing TorchAO config")
self.quantization_config = TorchAoConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
else:
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")
self.logger.info("\t+ Processing AutoQuantization config")
self.quantization_config = AutoQuantizationConfig.from_dict(
dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)

@property
def is_quantized(self) -> bool:
return self.config.quantization_scheme is not None or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) is not None
)

@property
def is_bnb_quantized(self) -> bool:
return self.config.quantization_scheme == "bnb" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "bnb"
and self.pretrained_config.quantization_config.get("quant_method") is not None
)

@property
def is_gptq_quantized(self) -> bool:
return self.config.quantization_scheme == "gptq" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
and self.pretrained_config.quantization_config.get("quant_method") == "gptq"
)

@property
def is_awq_quantized(self) -> bool:
return self.config.quantization_scheme == "awq" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
)

@property
def is_torchao_quantized(self) -> bool:
return self.config.quantization_scheme == "torchao" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao"
and self.pretrained_config.quantization_config.get("quant_method") == "awq"
)

@property
Expand All @@ -376,11 +341,11 @@ def is_exllamav2(self) -> bool:
(
hasattr(self.pretrained_config, "quantization_config")
and hasattr(self.pretrained_config.quantization_config, "exllama_config")
and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2
and self.pretrained_config.quantization_config.exllama_config.get("version") == 2
)
or (
"exllama_config" in self.config.quantization_config
and self.config.quantization_config["exllama_config"].get("version", None) == 2
and self.config.quantization_config["exllama_config"].get("version") == 2
)
)
)
Expand All @@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
if hasattr(torch, self.config.torch_dtype):
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
else:
kwargs["torch_dtype"] = self.config.torch_dtype

if self.is_quantized:
kwargs["quantization_config"] = self.quantization_config
Expand Down
4 changes: 0 additions & 4 deletions optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ...system_utils import is_rocm_system
from ..config import BackendConfig

DEVICE_MAPS = ["auto", "sequential"]
AMP_DTYPES = ["bfloat16", "float16"]
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]

Expand Down Expand Up @@ -60,9 +59,6 @@ def __post_init__(self):
"Please remove it from the `model_kwargs` and set it in the backend config directly."
)

if self.device_map is not None and self.device_map not in DEVICE_MAPS:
raise ValueError(f"`device_map` must be one of {DEVICE_MAPS}. Got {self.device_map} instead.")

if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")

Expand Down