Skip to content

Commit

Permalink
match new names
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Apr 8, 2024
1 parent a6d508f commit b219ffd
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig
from optimum.tpu.modeling import TpuModelForCausalLM
from optimum.tpu import AutoModelForCausalLM
from optimum.tpu.generation import TokenSelector

from .pb.generate_pb2 import (
Expand Down Expand Up @@ -301,7 +301,7 @@ class TpuGenerator(Generator):

def __init__(
self,
model: TpuModelForCausalLM,
model,
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
Expand Down Expand Up @@ -633,7 +633,7 @@ def from_pretrained(
"""
logger.info("Loading model (this can take a few minutes).")
start = time.time()
model = TpuModelForCausalLM.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
end = time.time()
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand Down

0 comments on commit b219ffd

Please sign in to comment.