Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 13, 2024
1 parent 3ab987c commit f7e4069
Show file tree
Hide file tree
Showing 18 changed files with 180 additions and 280 deletions.
2 changes: 1 addition & 1 deletion examples/neural_compressor_ptq_bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:
- override hydra/job_logging: colorlog # colorful logging
- override hydra/hydra_logging: colorlog # colorful logging

experiment_name: openvino_static_quant_bert
experiment_name: neural_compressor_ptq_bert

backend:
device: cpu
Expand Down
109 changes: 0 additions & 109 deletions optimum_benchmark/aggregators/__init__.py

This file was deleted.

76 changes: 40 additions & 36 deletions optimum_benchmark/backends/neural_compressor/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,19 @@
from logging import getLogger
from tempfile import TemporaryDirectory

from ...generators.dataset_generator import DatasetGenerator
from ..transformers_utils import randomize_weights
from .utils import TASKS_TO_INCMODELS
from .config import INCConfig
from ..base import Backend

import torch
from hydra.utils import get_class
from transformers.utils import ModelOutput
from transformers.modeling_utils import no_init_weights
from transformers.utils.logging import set_verbosity_error
from optimum.intel.neural_compressor.quantization import INCQuantizer
from neural_compressor.config import (
PostTrainingQuantConfig,
AccuracyCriterion,
TuningCriterion,
)

from ...generators.dataset_generator import DatasetGenerator
from .utils import TASKS_TO_INCMODELS
from .config import INCConfig
from ..base import Backend
from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion, TuningCriterion

# disable transformers logging
set_verbosity_error()
Expand All @@ -34,9 +31,7 @@ def __init__(self, config: INCConfig):
super().__init__(config)
self.validate_task()

self.incmodel_class = get_class(TASKS_TO_INCMODELS[self.config.task])
LOGGER.info(f"Using INCModel class {self.incmodel_class.__name__}")

LOGGER.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

if self.config.ptq_quantization:
Expand All @@ -52,57 +47,65 @@ def __init__(self, config: INCConfig):
else:
self.load_incmodel_from_pretrained()

self.tmpdir.cleanup()

def validate_task(self) -> None:
if self.config.task not in TASKS_TO_INCMODELS:
raise NotImplementedError(f"INCBackend does not support task {self.config.task}")

self.incmodel_class = get_class(TASKS_TO_INCMODELS[self.config.task])
LOGGER.info(f"Using INCModel class {self.incmodel_class.__name__}")

def load_automodel_from_pretrained(self) -> None:
LOGGER.info("\t+ Loading AutoModel from pretrained")
self.pretrained_model = self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs)

def load_automodel_with_no_weights(self) -> None:
no_weights_model = os.path.join(self.tmpdir.name, "no_weights")
def create_no_weights_model(self) -> None:
LOGGER.info("\t+ Creating no weights model state_dict")
state_dict = torch.nn.Linear(1, 1).state_dict()

if not os.path.exists(no_weights_model):
LOGGER.info("\t+ Creating no weights model directory")
os.makedirs(no_weights_model)
LOGGER.info("\t+ Creating no weights model directory")
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights")
os.makedirs(self.no_weights_model, exist_ok=True)

LOGGER.info("\t+ Saving pretrained config")
self.pretrained_config.save_pretrained(save_directory=no_weights_model)
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

LOGGER.info("\t+ Creating no weights model")
state_dict = torch.nn.Linear(1, 1).state_dict()
LOGGER.info("\t+ Saving no weights model state_dict")
torch.save(state_dict, os.path.join(self.no_weights_model, "pytorch_model.bin"))

LOGGER.info("\t+ Saving no weights model")
torch.save(state_dict, os.path.join(no_weights_model, "pytorch_model.bin"))
def load_automodel_with_no_weights(self) -> None:
self.create_no_weights_model()

LOGGER.info("\t+ Loading no weights model")
with no_init_weights():
original_model = self.config.model
self.config.model = no_weights_model
self.config.model = self.no_weights_model
LOGGER.info("\t+ Loading no weights model")
self.load_automodel_from_pretrained()
self.config.model = original_model

LOGGER.info("\t+ Randomizing model weights")
randomize_weights(self.pretrained_model)
LOGGER.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

def load_incmodel_from_pretrained(self) -> None:
LOGGER.info("\t+ Loading INCModel from pretrained")
self.pretrained_model = self.incmodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs)

def load_incmodel_with_no_weights(self) -> None:
no_weights_model = os.path.join(self.tmpdir.name, "no_weights")

LOGGER.info("\t+ Loading AutoModel with no weights")
self.load_automodel_with_no_weights()
self.delete_pretrained_model()
self.create_no_weights_model()

LOGGER.info("\t+ Loading INCModel with no weights")
with no_init_weights():
original_model = self.config.model
self.config.model = no_weights_model
self.config.model = self.no_weights_model
LOGGER.info("\t+ Loading no weights model")
self.load_incmodel_from_pretrained()
self.config.model = original_model

LOGGER.info("\t+ Randomizing model weights")
randomize_weights(self.pretrained_model.model)
LOGGER.info("\t+ Tying model weights")
self.pretrained_model.model.tie_weights()

def quantize_automodel(self) -> None:
LOGGER.info("\t+ Attempting to quantize model")
quantized_model_path = f"{self.tmpdir.name}/quantized"
Expand Down Expand Up @@ -134,7 +137,7 @@ def quantize_automodel(self) -> None:
task=self.config.task,
dataset_shapes=dataset_shapes,
model_shapes=self.model_shapes,
).generate()
)()
columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._signature_columns))
calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed)
else:
Expand Down Expand Up @@ -169,6 +172,7 @@ def clean(self) -> None:
super().clean()

if hasattr(self, "tmpdir"):
LOGGER.info("\t+ Cleaning backend temporary directory")
self.tmpdir.cleanup()

gc.collect()
Loading

0 comments on commit f7e4069

Please sign in to comment.