Skip to content

Commit

Permalink
optimize parallelize_pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 25, 2024
1 parent ee9e03b commit e8e13b3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def new_patch_embed_forward(
def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
original_call = pipe.__class__.__call__

if not getattr(original_call, "is_parallelized", False):
if not getattr(original_call, "_is_parallelized", False):

@functools.wraps(original_call)
def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs):
if generator is None:
if generator is None and getattr(self, "_is_parallelized", False):
seed = torch.seed()
seed += torch.iinfo(torch.int64).min
seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device)
Expand All @@ -125,11 +125,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.
generator = torch.Generator(self.device).manual_seed(seed)
return original_call(self, *args, generator=generator, **kwargs)

new_call.is_parallelized = True
new_call._is_parallelized = True

pipe.__class__.__call__ = new_call

if not shallow_patch:
parallelize_transformer(pipe.transformer, **kwargs)

pipe._is_parallelized = True

return pipe
8 changes: 5 additions & 3 deletions src/para_attn/context_parallel/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def new_forward(
def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
original_call = pipe.__class__.__call__

if not getattr(original_call, "is_parallelized", False):
if not getattr(original_call, "_is_parallelized", False):

@functools.wraps(original_call)
def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs):
if generator is None:
if generator is None and getattr(self, "_is_parallelized", False):
seed = torch.seed()
seed += torch.iinfo(torch.int64).min
seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device)
Expand All @@ -97,11 +97,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.
generator = torch.Generator(self.device).manual_seed(seed)
return original_call(self, *args, generator=generator, **kwargs)

new_call.is_parallelized = True
new_call._is_parallelized = True

pipe.__class__.__call__ = new_call

if not shallow_patch:
parallelize_transformer(pipe.transformer, **kwargs)

pipe._is_parallelized = True

return pipe
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def custom_forward(*inputs):
def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
original_call = pipe.__class__.__call__

if not getattr(original_call, "is_parallelized", False):
if not getattr(original_call, "_is_parallelized", False):

@functools.wraps(original_call)
def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs):
if generator is None:
if generator is None and getattr(self, "_is_parallelized", False):
seed = torch.seed()
seed += torch.iinfo(torch.int64).min
seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device)
Expand All @@ -198,11 +198,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.
generator = torch.Generator(self.device).manual_seed(seed)
return original_call(self, *args, generator=generator, **kwargs)

new_call.is_parallelized = True
new_call._is_parallelized = True

pipe.__class__.__call__ = new_call

if not shallow_patch:
parallelize_transformer(pipe.transformer, **kwargs)

pipe._is_parallelized = True

return pipe
8 changes: 5 additions & 3 deletions src/para_attn/context_parallel/diffusers_adapters/mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def new_rope_forward(
def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
original_call = pipe.__class__.__call__

if not getattr(original_call, "is_parallelized", False):
if not getattr(original_call, "_is_parallelized", False):

@functools.wraps(original_call)
def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs):
if generator is None:
if generator is None and getattr(self, "_is_parallelized", False):
seed = torch.seed()
seed += torch.iinfo(torch.int64).min
seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device)
Expand All @@ -134,11 +134,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.
generator = torch.Generator(self.device).manual_seed(seed)
return original_call(self, *args, generator=generator, **kwargs)

new_call.is_parallelized = True
new_call._is_parallelized = True

pipe.__class__.__call__ = new_call

if not shallow_patch:
parallelize_transformer(pipe.transformer, **kwargs)

pipe._is_parallelized = True

return pipe

0 comments on commit e8e13b3

Please sign in to comment.