Skip to content

Commit

Permalink
Several Inference Endpoint fixes (#66)
Browse files Browse the repository at this point in the history
* fix(tgi): remove all the variables from entrypoint.sh

* fix(tgi): correct version

* fix(tgi): pin numpy version <2.0

* feat(tgi): entrypoint adds GKE specific command

* fix(generator): correct CachedBatch serialization when it's None

This was generating a tricky error when calling "/health" at the server
startup: this was calling prefill and returning None as the cached
batch, that was failing to be serialized.

* feat(generator): prefill input preparation is done on CPU

Doing that on TPU seems to slow down (due to compilation?) and takes a
lot of memory.

* feat(generator): decode input preparation is done on CPU

* feat(generator): support TGI truncate parameter in Request

* fix(generator): warmup clears after prefill

This allows to correctly handle warmup.

* fix(tgi): correct clear implementation

This clears a potential issue when clearing TGI requests.

When a client cancels a TGI request, two different methods can be called
on the TGI server:

- if the request is cancelled after prefill, then the router asks the
  server to "filter" the decoding batch from the corresponding request.
  This is correctly implemented,
- if the request is cancelled during prefill, then the router asks the
  server to clear the whole prefill batch. This was not correctly
  implemented because in that configuration we cleared all requests,
  even those not included in that prefill batch.

This is now fixed, basically reproducing TGI Neuron fix:
huggingface/optimum-neuron#609

* feat(ci): release TGI images only when release is published

* chore(generator): turn log info -> debug on clear

---------

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
  • Loading branch information
tengomucho and mfuntowicz authored Jul 3, 2024
1 parent 7050cf4 commit 246fb24
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 120 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tpu-tgi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Release

on:
release:
types: [published]

jobs:
docker:
Expand Down
37 changes: 5 additions & 32 deletions text-generation-inference/docker/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,44 +1,17 @@
#!/bin/bash

# This is required by GKE, see
# https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#privileged-mode
ulimit -l 68719476736

# Hugging Face Hub related
if [[ -z "${MODEL_ID}" ]]; then
echo "MODEL_ID must be set"
exit 1
fi
export MODEL_ID="${MODEL_ID}"

# TGI related
if [[ -n "${TGI_MAX_CONCURRENT_REQUESTS}" ]]; then
export TGI_MAX_CONCURRENT_REQUESTS="${TGI_MAX_CONCURRENT_REQUESTS}"
else
export TGI_MAX_CONCURRENT_REQUESTS=4
fi

if [[ -n "${TGI_MAX_BATCH_SIZE}" ]]; then
export TGI_MAX_BATCH_SIZE="${TGI_MAX_BATCH_SIZE}"
else
export TGI_MAX_BATCH_SIZE=1
fi

if [[ -n "${TGI_MAX_INPUT_TOKENS}" ]]; then
export TGI_MAX_INPUT_TOKENS="${TGI_MAX_INPUT_TOKENS}"
else
export TGI_MAX_INPUT_TOKENS=32
fi

if [[ -n "${TGI_MAX_TOTAL_TOKENS}" ]]; then
export TGI_MAX_TOTAL_TOKENS="${TGI_MAX_TOTAL_TOKENS}"
else
export TGI_MAX_TOTAL_TOKENS=64
fi

TGI_MAX_BATCH_PREFILL_TOKENS=$(( TGI_MAX_BATCH_SIZE*TGI_MAX_INPUT_TOKENS ))

text-generation-launcher --port 8080 \
--max-concurrent-requests ${TGI_MAX_CONCURRENT_REQUESTS} \
--max-batch-size ${TGI_MAX_BATCH_SIZE} \
--max-batch-prefill-tokens ${TGI_MAX_BATCH_PREFILL_TOKENS} \
--max-input-tokens ${TGI_MAX_INPUT_TOKENS} \
--max-total-tokens ${TGI_MAX_TOTAL_TOKENS} \
--max-batch-size 4 \
--model-id ${MODEL_ID}

1 change: 1 addition & 0 deletions text-generation-inference/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
'transformers == 4.41.1',
'loguru == 0.6.0',
"sentencepiece == 0.2.0",
"numpy<2.0",
]

[tool.setuptools]
Expand Down
193 changes: 109 additions & 84 deletions text-generation-inference/server/text_generation_server/generator.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ async def ServiceDiscovery(self, request, context):

async def ClearCache(self, request, context):
if request.HasField("id"):
logger.warning(f"Clearing all batches instead of batch {request.id} only.")
self.generator.clear()
self.generator.clear(request.id)
else:
self.generator.clear()
return generate_pb2.ClearCacheResponse()

async def FilterBatch(self, request, context):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pkg_resources import parse_version


__version__ = "0.1.0a1"
__version__ = "0.1.1"
VERSION = parse_version(__version__)
2 changes: 1 addition & 1 deletion text-generation-inference/tests/test_generator_slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_decode_streaming(tokenizer, input_text, generated_text):
# Note: device used is cpu to make it faster
slot = Slot(0, tokenizer, "cpu")
request = Request(id=0, inputs=input_text)
slot.assign(request, GenerationConfig())
slot.assign(0, request, GenerationConfig())
assert slot.cached_text == input_text

inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors="pt")
Expand Down
35 changes: 35 additions & 0 deletions text-generation-inference/tests/test_prefill_truncate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from helpers import create_request, prepare_model
from text_generation_server.generator import TpuGeneratorSingleThread as TpuGenerator
from text_generation_server.pb.generate_pb2 import Batch


def test_prefill_truncate():
model_id="Maykeye/TinyLLama-v0"
sequence_length=1024

model_path = prepare_model(model_id, sequence_length)
max_new_tokens = 20

generator = TpuGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length
)
input_text = "This is a secret part. Once upon a time,"

request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generations, _ = generator.prefill(batch)
assert len(generations) == 1
assert generations[0].tokens.ids == [635]
assert generations[0].tokens.texts == [" there"]

# Now re-test but with truncate
generator.clear()

request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
# This will only leave 5 tokens
request.truncate = 5
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generations, _ = generator.prefill(batch)
assert len(generations) == 1
assert generations[0].tokens.ids == [260]
assert generations[0].tokens.texts == [" a"]

0 comments on commit 246fb24

Please sign in to comment.