Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flowedit #114

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mflux/config/runtime_config.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,10 @@ def width(self) -> int:
def guidance(self) -> float:
return self.config.guidance

@guidance.setter
def guidance(self, value: float):
self.config.guidance = value

@property
def num_inference_steps(self) -> int:
return self.config.num_inference_steps
71 changes: 52 additions & 19 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,11 @@ def __init__(
def generate_image(
self,
seed: int,
prompt: str,
src_prompt: str,
tar_prompt: str,
src_guidance: float,
tar_guidance: float,
image_path: str,
config: Config = Config(),
stepwise_output_dir: Path = None,
) -> GeneratedImage:
@@ -80,52 +84,81 @@ def generate_image(
flux=self,
config=config,
seed=seed,
prompt=prompt,
prompt=f"src_prompt: {src_prompt} | tar_prompt: {tar_prompt}",
time_steps=time_steps,
output_dir=stepwise_output_dir,
)

# 1. Create the initial latents
latents = LatentCreator.create_for_txt2img_or_img2img(seed, config, self.vae)

# 2. Embed the prompt
t5_tokens = self.t5_tokenizer.tokenize(prompt)
clip_tokens = self.clip_tokenizer.tokenize(prompt)
prompt_embeds = self.t5_text_encoder(t5_tokens)
pooled_prompt_embeds = self.clip_text_encoder(clip_tokens)

image_latents = LatentCreator.encode_image(
init_image_path=Path(image_path),
width=config.width,
height=config.height,
vae=self.vae
) # fmt:off

# 2a. Embed the source prompt
t5_tokens_src = self.t5_tokenizer.tokenize(src_prompt)
clip_tokens_src = self.clip_tokenizer.tokenize(src_prompt)
prompt_embeds_src = self.t5_text_encoder(t5_tokens_src)
pooled_prompt_embeds_src = self.clip_text_encoder(clip_tokens_src)
# 2b. Embed the target prompt
t5_tokens_tar = self.t5_tokenizer.tokenize(tar_prompt)
clip_tokens_tar = self.clip_tokenizer.tokenize(tar_prompt)
prompt_embeds_tar = self.t5_text_encoder(t5_tokens_tar)
pooled_prompt_embeds_tar = self.clip_text_encoder(clip_tokens_tar)

Z_FE = mx.array(image_latents)
for gen_step, t in enumerate(time_steps, 1):
try:
if config.num_inference_steps - t > 24:
continue

random_noise = mx.random.normal(shape=[1, (config.height // 16) * (config.width // 16), 64])
Z_src = (1 - config.sigmas[t]) * image_latents + config.sigmas[t] * random_noise
Z_tar = Z_FE + Z_src - image_latents

# 3.t Predict the noise
noise = self.transformer.predict(
config.guidance = src_guidance
noise_src = self.transformer.predict(
t=t,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents,
prompt_embeds=prompt_embeds_src,
pooled_prompt_embeds=pooled_prompt_embeds_src,
hidden_states=Z_src,
config=config,
)
config.guidance = tar_guidance
noise_tar = self.transformer.predict(
t=t,
prompt_embeds=prompt_embeds_tar,
pooled_prompt_embeds=pooled_prompt_embeds_tar,
hidden_states=Z_tar,
config=config,
)

noise_delta = noise_tar - noise_src

# 4.t Take one denoise step
dt = config.sigmas[t + 1] - config.sigmas[t]
latents += noise * dt
Z_FE += noise_delta * dt

# Handle stepwise output if enabled
stepwise_handler.process_step(gen_step, latents)
stepwise_handler.process_step(gen_step, Z_FE)

# Evaluate to enable progress tracking
mx.eval(latents)
mx.eval(Z_FE)

except KeyboardInterrupt: # noqa: PERF203
stepwise_handler.handle_interruption()
raise StopImageGenerationException(f"Stopping image generation at step {t + 1}/{len(time_steps)}")

# 5. Decode the latent array and return the image
latents = ArrayUtil.unpack_latents(latents=latents, height=config.height, width=config.width)
latents = ArrayUtil.unpack_latents(latents=Z_FE, height=config.height, width=config.width)
decoded = self.vae.decode(latents)
return ImageUtil.to_image(
decoded_latents=decoded,
seed=seed,
prompt=prompt,
prompt=f"src_prompt: {src_prompt} | tar_prompt: {tar_prompt}",
quantization=self.bits,
generation_time=time_steps.format_dict["elapsed"],
lora_paths=self.lora_paths,
56 changes: 29 additions & 27 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,49 @@
import time
from pathlib import Path

from mflux import Config, Flux1, ModelConfig, StopImageGenerationException
from mflux.ui.cli.parsers import CommandLineParser

# image_path = "/Users/filipstrand/Desktop/lighthouse.png"
# source_prompt = "The image features a tall white lighthouse standing prominently on a hill, with a beautiful blue sky in the background. The lighthouse is illuminated by a bright light, making it a prominent landmark in the scene."
# target_prompt = "The image features Big ben clock tower standing prominently on a hill, with a beautiful blue sky in the background. The Big ben clock tower is illuminated by a bright light, making it a prominent landmark in the scene."

def main():
# fmt: off
parser = CommandLineParser(description="Generate an image based on a prompt.")
parser.add_model_arguments(require_model_arg=False)
parser.add_lora_arguments()
parser.add_image_generator_arguments(supports_metadata_config=True)
parser.add_image_to_image_arguments(required=False)
parser.add_output_arguments()
args = parser.parse_args()
image_path = "/Users/filipstrand/Desktop/gas_station.png"
source_prompt = "A gas station with a white and red sign that reads 'CAFE' There are several cars parked in front of the gas station, including a white car and a van."
target_prompt = "A gas station with a white and red sign that reads 'CVPR' There are several cars parked in front of the gas station, including a white car and a van."

height = 512
width = 512
steps = 28
seed = 2
source_guidance = 1.5
target_guidance = 5.5


def main():
# Load the model
flux = Flux1(
model_config=ModelConfig.from_alias(args.model),
quantize=args.quantize,
local_path=args.path,
lora_paths=args.lora_paths,
lora_scales=args.lora_scales,
model_config=ModelConfig.FLUX1_DEV,
quantize=4,
)

try:
# Generate an image
image = flux.generate_image(
seed=int(time.time()) if args.seed is None else args.seed,
prompt=args.prompt,
stepwise_output_dir=Path(args.stepwise_image_output_dir) if args.stepwise_image_output_dir else None,
seed=seed,
src_prompt=source_prompt,
tar_prompt=target_prompt,
src_guidance=source_guidance,
tar_guidance=target_guidance,
image_path=image_path,
stepwise_output_dir=Path("/Users/filipstrand/Desktop/edit"),
config=Config(
num_inference_steps=args.steps,
height=args.height,
width=args.width,
guidance=args.guidance,
init_image_path=args.init_image_path,
init_image_strength=args.init_image_strength,
num_inference_steps=steps,
height=height,
width=width,
guidance=0.0,
),
)

# Save the image
image.save(path=args.output, export_json_metadata=args.metadata)
image.save(path="edited.png")
except StopImageGenerationException as stop_exc:
print(stop_exc)

31 changes: 24 additions & 7 deletions src/mflux/latent_creator/latent_creator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import mlx.core as mx
from mlx import nn

@@ -35,21 +37,36 @@ def create_for_txt2img_or_img2img(
return pure_noise
else:
# Image2Image
user_image = ImageUtil.load_image(runtime_conf.config.init_image_path).convert("RGB")
scaled_user_image = ImageUtil.scale_to_dimensions(
image=user_image,
target_width=runtime_conf.width,
target_height=runtime_conf.height,
latents = LatentCreator.encode_image(
init_image_path=runtime_conf.config.init_image_path,
height=runtime_conf.height,
width=runtime_conf.width,
vae=vae,
)
encoded = vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=runtime_conf.height, width=runtime_conf.width)
sigma = runtime_conf.sigmas[runtime_conf.init_time_step]
return LatentCreator.add_noise_by_interpolation(
clean=latents,
noise=pure_noise,
sigma=sigma
) # fmt: off

@staticmethod
def encode_image(
init_image_path: Path,
width: int,
height: int,
vae: nn.Module,
):
user_image = ImageUtil.load_image(init_image_path).convert("RGB")
scaled_user_image = ImageUtil.scale_to_dimensions(
image=user_image,
target_width=width,
target_height=height,
)
encoded = vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=height, width=width)
return latents

@staticmethod
def add_noise_by_interpolation(clean: mx.array, noise: mx.array, sigma: float) -> mx.array:
return (1 - sigma) * clean + sigma * noise