Skip to content

Commit

Permalink
review: fix imports for type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Sep 10, 2024
1 parent b323794 commit 6857b5f
Showing 1 changed file with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Import torch_xla2 first
import torch_xla2 # isort:skip

from typing import Any
from typing import TYPE_CHECKING, Any

import jax
from jetstream_pt import fetch_models, torchjax
Expand All @@ -11,13 +11,17 @@
QuantizationConfig,
)
from loguru import logger
from transformers import AutoConfig, PretrainedConfig


if TYPE_CHECKING:
from transformers import PretrainedConfig
from transformers import AutoConfig

from .engine import HfEngine
from .llama_model_exportable_hf import TransformerHf


def load_llama_model_info(config: PretrainedConfig) -> Any:
def load_llama_model_info(config: "PretrainedConfig") -> Any:
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
Expand Down

0 comments on commit 6857b5f

Please sign in to comment.