From b219ffd7d7dc9a2de2170d5f232f5f8dc38550d8 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 8 Apr 2024 10:07:57 +0200 Subject: [PATCH] match new names --- .../server/text_generation_server/generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 17fa439f..5df4bd4e 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -11,7 +11,7 @@ from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache from transformers.generation import GenerationConfig -from optimum.tpu.modeling import TpuModelForCausalLM +from optimum.tpu import AutoModelForCausalLM from optimum.tpu.generation import TokenSelector from .pb.generate_pb2 import ( @@ -301,7 +301,7 @@ class TpuGenerator(Generator): def __init__( self, - model: TpuModelForCausalLM, + model, tokenizer: PreTrainedTokenizerBase, ): self.model = model @@ -633,7 +633,7 @@ def from_pretrained( """ logger.info("Loading model (this can take a few minutes).") start = time.time() - model = TpuModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained(model_path) end = time.time() logger.info(f"Model successfully loaded in {end - start:.2f} s.") tokenizer = AutoTokenizer.from_pretrained(model_path)