Skip to content

Commit

Permalink
feat(cache): use optimized StaticCache class for XLA (#70)
Browse files Browse the repository at this point in the history
This is actually a ripoff of the work originally done as a contribution
to transformers:

huggingface/transformers#31129

The original contribution has not been merged yet, but it shows lower
memory usage and better performance on XLA. So I think it's worth adding
it here, to be integrated on optimum-tpu.
  • Loading branch information
tengomucho authored Jul 9, 2024
1 parent 7cce24c commit 77bebf8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
60 changes: 60 additions & 0 deletions optimum/tpu/static_cache_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, Dict, Optional, Tuple

import torch
from transformers import StaticCache


class StaticCacheXla(StaticCache):
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

# `index_copy_(dim, index, source)` functions similarly to `tensor[index] = source`,
# but it is used for better generality and it uses less memory on XLA.
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)

return k_out, v_out


def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
key_cache = self.key_cache[layer_idx]
device = key_cache.device

# index_select(dim, index) performs the same operation as item = tensor[..., index, ...]
# but it is used for better generality and it uses less memory on XLA.
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
item = key_cache.index_select(0, torch.tensor(0, device=device))
head = item.index_select(1, torch.tensor(0, device=device))

return head.any(dim=-1).sum()
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import torch.multiprocessing as mp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation import GenerationConfig

import optimum.tpu.xla_logger as logger
from optimum.tpu import AutoModelForCausalLM
from optimum.tpu.generation import TokenSelector
from optimum.tpu.static_cache_xla import StaticCacheXla
from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox

from .generator_base import Generator
Expand Down Expand Up @@ -529,7 +530,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:

extra_args = {}
if self._supports_static_cache:
self.past_key_values = StaticCache(
self.past_key_values = StaticCacheXla(
config=self.model.config,
max_batch_size=len(self.slots),
max_cache_len=self.model.config.sequence_length,
Expand Down

0 comments on commit 77bebf8

Please sign in to comment.