From b96313274c758de7e0b976781531d6961eee3f20 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Sat, 8 Feb 2025 02:47:28 +0000 Subject: [PATCH] fix fbcache with torch.compile --- src/para_attn/first_block_cache/utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/para_attn/first_block_cache/utils.py b/src/para_attn/first_block_cache/utils.py index 0090b8f..3908bcc 100644 --- a/src/para_attn/first_block_cache/utils.py +++ b/src/para_attn/first_block_cache/utils.py @@ -207,12 +207,20 @@ def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states # hidden_states_shape = hidden_states.shape # encoder_hidden_states_shape = encoder_hidden_states.shape - # hidden_states = hidden_states.flatten().contiguous().reshape(hidden_states_shape) - # encoder_hidden_states = encoder_hidden_states.flatten().contiguous().reshape(encoder_hidden_states_shape) + hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape) + encoder_hidden_states = ( + encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape) + ) - hidden_states = hidden_states.contiguous() - encoder_hidden_states = encoder_hidden_states.contiguous() + # hidden_states = hidden_states.contiguous() + # encoder_hidden_states = encoder_hidden_states.contiguous() hidden_states_residual = hidden_states - original_hidden_states encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + + hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape) + encoder_hidden_states_residual = ( + encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape) + ) + return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual