From c897a7f9f00a7acff6cec1ea0c385128cf7d40a6 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 18 Sep 2024 09:06:57 +0000 Subject: [PATCH] feat(debug): add env var to skip warmup --- .../text_generation_server/jetstream_pt_support/generator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 8251c3df..b47a5c92 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -2,6 +2,7 @@ import logging import time from enum import Enum +import os from typing import List, Optional, Tuple import jax @@ -330,6 +331,9 @@ def warmup(self, batch: Batch) -> int: # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible # batch sizes and sequence lengths. seq_len = self.model.config.sequence_length + if os.environ.get("SKIP_WARMUP", "0") == "1": + logger.debug("Skipping warmup") + return batch_size * seq_len bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, seq_len) decode_done = False for l in reversed(DEFAULT_PREFILL_BUCKETS):