diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 17fa439f..5df4bd4e 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -11,7 +11,7 @@ from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache from transformers.generation import GenerationConfig -from optimum.tpu.modeling import TpuModelForCausalLM +from optimum.tpu import AutoModelForCausalLM from optimum.tpu.generation import TokenSelector from .pb.generate_pb2 import ( @@ -301,7 +301,7 @@ class TpuGenerator(Generator): def __init__( self, - model: TpuModelForCausalLM, + model, tokenizer: PreTrainedTokenizerBase, ): self.model = model @@ -633,7 +633,7 @@ def from_pretrained( """ logger.info("Loading model (this can take a few minutes).") start = time.time() - model = TpuModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained(model_path) end = time.time() logger.info(f"Model successfully loaded in {end - start:.2f} s.") tokenizer = AutoTokenizer.from_pretrained(model_path)