Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small optimizations #5

Merged
merged 5 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import logging
import time
import os
from abc import ABC
from enum import Enum
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -181,7 +182,7 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s
selector: (`TokenSelector`):
An object implementing the updated token selection logic.
"""
self._tokens = input_ids.clone()
self._tokens = input_ids.cpu()
self._next_text_token_start = 0
self._next_text_token_end = torch.numel(self._tokens)
self._next_text = ""
Expand Down Expand Up @@ -210,25 +211,27 @@ def _decode_next_tokens(
self,
) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# Copy the tokens to CPU to avoid recompilation on TPU. Post-processing is quite fast anyway.
tokens = self._tokens.cpu()
# We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
new_text = self._tokenizer.decode(self._tokens[self._next_text_token_start :], skip_special_tokens=False)
new_text = self._tokenizer.decode(tokens[self._next_text_token_start :], skip_special_tokens=False)
if new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
return ""

# Compare the generated text with the one using only the tokens producing the last one
last_text = self._tokenizer.decode(
self._tokens[self._next_text_token_start : self._next_text_token_end],
tokens[self._next_text_token_start : self._next_text_token_end],
skip_special_tokens=False,
)
if len(new_text) == len(last_text):
# Nothing new was actually generated
return ""
# Return the decoded text and store its token offsets
self._next_text_token_start = self._next_text_token_end
self._next_text_token_end = torch.numel(self._tokens)
self._next_text_token_end = torch.numel(tokens)
return new_text[len(last_text) :]

def append(self, next_token: int) -> str:
Expand All @@ -248,7 +251,7 @@ def append(self, next_token: int) -> str:
The corresponding decoded text (if any).
"""
self._tokens = torch.cat(
[self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)]
[self._tokens, torch.tensor([next_token], dtype=self._tokens.dtype)]
)
# Update mask only if it was set previously
if self._mask is not None:
Expand Down Expand Up @@ -304,6 +307,12 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
if model.device.type == "xla" and "DBG_COMPILE" in os.environ:
self.model_one_token = torch.compile(model, backend="openxla")
logger.debug("Model compiled for decoding")
else:
self.model_one_token = model

# Specify padding options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
Expand Down Expand Up @@ -426,7 +435,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args
batch.id, input_ids, self.model, attention_mask=attention_mask, position_ids=position_ids, **extra_args
)

# Reactivate previously active slots for the next decode, and append
Expand Down Expand Up @@ -494,15 +503,17 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
else:
extra_args["attention_mask"] = attention_mask
extra_args["past_key_values"] = self.past_key_values
return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args)
return self._generate_token(
next_batch_id, input_ids, self.model_one_token, position_ids=position_ids, **extra_args
)

def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params
self, next_batch_id: int, input_ids: torch.LongTensor, model: torch.nn.Module, **forward_extra_params
) -> Tuple[List[Generation], CachedBatch]:
# Add barrier to allow next graph step to always be the same
xm.mark_step()
# Forward
outputs = self.model(
outputs = model(
input_ids,
return_dict=True,
use_cache=True,
Expand All @@ -512,8 +523,11 @@ def _generate_token(
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step(wait=False)
xm.mark_step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my knowledge: We were not waiting before, why this has changed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is wait=False, and I did not want to give the false impression I am changing the default behaviour, so I just removed the default parameter.

ret = self._post_generate(outputs, next_batch_id, input_ids)
return ret

def _post_generate(self, outputs: Dict, next_batch_id: int, input_ids: torch.LongTensor) -> Tuple[List[Generation], CachedBatch]:
generations = []
active_slots = False
for i, slot in enumerate(self.slots):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ def from_pretrained(
model.config.batch_size = batch_size
if sequence_length is not None or getattr(model.config, "sequence_length", None) is None:
model.config.sequence_length = sequence_length

# Do eval, and compile
# Do eval
model.eval()
if device == "xla" and "DBG_COMPILE" in environ:
model = torch.compile(model, backend="openxla_eval")
logger.debug("Model compiled.")

return model
15 changes: 11 additions & 4 deletions text-generation-inference/tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import os
from tqdm import tqdm
from text_generation_server.generator import TpuGenerator
from text_generation_server.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Expand Down Expand Up @@ -44,13 +45,19 @@ def create_request(
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if do_sample == False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not do_sample

# Drop top_p parameter to avoid warnings
top_p = 1.0
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
typical_p=typical_p,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)
Expand Down Expand Up @@ -121,7 +128,7 @@ def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, mo
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for i in range(max_new_tokens - 1):
for _ in tqdm(range(max_new_tokens - 1), "Decoding tokens"):
assert next_batch.size == 1
assert next_batch.max_tokens == 1024
assert len(generations) == 1
Expand Down Expand Up @@ -152,7 +159,7 @@ def test_decode_multiple(model_path):
assert len(tokens[0]) == 1
# Decode a few tokens
gen_tokens = 4
for _ in range(gen_tokens - 1):
for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
Expand All @@ -172,7 +179,7 @@ def test_decode_multiple(model_path):
assert len(tokens[1]) == 1
# Decode more tokens until we reach the maximum for the first request
batches = [next_batch, next_batch_1]
for _ in range(max_new_tokens - gen_tokens):
for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"):
generations, next_batch = generator.decode(batches)
for g in generations:
tokens[g.request_id].append(g.tokens.ids[0])
Expand All @@ -189,7 +196,7 @@ def test_decode_multiple(model_path):
assert output.generated_tokens == max_new_tokens
generated_text = output.text
# Continue decoding until the end of the second request
for _ in range(gen_tokens - 1):
for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
Expand Down
9 changes: 8 additions & 1 deletion text-generation-inference/tests/test_generator_gemma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import os
from tqdm import tqdm
from text_generation_server.generator import TpuGenerator
from text_generation_server.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Expand Down Expand Up @@ -35,13 +36,19 @@ def create_request(
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if do_sample == False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not do_sample

# Drop top_p parameter to avoid warnings
top_p = 1.0
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
typical_p=typical_p,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)
Expand All @@ -57,7 +64,7 @@ def test_decode_single(model_path):
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for _ in range(max_new_tokens - 1):
for _ in tqdm(range(max_new_tokens - 1)):
assert next_batch.size == 1
assert next_batch.max_tokens == 1024
assert len(generations) == 1
Expand Down