Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inversion #111

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/mflux/callbacks/instances/stepwise_handler.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
from mflux import StopImageGenerationException
from mflux.callbacks.callback import BeforeLoopCallback, InLoopCallback, InterruptCallback
from mflux.config.runtime_config import RuntimeConfig
from mflux.flux.v_cache import VCache
from mflux.post_processing.array_util import ArrayUtil
from mflux.post_processing.image_util import ImageUtil

@@ -16,9 +17,11 @@ def __init__(
self,
flux,
output_dir: str,
forward: bool,
):
self.flux = flux
self.output_dir = Path(output_dir)
self.forward = forward
self.step_wise_images = []

if self.output_dir:
@@ -93,12 +96,13 @@ def _save_image(
lora_scales=self.flux.lora_scales,
generation_time=generation_time,
)
stepwise_img.save(
path=self.output_dir / f"seed_{seed}_step{step}of{config.num_inference_steps}.png",
export_json_metadata=False,
)
self.step_wise_images.append(stepwise_img)
self._save_composite(seed=seed)
if self.forward != VCache.is_inverting:
stepwise_img.save(
path=self.output_dir / f"seed_{seed}_step{step}of{config.num_inference_steps}.png",
export_json_metadata=False,
)
self.step_wise_images.append(stepwise_img)
self._save_composite(seed=seed)

def _save_composite(self, seed: int) -> None:
if self.step_wise_images:
2 changes: 1 addition & 1 deletion src/mflux/config/runtime_config.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ def init_image_strength(self) -> float:

@property
def init_time_step(self) -> int:
is_txt2img = self.config.init_image_path is None or self.config.init_image_strength == 0.0
is_txt2img = True

if is_txt2img:
return 0
152 changes: 109 additions & 43 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
@@ -43,29 +43,96 @@ def __init__(
lora_scales=lora_scales,
)

def generate_image(
def invert(
self,
seed: int,
prompt: str,
config: Config,
) -> GeneratedImage:
) -> (mx.array, mx.array):
# 0. Create a new runtime config based on the model type and input parameters
config = RuntimeConfig(config, self.model_config)
time_steps = tqdm(range(config.init_time_step, config.num_inference_steps))

# 1. Create the initial latents
latents = LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=config.height,
# 1. Create the initial latents from the image
latents = LatentCreator.encode_image(
width=config.width,
height=config.height,
img2img=Img2Img(
vae=self.vae,
sigmas=config.sigmas,
init_time_step=config.init_time_step,
init_image_path=config.init_image_path,
),
)
) # fmt: off

# 2. Encode the prompt
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.prompt_cache,
t5_tokenizer=self.t5_tokenizer,
clip_tokenizer=self.clip_tokenizer,
t5_text_encoder=self.t5_text_encoder,
clip_text_encoder=self.clip_text_encoder,
)

# (Optional) Call subscribers for beginning of loop
Callbacks.before_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=config,
) # fmt: off

for t in time_steps:
dt = config.sigmas[config.num_inference_steps - 1 - t] - config.sigmas[config.num_inference_steps - t]

# 3.t Predict the noise with higher order terms
noise1 = self.transformer(
t=float(config.num_inference_steps - t),
sigma_t=config.sigmas[config.num_inference_steps - t],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents,
config=config,
)
noise2 = self.transformer(
t=float(config.num_inference_steps - t) - 0.5,
sigma_t=config.sigmas[config.num_inference_steps - t] + 0.5 * dt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents + noise1 * 0.5 * dt,
config=config,
)

# 4.t Take one denoise step
latents += dt * noise2

# (Optional) Call subscribes at end of loop
Callbacks.in_loop(
t=config.num_inference_steps - t,
seed=seed,
prompt=prompt,
latents=latents,
config=config,
time_steps=time_steps,
) # fmt: off

# Evaluate to enable progress tracking
mx.eval(latents)

return latents

def generate_image(
self,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
) -> GeneratedImage:
# 0. Create a new runtime config based on the model type and input parameters
config = RuntimeConfig(config, self.model_config)
time_steps = tqdm(range(config.init_time_step, config.num_inference_steps))

# 2. Encode the prompt
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
@@ -85,42 +152,41 @@ def generate_image(
) # fmt: off

for t in time_steps:
try:
# 3.t Predict the noise
noise = self.transformer(
t=t,
config=config,
hidden_states=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)

# 4.t Take one denoise step
dt = config.sigmas[t + 1] - config.sigmas[t]
latents += noise * dt

# (Optional) Call subscribes at end of loop
Callbacks.in_loop(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=config,
time_steps=time_steps,
) # fmt: off

# (Optional) Evaluate to enable progress tracking
mx.eval(latents)

except KeyboardInterrupt: # noqa: PERF203
Callbacks.interruption(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=config,
time_steps=time_steps,
)
dt = config.sigmas[t + 1] - config.sigmas[t]

# 3.t Predict the noise with higher order terms
noise1 = self.transformer(
t=float(t),
sigma_t=config.sigmas[t],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents,
config=config,
)
noise2 = self.transformer(
t=float(t) + 0.5,
sigma_t=config.sigmas[t] + 0.5 * dt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents + noise1 * 0.5 * dt,
config=config,
)

# 4.t Take one denoise step
latents += dt * noise2

# (Optional) Call subscribes at end of loop
Callbacks.in_loop(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=config,
time_steps=time_steps,
) # fmt: off

# Evaluate to enable progress tracking
mx.eval(latents)

# 7. Decode the latent array and return the image
latents = ArrayUtil.unpack_latents(latents=latents, height=config.height, width=config.width)
18 changes: 18 additions & 0 deletions src/mflux/flux/v_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import mlx.core as mx
import numpy as np


class VCache:
is_inverting = True
v_cache = {}
t_max = 5

@staticmethod
def save_dict(data_dict, filename):
np_dict = {k: v.tolist() for k, v in data_dict.items()}
np.savez_compressed(filename, **np_dict)

@staticmethod
def load_dict(filename):
data = np.load(filename)
return {k: mx.array(v) for k, v in data.items()}
91 changes: 53 additions & 38 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,67 @@
from mflux import Config, Flux1, ModelConfig, StopImageGenerationException
from mflux.callbacks.callback_registry import CallbackRegistry
from mflux.callbacks.instances.stepwise_handler import StepwiseHandler
from mflux.ui.cli.parsers import CommandLineParser
from mflux.flux.v_cache import VCache

image_path = "/Users/filipstrand/Desktop/cat.png"
source_prompt = "A cat"
target_prompt = "A sleeping cat"
height = 256
width = 256
steps = 20
seed = 2
source_guidance = 1.5
target_guidance = 5.5
VCache.t_max = 10


def main():
# 0. Parse command line arguments
parser = CommandLineParser(description="Generate an image based on a prompt.")
parser.add_model_arguments(require_model_arg=False)
parser.add_lora_arguments()
parser.add_image_generator_arguments(supports_metadata_config=True)
parser.add_image_to_image_arguments(required=False)
parser.add_output_arguments()
args = parser.parse_args()

# 1. Load the model
# Load the model
flux = Flux1(
model_config=ModelConfig.from_name(model_name=args.model, base_model=args.base_model),
quantize=args.quantize,
local_path=args.path,
lora_paths=args.lora_paths,
lora_scales=args.lora_scales,
model_config=ModelConfig.dev(),
quantize=4,
)

# 2. Register the optional callbacks
if args.stepwise_image_output_dir:
handler = StepwiseHandler(flux=flux, output_dir=args.stepwise_image_output_dir)
CallbackRegistry.register_before_loop(handler)
CallbackRegistry.register_in_loop(handler)
CallbackRegistry.register_interrupt(handler)
# 2a. Register the optional callbacks - Backwards direction
handler_backward = StepwiseHandler(flux=flux, output_dir="/Users/filipstrand/Desktop/backward", forward=False)
CallbackRegistry.register_before_loop(handler_backward)
CallbackRegistry.register_in_loop(handler_backward)
# 2b. Register the optional callbacks - Forwards direction
handler_forward = StepwiseHandler(flux=flux, output_dir="/Users/filipstrand/Desktop/forward", forward=True)
CallbackRegistry.register_before_loop(handler_forward)
CallbackRegistry.register_in_loop(handler_forward)

try:
for seed in args.seed:
# 3. Generate an image for each seed value
image = flux.generate_image(
seed=seed,
prompt=args.prompt,
config=Config(
num_inference_steps=args.steps,
height=args.height,
width=args.width,
guidance=args.guidance,
init_image_path=args.init_image_path,
init_image_strength=args.init_image_strength,
),
)
# 4. Save the image
image.save(path=args.output.format(seed=seed), export_json_metadata=args.metadata)
# 1. Invert an existing image
VCache.is_inverting = True
inverted_latents = flux.invert(
seed=seed,
prompt=source_prompt,
config=Config(
num_inference_steps=steps,
height=height,
width=width,
guidance=source_guidance,
init_image_path=image_path,
),
)

# 2. Generate a new image based on the inverted one
VCache.is_inverting = False
image = flux.generate_image(
seed=seed,
prompt=target_prompt,
latents=inverted_latents,
config=Config(
num_inference_steps=steps,
height=height,
width=width,
guidance=target_guidance,
),
)

# 3. Save the image
image.save(path="edited.png")
except StopImageGenerationException as stop_exc:
print(stop_exc)

19 changes: 12 additions & 7 deletions src/mflux/latent_creator/latent_creator.py
Original file line number Diff line number Diff line change
@@ -57,13 +57,7 @@ def create_for_txt2img_or_img2img(
)

# 2. Encode the image
scaled_user_image = ImageUtil.scale_to_dimensions(
image=ImageUtil.load_image(img2img.init_image_path).convert("RGB"),
target_width=width,
target_height=height,
)
encoded = img2img.vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=height, width=width)
latents = LatentCreator.encode_image(height=height, width=width, img2img=img2img)

# 3. Find the appropriate sigma value
sigma = img2img.sigmas[img2img.init_time_step]
@@ -75,6 +69,17 @@ def create_for_txt2img_or_img2img(
sigma=sigma
) # fmt: off

@staticmethod
def encode_image(height: int, width: int, img2img: Img2Img):
scaled_user_image = ImageUtil.scale_to_dimensions(
image=ImageUtil.load_image(img2img.init_image_path).convert("RGB"),
target_width=width,
target_height=height,
)
encoded = img2img.vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=height, width=width)
return latents

@staticmethod
def add_noise_by_interpolation(clean: mx.array, noise: mx.array, sigma: float) -> mx.array:
return (1 - sigma) * clean + sigma * noise
Loading