The implementation is on par with the official implementation at torchscale
repo for paper (https://arxiv.org/pdf/2307.08621.pdf) and huggingface transformer compatible implementation of Retention Networks
In this repo, you can find:
-
nX speed up implementation of retention operation of Retention Networks. Notice: This implement only suitable for large sequence length and small head dimension. Assume the dimension of query
is , you can enjoy considerable speed up for . Run python test.py
for benchmark.python test.py #(Platform: 3090) (e1 = fast-origin) (e2 = reduce - origin) B S D e1 e2 fast reduce origin speed_up 0 1 30 8 0.000032 0.000033 0.000435 0.000212 0.000118 0.272 1 1 30 16 0.000068 0.000073 0.000530 0.000247 0.000141 0.266 2 1 30 32 0.000178 0.000189 0.000472 0.000269 0.000123 0.261 3 1 30 64 0.000570 0.000490 0.000694 0.000221 0.000122 0.175 4 1 300 8 0.000147 0.000203 0.000469 0.000223 0.000118 0.251 5 1 300 16 0.000300 0.000440 0.000488 0.000210 0.000118 0.242 6 1 300 32 0.000681 0.001043 0.001454 0.000327 0.000240 0.165 7 1 300 64 0.001849 0.002668 0.005716 0.000982 0.000259 0.045 8 1 3000 8 0.000969 NaN 0.001877 0.002474 0.016054 8.552 9 1 3000 16 0.001753 NaN 0.005466 0.002907 0.016108 2.947 10 1 3000 32 0.003381 NaN 0.020679 0.004071 0.016576 0.802 11 1 3000 64 0.007355 NaN 0.081177 0.010387 0.017319 0.213 12 1 5000 8 0.001459 NaN 0.004210 0.005091 0.044165 10.490 13 1 5000 16 0.002664 NaN 0.012461 0.005538 0.044373 3.561 14 1 5000 32 0.005147 NaN 0.044402 0.007455 0.045701 1.029 15 1 5000 64 0.010328 NaN 0.178734 0.018047 0.047592 0.266
-
Any chunksize Recurrent:
- If chunksize==wholelength, it become parallel mode.
- If chunksize==1, it become recurrent mode.
We reformulation the retention, change the operation order and achieve an identity implement that can correct preduce the kv cache and the gk cache of retention.
import torch import torch.nn as nn import numpy as np from self_retention import SelfRetentionV2,RetNetRelPosV2, RMSNorm from configuration_retnet import RetNetConfig S = 30 B = 2 H = 8 qk_dim = 32 v_dim = 64 q = torch.randn(B,H,S,qk_dim).cuda() k = torch.randn(B,H,S,qk_dim).cuda() v = torch.randn(B,H,S, v_dim).cuda() config = RetNetConfig(decoder_layers=1, decoder_embed_dim=256, decoder_value_embed_dim=256, decoder_retention_heads=8, decoder_ffn_embed_dim=128) retnet_rel_pos = RetNetRelPosV2(config).cuda() model = SelfRetentionV2(config) group_norm = RMSNorm(H,0,False) model.group_norm = nn.Identity() ## remove the group norm which we add by ourselves use_gk = True mode = 'qk_first' print(" ================= random chunksize recurrent test ====================") partition = np.sort(np.random.choice(np.arange(2,S-2),(5,),replace=False)).tolist() + [S] print(f" partition: {partition}") past_kv = None full_rnn_state = [] last = 0 for i in partition: qm = q[:,:,last:i] km = k[:,:,last:i] vm = v[:,:,last:i] (cos, sin), (chunk_gamma, unnormlized_decay_mask, mask_normlizer) = retnet_rel_pos( i, recurrent_chunk_size=qm.shape[-2], forward_impl='chunkwise_recurrent') one_step_output, _, past_kv = model(qm, km, vm, (chunk_gamma, unnormlized_decay_mask,mask_normlizer), past_key_value= past_kv, normlize_for_stable=use_gk, mode=mode) full_rnn_state.append(one_step_output) last = i full_rnn_state = torch.cat(full_rnn_state, dim=1)
We use torch-discounted-cumsum
to accelerate computation.
#pip install torch-discounted-cumsum --no-build-isolation
pip install git+https://github.com/veya2ztn/torch-discounted-cumsum.git --no-build-isolation
The "Parallel" formation of retention is simple
def forward(self, q, k, v, decay_mask):
"""
q -> (B, H, S, D)
k -> (B, H, S, D)
v -> (B, H, S, D)
decay_mask-> (1, H, S, S)
"""
retention = q @ k.transpose(-1, -2) # --> (B,H,S,S)
retention = retention * decay_mask # --> (B,H,S,S)
retention = retention / retention.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1) # --> (B,H,S,S)
output = retention @ v # [b, h, t, v_dim / h] ## # --> (B,H,S,D)
return output
However, it is
Lets drive it step by step.
Firstly,
with decay mask
where
Thus
One thing that paper don't say is it will do normlaization.
Now, we have reduced
second line is the formula for reduced
third line is the formula for discounted-cumsum
Next
Still can compute
second line is the formula for reduced
third line is the formula for discounted-cumsum
(3) [TODO]: dirctly build the operation
Now, the max intermediate is
broadcast operationbfloat16 cuda(Lazy implement) (wondering pow percision for bfloat16)
Firstly, consider the parallel retention. (We use bold
with the decay
Notice, in the real code, the decay mask will be normalized along
There is one more normalization during parallel.
The final step is
where
and
For example, let omit
It has obviously recurrent formulation:
Same
Now consider the chunk-wise recurrent that forward with multi step.
where the
which is the first
Thus, the recurrent - chunk_size formulation is
current_kv = torch.einsum('Hi,BHiac->BHiac', gamma, past_kv) +
torch.einsum('Hij,BHja,BHjc-> BHiac', mask, k_more, v_more)
current_gk = torch.einsum('Hi,BHia->BHia', gamma, past_gk) +
torch.einsum('Hij,BHja-> BHia', mask, k_more)
Notice, at current (2023.10.24), the group_norm
is use RMSNorm
which normalize the
That is
Given any gauge only effect on the
The result will hold invariant.
This mean we can totally remove the current_gk
if we finally apply the group_norm normalization.
For numerical reason, we can rescale the retention
use Nan
or Inf
.
The cost to obtain current_gk
and current_kv
almost same, thus, disable computing the retention
vis set normlize_for_stable=False
can indeed accelerate the inference 2 times fast.