Skip to content

Commit

Permalink
fix fbcache with torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Feb 8, 2025
1 parent ffef121 commit b963132
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/para_attn/first_block_cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,20 @@ def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states

# hidden_states_shape = hidden_states.shape
# encoder_hidden_states_shape = encoder_hidden_states.shape
# hidden_states = hidden_states.flatten().contiguous().reshape(hidden_states_shape)
# encoder_hidden_states = encoder_hidden_states.flatten().contiguous().reshape(encoder_hidden_states_shape)
hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape)
encoder_hidden_states = (
encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
)

hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
# hidden_states = hidden_states.contiguous()
# encoder_hidden_states = encoder_hidden_states.contiguous()

hidden_states_residual = hidden_states - original_hidden_states
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states

hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape)
encoder_hidden_states_residual = (
encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
)

return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual

0 comments on commit b963132

Please sign in to comment.