Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Few more Inference Endpoints fixes #69

Merged
merged 8 commits into from
Jul 8, 2024
2 changes: 1 addition & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from pkg_resources import parse_version


__version__ = "0.1.1"
__version__ = "0.1.2"
VERSION = parse_version(__version__)
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
optimum_logger.setLevel("CRITICAL")

# These will do some bucketing on prefill lengths to avoid too many different sizes
PREFILL_LENGTHS = [
PREFILL_LENGTHS = list(range(6, 16)) + [
16,
32,
64,
Expand Down Expand Up @@ -446,6 +446,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
active_slots = slots[Slot.State.READY]
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s)."
f" Maximum batch size supported is: {model_batch_size}."
)
for slot in empty_slots:
self.slots.remove(slot)
# Assign each request to an empty slot
Expand Down Expand Up @@ -836,7 +842,8 @@ def return_to_caller(*data):
cached_batch = generator.filter(batch_id, request_ids)
return_to_caller(cached_batch.SerializeToString())
if command == GeneratorCommand.CLEAR:
generator.clear()
batch_id = data[0]
generator.clear(batch_id)
if command == GeneratorCommand.DELETE:
if rank == 0:
# Set agent to ready
Expand Down Expand Up @@ -902,8 +909,8 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
s_cached_batch = self.mailbox.send(GeneratorCommand.FILTER, batch_id, request_ids)[0]
return CachedBatch.FromString(s_cached_batch)

def clear(self):
self.mailbox.send(GeneratorCommand.CLEAR)
def clear(self, batch_id: Optional[int] = None):
self.mailbox.send(GeneratorCommand.CLEAR, batch_id)

def leave(self):
if self.mailbox is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch"""
raise NotImplementedError

def clear(self):
def clear(self, batch_id: Optional[int] = None):
"""Remove all requests from the generator"""
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pkg_resources import parse_version


__version__ = "0.1.1"
__version__ = "0.1.2"
VERSION = parse_version(__version__)
2 changes: 1 addition & 1 deletion text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_decode_single(params):
DecodeTestParams(
model_id="google/gemma-7b",
sequence_length=128,
expected_text="\n\nThe time is 1984. The place is Airstrip One, the British",
expected_text="\n\nThe first line of George Orwell’s <em>1984</em> is a perfect example",
),
DecodeTestParams(
model_id="mistralai/Mistral-7B-v0.3",
Expand Down
5 changes: 4 additions & 1 deletion text-generation-inference/tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_


def test_decode_multiple(model_path):
generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH)
generator = TpuGenerator.from_pretrained(model_path,
revision="",
max_batch_size=2,
max_sequence_length=SEQUENCE_LENGTH)
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
Expand Down
Loading