Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use dynamic batching when generating #9

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions text-generation-inference/integration-tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@


MODEL_ID = "openai-community/gpt2"
BATCH_SIZE = 4
SEQUENCE_LENGTH = 1024


@pytest.fixture(scope="module")
def model_name_or_path():
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should keep a HF_MAX_BATCH_SIZE somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned below, maybe we'll use it later

os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH)
yield MODEL_ID

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,10 @@ def __init__(
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we introduce the HF_MAX_BATCH_SIZE maybe we can initialize this list, wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I need to better understand the usage of the batches, if it can be increased/reduced often, how this affects performance. At that point it will be easier to think about a reasonable algorithm to reduce compilation and batch change overhead as much as possible.

# Slots are empty to begin with, they will be populated as new batches arrive
self.slots = []
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
self.past_key_values = None
# _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup
# a static cache, otherwise it is not.
Expand Down Expand Up @@ -350,13 +353,18 @@ def warmup(self, batch: Batch) -> int:
The maximum number of tokens the model supports.
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self.model.config.batch_size
# NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok maybe you want to keep it for later 🤗

if self.model.config.batch_size is not None:
batch_size = self.model.config.batch_size
else:
# batch size is not set, just assume it's unlimited and accept all requests
batch_size = len(batch.requests)
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
self.prefill(batch)
return self.model.config.batch_size * self.model.config.sequence_length
return batch_size * self.model.config.sequence_length

@torch.no_grad
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
Expand All @@ -373,17 +381,20 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
for slot in self.slots:
slots[slot.state].append(slot)
active_slots = slots[Slot.State.READY]
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
if len(empty_slots) < len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f"Please align the number of concurrent requests with the static batch size: {self.model.batch_size}."
)
for slot in empty_slots:
self.slots.remove(slot)
mfuntowicz marked this conversation as resolved.
Show resolved Hide resolved
# Assign each request to an empty slot
logger.debug(f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)")
logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)")
new_slots = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is new_slots used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless 😄

for request in batch.requests:
slot = empty_slots.pop()
# Dynamically create a new slot for each request
slot = Slot(self.slot_index, self.tokenizer, self.model.device)
self.slot_index += 1
slot.assign(request, self.model.generation_config)
new_slots.append(slot)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about new_slots

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same answer, I'll remove it

self.slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
Expand Down Expand Up @@ -466,8 +477,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
batch_size = len(self.slots)
position_ids = torch.zeros(
[self.model.config.batch_size, 1],
[batch_size, 1],
dtype=torch.int64,
device=self.model.device,
)
Expand All @@ -476,7 +488,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
if input_ids is None:
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[self.model.config.batch_size, 1],
[batch_size, 1],
fill_value=self.tokenizer.eos_token_id,
dtype=torch.int64,
device=self.model.device,
Expand All @@ -488,7 +500,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.config.batch_size, slot.attention_mask.size(-1)],
[batch_size, slot.attention_mask.size(-1)],
dtype=torch.int64,
device=self.model.device,
)
Expand Down Expand Up @@ -550,7 +562,8 @@ def _post_generate(self, outputs: Dict, next_batch_id: int, input_ids: torch.Lon
text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason
)
logger.debug(f"Finished generating tokens for request {request_id}")
# mark the slot as available
# This slot is now empty, it will be removed from the list of
# active slots once a new prefill is requested
slot.clear()
else:
active_slots = True
Expand Down
2 changes: 0 additions & 2 deletions text-generation-inference/tests/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@


MODEL_ID = "google/gemma-2b"
BATCH_SIZE = 4
SEQUENCE_LENGTH = 1024


@pytest.fixture(scope="module")
def model_path():
# Add variables to environment so they can be used in TpuModelForCausalLM
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE)
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH)
path = fetch_model(MODEL_ID)
return path
Expand Down
8 changes: 2 additions & 6 deletions text-generation-inference/tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@


MODEL_ID = "openai-community/gpt2"
BATCH_SIZE = 4
SEQUENCE_LENGTH = 1024


@pytest.fixture(scope="module")
def model_path():
# Add variables to environment so they can be used in TpuModelForCausalLM
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE)
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH)
path = fetch_model(MODEL_ID)
return path
Expand Down Expand Up @@ -84,7 +82,6 @@ def create_request(
@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"])
def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path):
generator = TpuGenerator.from_pretrained(model_path)
assert generator.model.config.batch_size >= batch_size
requests = []
max_new_tokens = 20
for i in range(batch_size):
Expand All @@ -110,13 +107,13 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_
[
"It was a bright cold day in April, and the clocks were striking thirteen.",
20,
" The sun was setting, and the wind was blowing. The sun was shining, and the wind was",
" The sun was setting, and the wind was blowing. The sun was setting, and the wind was",
False,
],
[
"It was a bright cold day in April, and the clocks were striking thirteen.",
20,
" We sat outside the house, drinking coffee, listening to the orchestra playing through the window. We could",
" We sat at the front door and watched the clock on a box of yellow paper and found it almost",
True,
],
],
Expand Down Expand Up @@ -144,7 +141,6 @@ def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, mo

def test_decode_multiple(model_path):
generator = TpuGenerator.from_pretrained(model_path)
assert generator.model.config.batch_size > 1
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
Expand Down