diff --git a/optimum_benchmark/backends/py_txi/backend.py b/optimum_benchmark/backends/py_txi/backend.py index 81f1f6c3..e4c00f27 100644 --- a/optimum_benchmark/backends/py_txi/backend.py +++ b/optimum_benchmark/backends/py_txi/backend.py @@ -46,9 +46,14 @@ def download_pretrained_model(self) -> None: def create_no_weights_model(self) -> None: self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model") - filename = os.path.join(self.no_weights_model, "model.safetensors") os.makedirs(self.no_weights_model, exist_ok=True) + if self.pretrained_config is not None: + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) + if self.pretrained_processor is not None: + self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) + + filename = os.path.join(self.no_weights_model, "model.safetensors") save_file(tensors=torch.nn.Linear(1, 1).state_dict(), filename=filename, metadata={"format": "pt"}) with fast_weights_init(): # unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model @@ -59,11 +64,6 @@ def create_no_weights_model(self) -> None: del self.pretrained_model torch.cuda.empty_cache() - if self.pretrained_config is not None: - self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - if self.pretrained_processor is not None: - self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) - if self.config.task in TEXT_GENERATION_TASKS: self.generation_config.eos_token_id = None self.generation_config.pad_token_id = None @@ -78,11 +78,11 @@ def load_model_with_no_weights(self) -> None: def load_model_from_pretrained(self) -> None: if self.config.task in TEXT_GENERATION_TASKS: self.pretrained_model = TGI( - config=TGIConfig(self.config.model, **self.txi_kwargs, **self.tgi_kwargs), + config=TGIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tgi_kwargs), ) elif self.config.task in TEXT_EMBEDDING_TASKS: self.pretrained_model = TEI( - config=TEIConfig(self.config.model, **self.txi_kwargs, **self.tei_kwargs), + config=TEIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tei_kwargs), ) else: raise NotImplementedError(f"TXI does not support task {self.config.task}")