-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: use dynamic batching when generating #9
Changes from 3 commits
bc700ad
9a29580
8f7ceab
309f58d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -312,7 +312,10 @@ def __init__( | |
tokenizer.padding_side = "left" | ||
self.tokenizer = tokenizer | ||
self.special_tokens = self.tokenizer.all_special_ids | ||
self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we introduce the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I need to better understand the usage of the batches, if it can be increased/reduced often, how this affects performance. At that point it will be easier to think about a reasonable algorithm to reduce compilation and batch change overhead as much as possible. |
||
# Slots are empty to begin with, they will be populated as new batches arrive | ||
self.slots = [] | ||
# Note: this index will _never_ be decremented, and that's fine. | ||
self.slot_index = 0 | ||
self.past_key_values = None | ||
# _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup | ||
# a static cache, otherwise it is not. | ||
|
@@ -350,13 +353,18 @@ def warmup(self, batch: Batch) -> int: | |
The maximum number of tokens the model supports. | ||
""" | ||
# Just check that the warmup request parameters match the model capacity | ||
batch_size = self.model.config.batch_size | ||
# NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok maybe you want to keep it for later 🤗 |
||
if self.model.config.batch_size is not None: | ||
batch_size = self.model.config.batch_size | ||
else: | ||
# batch size is not set, just assume it's unlimited and accept all requests | ||
batch_size = len(batch.requests) | ||
if len(batch.requests) > batch_size: | ||
raise ValueError( | ||
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length." | ||
) | ||
self.prefill(batch) | ||
return self.model.config.batch_size * self.model.config.sequence_length | ||
return batch_size * self.model.config.sequence_length | ||
|
||
@torch.no_grad | ||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | ||
|
@@ -373,17 +381,20 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: | |
for slot in self.slots: | ||
slots[slot.state].append(slot) | ||
active_slots = slots[Slot.State.READY] | ||
# Delete all empty slots, no need to have them anymore | ||
empty_slots = slots[Slot.State.EMPTY] | ||
if len(empty_slots) < len(batch.requests): | ||
raise ValueError( | ||
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." | ||
f"Please align the number of concurrent requests with the static batch size: {self.model.batch_size}." | ||
) | ||
for slot in empty_slots: | ||
self.slots.remove(slot) | ||
mfuntowicz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Assign each request to an empty slot | ||
logger.debug(f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)") | ||
logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)") | ||
new_slots = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. useless 😄 |
||
for request in batch.requests: | ||
slot = empty_slots.pop() | ||
# Dynamically create a new slot for each request | ||
slot = Slot(self.slot_index, self.tokenizer, self.model.device) | ||
self.slot_index += 1 | ||
slot.assign(request, self.model.generation_config) | ||
new_slots.append(slot) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same answer, I'll remove it |
||
self.slots.append(slot) | ||
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") | ||
# Reconstruct the full inputs (without padding) as seen by the model. | ||
# This comprises: | ||
|
@@ -466,8 +477,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |
# Reconstruct input_ids and attention_mask from slots | ||
input_ids = None | ||
attention_mask = None | ||
batch_size = len(self.slots) | ||
position_ids = torch.zeros( | ||
[self.model.config.batch_size, 1], | ||
[batch_size, 1], | ||
dtype=torch.int64, | ||
device=self.model.device, | ||
) | ||
|
@@ -476,7 +488,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |
if input_ids is None: | ||
# Create blank inputs covering all slots (even empty ones) | ||
input_ids = torch.full( | ||
[self.model.config.batch_size, 1], | ||
[batch_size, 1], | ||
fill_value=self.tokenizer.eos_token_id, | ||
dtype=torch.int64, | ||
device=self.model.device, | ||
|
@@ -488,7 +500,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |
if attention_mask is None: | ||
# Create default mask covering all slots (even empty ones) | ||
attention_mask = torch.zeros( | ||
[self.model.config.batch_size, slot.attention_mask.size(-1)], | ||
[batch_size, slot.attention_mask.size(-1)], | ||
dtype=torch.int64, | ||
device=self.model.device, | ||
) | ||
|
@@ -550,7 +562,8 @@ def _post_generate(self, outputs: Dict, next_batch_id: int, input_ids: torch.Lon | |
text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason | ||
) | ||
logger.debug(f"Finished generating tokens for request {request_id}") | ||
# mark the slot as available | ||
# This slot is now empty, it will be removed from the list of | ||
# active slots once a new prefill is requested | ||
slot.clear() | ||
else: | ||
active_slots = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should keep a
HF_MAX_BATCH_SIZE
somewhere?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as mentioned below, maybe we'll use it later