diff --git a/src/para_attn/context_parallel/diffusers_adapters/cogvideox.py b/src/para_attn/context_parallel/diffusers_adapters/cogvideox.py index 2307067..2b843f8 100644 --- a/src/para_attn/context_parallel/diffusers_adapters/cogvideox.py +++ b/src/para_attn/context_parallel/diffusers_adapters/cogvideox.py @@ -110,11 +110,11 @@ def new_patch_embed_forward( def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs): original_call = pipe.__class__.__call__ - if not getattr(original_call, "is_parallelized", False): + if not getattr(original_call, "_is_parallelized", False): @functools.wraps(original_call) def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs): - if generator is None: + if generator is None and getattr(self, "_is_parallelized", False): seed = torch.seed() seed += torch.iinfo(torch.int64).min seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device) @@ -125,11 +125,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch. generator = torch.Generator(self.device).manual_seed(seed) return original_call(self, *args, generator=generator, **kwargs) - new_call.is_parallelized = True + new_call._is_parallelized = True pipe.__class__.__call__ = new_call if not shallow_patch: parallelize_transformer(pipe.transformer, **kwargs) + pipe._is_parallelized = True + return pipe diff --git a/src/para_attn/context_parallel/diffusers_adapters/flux.py b/src/para_attn/context_parallel/diffusers_adapters/flux.py index eb0e459..39b8c35 100644 --- a/src/para_attn/context_parallel/diffusers_adapters/flux.py +++ b/src/para_attn/context_parallel/diffusers_adapters/flux.py @@ -82,11 +82,11 @@ def new_forward( def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs): original_call = pipe.__class__.__call__ - if not getattr(original_call, "is_parallelized", False): + if not getattr(original_call, "_is_parallelized", False): @functools.wraps(original_call) def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs): - if generator is None: + if generator is None and getattr(self, "_is_parallelized", False): seed = torch.seed() seed += torch.iinfo(torch.int64).min seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device) @@ -97,11 +97,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch. generator = torch.Generator(self.device).manual_seed(seed) return original_call(self, *args, generator=generator, **kwargs) - new_call.is_parallelized = True + new_call._is_parallelized = True pipe.__class__.__call__ = new_call if not shallow_patch: parallelize_transformer(pipe.transformer, **kwargs) + pipe._is_parallelized = True + return pipe diff --git a/src/para_attn/context_parallel/diffusers_adapters/hunyuan_video.py b/src/para_attn/context_parallel/diffusers_adapters/hunyuan_video.py index e3416fd..8996ad0 100644 --- a/src/para_attn/context_parallel/diffusers_adapters/hunyuan_video.py +++ b/src/para_attn/context_parallel/diffusers_adapters/hunyuan_video.py @@ -183,11 +183,11 @@ def custom_forward(*inputs): def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs): original_call = pipe.__class__.__call__ - if not getattr(original_call, "is_parallelized", False): + if not getattr(original_call, "_is_parallelized", False): @functools.wraps(original_call) def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs): - if generator is None: + if generator is None and getattr(self, "_is_parallelized", False): seed = torch.seed() seed += torch.iinfo(torch.int64).min seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device) @@ -198,11 +198,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch. generator = torch.Generator(self.device).manual_seed(seed) return original_call(self, *args, generator=generator, **kwargs) - new_call.is_parallelized = True + new_call._is_parallelized = True pipe.__class__.__call__ = new_call if not shallow_patch: parallelize_transformer(pipe.transformer, **kwargs) + pipe._is_parallelized = True + return pipe diff --git a/src/para_attn/context_parallel/diffusers_adapters/mochi.py b/src/para_attn/context_parallel/diffusers_adapters/mochi.py index c4b52f9..c84d2ba 100644 --- a/src/para_attn/context_parallel/diffusers_adapters/mochi.py +++ b/src/para_attn/context_parallel/diffusers_adapters/mochi.py @@ -119,11 +119,11 @@ def new_rope_forward( def parallelize_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs): original_call = pipe.__class__.__call__ - if not getattr(original_call, "is_parallelized", False): + if not getattr(original_call, "_is_parallelized", False): @functools.wraps(original_call) def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, **kwargs): - if generator is None: + if generator is None and getattr(self, "_is_parallelized", False): seed = torch.seed() seed += torch.iinfo(torch.int64).min seed_t = torch.full([1], seed, dtype=torch.int64, device=self.device) @@ -134,11 +134,13 @@ def new_call(self, *args, generator: Optional[Union[torch.Generator, List[torch. generator = torch.Generator(self.device).manual_seed(seed) return original_call(self, *args, generator=generator, **kwargs) - new_call.is_parallelized = True + new_call._is_parallelized = True pipe.__class__.__call__ = new_call if not shallow_patch: parallelize_transformer(pipe.transformer, **kwargs) + pipe._is_parallelized = True + return pipe