diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 7e76d922..a8aba3c0 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -1,7 +1,7 @@ # Import torch_xla2 first import torch_xla2 # isort:skip -from typing import Any +from typing import TYPE_CHECKING, Any import jax from jetstream_pt import fetch_models, torchjax @@ -11,13 +11,17 @@ QuantizationConfig, ) from loguru import logger -from transformers import AutoConfig, PretrainedConfig + + +if TYPE_CHECKING: + from transformers import PretrainedConfig +from transformers import AutoConfig from .engine import HfEngine from .llama_model_exportable_hf import TransformerHf -def load_llama_model_info(config: PretrainedConfig) -> Any: +def load_llama_model_info(config: "PretrainedConfig") -> Any: num_layers = config.num_hidden_layers num_heads = config.num_attention_heads head_dim = config.hidden_size // num_heads