diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py deleted file mode 100644 index 92406664..00000000 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Any, Callable, Optional, Tuple - -import jax -import jax.numpy as jnp -import torch -from jetstream.engine import engine_api -from jetstream_pt import engine - - -class HfEngine(engine.PyTorchEngine): - def __init__( - self, - pt_model: torch.nn.Module, - env: engine.JetEngineEnvironment, - weights=None, - ): - super().__init__(pt_model, env, weights) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 7517c0e9..91cf9784 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -5,6 +5,7 @@ import jax from jetstream_pt import fetch_models, torchjax +from jetstream_pt.engine import PyTorchEngine from jetstream_pt.environment import ( JetEngineEnvironment, JetEngineEnvironmentData, @@ -17,7 +18,6 @@ from transformers import PretrainedConfig from transformers import AutoConfig -from .engine import HfEngine from .llama_model_exportable_hf import TransformerHf @@ -136,7 +136,7 @@ def create_engine( sequence_length: int, max_input_tokens: int, max_output_tokens: int, -) -> HfEngine: +) -> PyTorchEngine: # NOTE: for now no quantization is done env_data = create_engine_env_data(model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens) if env_data is None: @@ -147,7 +147,7 @@ def create_engine( weight_shardings = model.get_sharding_annotations() sharded_weights = shard_weights(env, model.state_dict(), weight_shardings) - return HfEngine( + return PyTorchEngine( pt_model=model, env=env, weights=torchjax.from_torch_with_copy(sharded_weights), diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index cefbebe8..b3bc00a2 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -10,6 +10,7 @@ import torch import torch_xla2 from jetstream.engine.token_utils import pad_tokens, take_nearest_length +from jetstream_pt.engine import PyTorchEngine from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers.generation import GenerationConfig @@ -27,7 +28,6 @@ StoppingCriteriaParameters, Tokens, ) -from .engine import HfEngine from .engine_loader import create_engine from .token_selector import TokenSelector @@ -243,7 +243,7 @@ class TpuGeneratorJetStream(Generator): def __init__( self, - engine: HfEngine, + engine: PyTorchEngine, tokenizer: PreTrainedTokenizerBase, ): self.engine = engine