Skip to content

Commit

Permalink
fix hunyuanvideo context parallel with compile
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 29, 2025
1 parent 5c9965a commit 46bf9e7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 45 deletions.
100 changes: 55 additions & 45 deletions src/para_attn/context_parallel/diffusers_adapters/hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,51 +110,9 @@ def get_rotary_emb_chunk(freqs):

with SparseKVAttnMode(), UnifiedAttnMode(mesh):
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}

for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
hidden_states, encoder_hidden_states = self.call_transformer_blocks(
hidden_states, encoder_hidden_states, temb, guidance, image_rotary_emb
)

# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
Expand Down Expand Up @@ -182,6 +140,58 @@ def custom_forward(*inputs):

transformer.forward = new_forward.__get__(transformer)

def call_transformer_blocks(self, hidden_states, encoder_hidden_states, temb, guidance, image_rotary_emb):
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}

for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
None,
image_rotary_emb,
**ckpt_kwargs,
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
None,
image_rotary_emb,
**ckpt_kwargs,
)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, None, image_rotary_emb
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, None, image_rotary_emb
)

return hidden_states, encoder_hidden_states

transformer.call_transformer_blocks = call_transformer_blocks.__get__(transformer)

transformer._is_parallelized = True

return transformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def new_forward(

encoder_hidden_states = encoder_hidden_states[:, encoder_attention_mask[0].bool()]

# 4. Transformer blocks
hidden_states, encoder_hidden_states = self.call_transformer_blocks(
hidden_states, encoder_hidden_states, temb, guidance, image_rotary_emb
)
Expand Down

0 comments on commit 46bf9e7

Please sign in to comment.