Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 15, 2024
1 parent 4289798 commit ecaa6c8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
27 changes: 10 additions & 17 deletions optimum_benchmark/backends/py_txi/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,40 +63,33 @@ def prepare_generation_config(self) -> None:

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
self.logger.info("\t+ Creating no weights model directory")
os.makedirs(self.no_weights_model, exist_ok=True)
self.logger.info("\t+ Creating no weights model state dict")
state_dict = torch.nn.Linear(1, 1).state_dict()
self.logger.info("\t+ Saving no weights model safetensors")
safetensor = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"})
self.logger.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
self.logger.info("\t+ Saving no weights model pretrained processor")
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)
# unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model
self.logger.info(f"\t+ Loading no weights model from {self.no_weights_model}")
with fast_weights_init():
# unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model
self.pretrained_model = self.automodel_loader.from_pretrained(
self.no_weights_model, **self.config.model_kwargs, device_map="auto", _fast_init=False
)
self.logger.info("\t+ Saving no weights model")
self.pretrained_model.save_pretrained(save_directory=self.no_weights_model)
save_file(tensors=self.pretrained_model.state_dict(), filename=safetensor, metadata={"format": "pt"})
del self.pretrained_model
torch.cuda.empty_cache()

if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Modifying generation config for fixed length generation")
if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
if self.pretrained_processor is not None:
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)
if self.generation_config is not None:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
self.logger.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=self.no_weights_model)

def load_model_with_no_weights(self) -> None:
self.config.volumes[self.tmpdir.name] = {"bind": "/no_weights_data/", "mode": "rw"}
original_model, self.config.model = self.config.model, "/no_weights_data/no_weights_model/"
self.config.volumes = (self.config.volumes, {self.tmpdir.name: {"bind": self.tmpdir.name, "mode": "rw"}})
original_model, self.config.model = self.config.model, self.no_weights_model
self.load_model_from_pretrained()
self.config.model = original_model
self.config.model, self.config.volumes = original_model

def load_model_from_pretrained(self) -> None:
if self.config.task in TEXT_GENERATION_TASKS:
Expand Down
3 changes: 3 additions & 0 deletions tests/configs/cpu_inference_py_txi_gpt2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ defaults:
- override backend: py-txi

name: cpu_inference_py_txi_gpt2

backend:
cuda_graphs: 0
3 changes: 3 additions & 0 deletions tests/configs/cuda_inference_py_txi_gpt2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ defaults:
- override backend: py-txi

name: cuda_inference_py_txi_gpt2

backend:
cuda_graphs: 0

0 comments on commit ecaa6c8

Please sign in to comment.