Skip to content

Commit

Permalink
fix(generation): eos_token_id can be a list in configs
Browse files Browse the repository at this point in the history
This essentially copies commit 8a4a98d2472b8e0180eb9bd4a1824f983e220811
from optimum-neuron, that fixed the same problem.
  • Loading branch information
tengomucho committed Apr 30, 2024
1 parent c7a4bde commit 584ffda
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
19 changes: 10 additions & 9 deletions optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import logging
from typing import Optional
from typing import List, Optional, Union

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -42,15 +42,15 @@ def __init__(
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_id: int,
eos_token_ids: Union[int,List[int]],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
seed: Optional[int] = 0,
):
self.mode = mode
self.logits_processor = logits_processor
self.stopping_criteria = stopping_criteria
self.eos_token_id = eos_token_id
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
xm.set_rng_state(seed)
Expand Down Expand Up @@ -132,13 +132,14 @@ def create(
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria)

# The generation requires special tokens
eos_token_id = generation_config.eos_token_id
# This is not supposed to happen for any of the models we support
assert eos_token_id is not None and not isinstance(eos_token_id, list)
eos_token_id = generation_config.eos_token_id
assert eos_token_id is not None
# The generation requires special tokens
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
if generation_config.pad_token_id is None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-end generation.")
generation_config.pad_token_id = eos_token_ids[0]

generation_mode = generation_config.get_generation_mode()
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
Expand All @@ -153,7 +154,7 @@ def create(
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
logits_warper=logits_warper,
eos_token_id=eos_token_id,
eos_token_ids=eos_token_ids,
pad_token_id=generation_config.pad_token_id,
seed=seed,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.EMPTY:
if input_ids is None:
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is None:
if isinstance(self.tokenizer.eos_token_id, list):
pad_token_id = self.tokenizer.eos_token_id[0]
else:
pad_token_id = self.tokenizer.eos_token_id
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[batch_size, 1],
Expand Down

0 comments on commit 584ffda

Please sign in to comment.