diff --git a/OPTIONS.md b/OPTIONS.md index 2287c987..60ad87c9 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -242,7 +242,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--soft_min_snr_sigma_data SOFT_MIN_SNR_SIGMA_DATA] [--model_type {full,lora,deepfloyd-full,deepfloyd-lora,deepfloyd-stage2,deepfloyd-stage2-lora}] [--legacy] [--kolors] [--flux] - [--flux_lora_target {mmdit,all}] [--flux_fast_schedule] + [--flux_lora_target {mmdit,context,all,all+ffs}] + [--flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE] + [--flux_fast_schedule] [--flux_guidance_mode {constant,random-range}] [--flux_guidance_value FLUX_GUIDANCE_VALUE] [--flux_guidance_min FLUX_GUIDANCE_MIN] @@ -251,10 +253,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--flow_matching_loss {diffusers,compatible,diffusion}] [--pixart_sigma] [--sd3] [--sd3_t5_mask_behaviour {do-nothing,mask}] - [--weighting_scheme {sigma_sqrt,logit_normal,mode,cosmap,none}] - [--logit_mean LOGIT_MEAN] [--logit_std LOGIT_STD] - [--mode_scale MODE_SCALE] [--lora_type {Standard}] - [--lora_init_type {default,gaussian,loftq}] + [--lora_type {Standard}] + [--lora_init_type {default,gaussian,loftq,olora,pissa}] [--lora_rank LORA_RANK] [--lora_alpha LORA_ALPHA] [--lora_dropout LORA_DROPOUT] [--controlnet] [--controlnet_model_name_or_path] @@ -340,6 +340,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--adam_epsilon ADAM_EPSILON] [--adam_bfloat16] [--max_grad_norm MAX_GRAD_NORM] [--push_to_hub] [--push_checkpoints_to_hub] [--hub_model_id HUB_MODEL_ID] + [--model_card_note MODEL_CARD_NOTE] [--logging_dir LOGGING_DIR] [--validation_seed_source {gpu,cpu}] [--validation_torch_compile VALIDATION_TORCH_COMPILE] @@ -373,8 +374,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--noise_offset_probability NOISE_OFFSET_PROBABILITY] [--validation_guidance VALIDATION_GUIDANCE] [--validation_guidance_real VALIDATION_GUIDANCE_REAL] - [--validation_guidance_rescale VALIDATION_GUIDANCE_RESCALE] [--validation_no_cfg_until_timestep VALIDATION_NO_CFG_UNTIL_TIMESTEP] + [--validation_guidance_rescale VALIDATION_GUIDANCE_RESCALE] [--validation_randomize] [--validation_seed VALIDATION_SEED] [--fully_unload_text_encoder] [--freeze_encoder_before FREEZE_ENCODER_BEFORE] @@ -420,14 +421,19 @@ options: model. --flux This option must be provided when training a Flux model. - --flux_lora_target {mmdit,all} - Flux has single and joint attention blocks. The single - attention blocks deal with text inputs and are not - transformed by LoRA by default. All attention blocks - are trained by default. If 'mmdit' is provided, the - text input layers will not be trained. This is roughly - equivalent to not training the text encoder(s) in - earlier models. + --flux_lora_target {mmdit,context,all,all+ffs} + Flux has single and joint attention blocks. Only the + multimodal 'dual stream' attention blocks are trained + by default. If 'mmdit' is provided, the text input + layers will not be trained. If 'context' is provided, + the mmdit layers will not be trained. If 'all' is + provided, all layers will be trained, minus feed- + forward and norms. If 'all+ffs' is provided, all + layers will be trained including feed-forward and + norms. + --flow_matching_sigmoid_scale FLOW_MATCHING_SIGMOID_SCALE + Scale factor for sigmoid timestep sampling for flow- + matching models.. --flux_fast_schedule An experimental feature to train Flux.1S using a noise schedule closer to what it was trained with, which has improved results in short experiments. Thanks to @@ -442,7 +448,10 @@ options: and --flux_guidance_max. --flux_guidance_value FLUX_GUIDANCE_VALUE When using --flux_guidance_mode=constant, this value - will be used for every input sample. + will be used for every input sample. Using a value of + 1.0 seems to preserve the CFG distillation for the Dev + model, and using any other value will result in the + resulting LoRA requiring CFG at inference time. --flux_guidance_min FLUX_GUIDANCE_MIN --flux_guidance_max FLUX_GUIDANCE_MAX --smoldit Use the experimental SmolDiT model architecture. @@ -474,32 +483,12 @@ options: prevents expansion of SD3 Medium's prompt length, as it will unnecessarily attend to every token in the prompt embed, even masked positions. - --weighting_scheme {sigma_sqrt,logit_normal,mode,cosmap,none} - Stable Diffusion 3 used either uniform sampling of - timesteps with post-prediction loss weighting, or a - weighted timestep selection by mode or log-normal - distribution. The default for SD3 is logit_normal, - though upstream Diffusers training examples use - sigma_sqrt. The mode option is experimental, as it is - the most difficult to implement cleanly. In - experiments, logit_normal produced the best results - for large-scale finetuning across many nodes. For - small scale tuning, 'none' returns the best results. - The default is 'none'. - --logit_mean LOGIT_MEAN - As outlined in the Stable Diffusion 3 paper, using a - logit_mean of -0.5 produced the highest quality FID - results. The default here is 0.0. - --logit_std LOGIT_STD - Stable Diffusion 3-specific training parameters. - --mode_scale MODE_SCALE - Stable Diffusion 3-specific training parameters. --lora_type {Standard} When training using --model_type=lora, you may specify a different type of LoRA to train here. Currently, only 'Standard' type is supported. This option exists for compatibility with Kohya configuration files. - --lora_init_type {default,gaussian,loftq} + --lora_init_type {default,gaussian,loftq,olora,pissa} The initialization type for the LoRA model. 'default' will use Microsoft's initialization method, 'gaussian' will use a Gaussian scaled distribution, and 'loftq' @@ -1006,6 +995,9 @@ options: --hub_model_id HUB_MODEL_ID The name of the repository to keep in sync with the local `output_dir`. + --model_card_note MODEL_CARD_NOTE + Add a string to the top of your model card to provide + users with some additional context. --logging_dir LOGGING_DIR [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to @@ -1106,12 +1098,8 @@ options: validations with a single prompt on slower systems, or if you are not interested in unconditional space generations. - --disable_compel If provided, validation pipeline prompts will be - handled using the typical prompt encoding strategy. - Otherwise, the default behaviour is to use Compel for - prompt embed generation. Note that the training input - text embeds are not generated using Compel, and will - be truncated to 77 tokens. + --disable_compel This option does nothing. It is deprecated and will be + removed in a future release. --enable_watermark The SDXL 0.9 and 1.0 licenses both require a watermark be used to identify any images created to be shared. Since the images created during validation typically @@ -1195,14 +1183,11 @@ options: --validation_guidance VALIDATION_GUIDANCE CFG value for validation images. Default: 7.5 --validation_guidance_real VALIDATION_GUIDANCE_REAL - For flux, for any >1.0 value the validation will use - classifier free guidance instead of the distilled - sampling. + Use real CFG sampling for Flux validation images. Default: 1.0 --validation_no_cfg_until_timestep VALIDATION_NO_CFG_UNTIL_TIMESTEP - When using real CFG with flux, do not use CFG until this - sampling timestep. - Default: 2 + When using real CFG sampling for Flux validation + images, skip doing CFG on these timesteps. Default: 2 --validation_guidance_rescale VALIDATION_GUIDANCE_RESCALE CFG rescale value for validation images. Default: 0.0, max 1.0 diff --git a/README.md b/README.md index f63a2075..568cb2e4 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,9 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP Preliminary training support for Flux.1 is included: -- Low loss training using SD3 style loss calculations +- Low loss training using optimised approach + - Preserve the dev model's distillation qualities + - Or, reintroduce CFG to the model and improve its creativity at the cost of inference speed. - LoRA or full tuning via DeepSpeed ZeRO - ControlNet training is not yet supported - Train either Schnell or Dev models diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index bfaa452d..6fbf0ddf 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -312,8 +312,17 @@ In ComfyUI, you'll need to put Flux through another node called AdaptiveGuider. ### Classifier-free guidance #### Problem - -The Dev model arrives guidance-distilled out of the box, which means it does a very straight shot trajectory to the teacher model outputs - this isn't as extreme as what was done to the Schnell model, but it noticeably impacts training by re-introducing the classifier-free guidance objective into the model. Interestingly, this occurs whether caption dropout is set to 0.0 (disabled) or 0.1 (default). +The Dev model arrives guidance-distilled out of the box, which means it does a very straight shot trajectory to the teacher model outputs. This is done through a guidance vector that is fed into the model at training and inference time - the value of this vector greatly impacts what type of resulting LoRA you end up with: +- A value of 1.0 will preserve the initial distillation done to the Dev model + - This is the most compatible mode + - Inference is just as fast as the original model + - Flow-matching distillation reduces the creativity and output variability of the model, as with the original Flux Dev model (everything keeps the same composition/look) +- A higher value (tested around 3.5-4.5) will reintroduce the CFG objective into the model + - This requires the inference pipeline to have support for CFG + - Inference is 50% slower and 0% VRAM increase **or** about 20% slower and 20% VRAM increase due to batched CFG inference + - However, this style of training improves creativity and model output variability, which might be required for certain training tasks + +It's not clear if we can reintroduce CFG to a de-distilled model by continuing tuning using a vector value of 1.0. #### Solution The solution for this is already enabled in the main branch; it is necessary to enable true CFG sampling at inference time when using LoRAs on Dev. @@ -417,6 +426,6 @@ export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --base_model_default_dtype=bf16 The users of [Terminus Research](https://huggingface.co/terminusresearch) who worked on this probably more than their day jobs to figure it out -Lambda Labs for generous compute allocations that were used for tests and verifications for large scale training runs +[Lambda Labs](https://lambdalabs.com) for generous compute allocations that were used for tests and verifications for large scale training runs -Especially [@JimmyCarter](https://huggingface.co/jimmycarter) and [kaibioinfo](https://github.com/kaibioinfo) for coming up with some of the best ideas and putting them into action, offering pull requests and running exhaustive tests for analysis - even daring to use _their own faces_ for DreamBooth experimentation. \ No newline at end of file +Especially [@JimmyCarter](https://huggingface.co/jimmycarter) and [@kaibioinfo](https://github.com/kaibioinfo) for coming up with some of the best ideas and putting them into action, offering pull requests and running exhaustive tests for analysis - even daring to use _their own faces_ for DreamBooth experimentation. diff --git a/helpers/arguments.py b/helpers/arguments.py index 73cb6ec2..9e254d9a 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -119,10 +119,10 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--flux_sigmoid_scale", + "--flow_matching_sigmoid_scale", type=float, default=1.0, - help='Scale factor for sigmoid timestep sampling (only used when timestep_scheme is "flux").', + help="Scale factor for sigmoid timestep sampling for flow-matching models..", ) parser.add_argument( "--flux_fast_schedule", @@ -147,9 +147,11 @@ def parse_args(input_args=None): parser.add_argument( "--flux_guidance_value", type=float, - default=4.0, + default=1.0, help=( "When using --flux_guidance_mode=constant, this value will be used for every input sample." + " Using a value of 1.0 seems to preserve the CFG distillation for the Dev model," + " and using any other value will result in the resulting LoRA requiring CFG at inference time." ), ) parser.add_argument( @@ -189,17 +191,6 @@ def parse_args(input_args=None): " Additionally, 'diffusion' is offered as an option to reparameterise a model to v_prediction loss." ), ) - parser.add_argument( - "--timestep_scheme", - type=str, - choices=["sd3", "flux"], - default=None, - help=( - "When training flow-matching models like SD3 or Flux, we can select timesteps based on an approximated continuous schedule" - " that takes the 1000 timesteps and derives pseudo-sigmas from them. This is the default behaviour." - " Flux training seems to benefit from a sigma schedule, and is recommended to use the 'flux' option." - ), - ) parser.add_argument( "--pixart_sigma", action="store_true", @@ -228,40 +219,6 @@ def parse_args(input_args=None): " even masked positions." ), ) - parser.add_argument( - "--weighting_scheme", - type=str, - default="cosmap", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=( - "Stable Diffusion 3 used either uniform sampling of timesteps with post-prediction loss weighting, or" - " a weighted timestep selection by mode or log-normal distribution. The default for SD3 is logit_normal, though" - " upstream Diffusers training examples use sigma_sqrt. The mode option is experimental," - " as it is the most difficult to implement cleanly. In experiments, logit_normal produced the best results" - " for large-scale finetuning across many nodes. For small scale tuning, 'none' returns the best results." - " The default is 'none'." - ), - ) - parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help=( - "As outlined in the Stable Diffusion 3 paper, using a logit_mean of -0.5 produced the highest quality FID results. The default here is 0.0." - ), - ) - parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help=("Stable Diffusion 3-specific training parameters."), - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help=("Stable Diffusion 3-specific training parameters."), - ) parser.add_argument( "--lora_type", type=str, @@ -1523,7 +1480,7 @@ def parse_args(input_args=None): "--validation_guidance_real", type=float, default=1.0, - help="Use real CFG sampling for Flux validation images. Default: 1.0", + help="Use real CFG sampling for Flux validation images. Default: 1.0 (no CFG)", ) parser.add_argument( "--validation_no_cfg_until_timestep", @@ -1992,9 +1949,6 @@ def parse_args(input_args=None): if args.sd3: args.pretrained_vae_model_name_or_path = None args.disable_compel = True - if args.timestep_scheme is None: - args.timestep_scheme = "sd3" - logger.info(f"Using {args.timestep_scheme} timestep scheme.") t5_max_length = 77 if args.sd3 and ( @@ -2020,9 +1974,6 @@ def parse_args(input_args=None): elif "dev" in args.pretrained_model_name_or_path.lower(): model_max_seq_length = 512 if args.flux: - if args.timestep_scheme is None: - args.timestep_scheme = "flux" - logger.info(f"Using {args.timestep_scheme} timestep scheme.") if ( args.tokenizer_max_length is None or int(args.tokenizer_max_length) > model_max_seq_length diff --git a/helpers/models/flux/__init__.py b/helpers/models/flux/__init__.py index 11e4d292..298d2925 100644 --- a/helpers/models/flux/__init__.py +++ b/helpers/models/flux/__init__.py @@ -5,7 +5,6 @@ def update_flux_schedule_to_fast(args, noise_scheduler_to_copy): if args.flux_fast_schedule and args.flux: # 4-step noise schedule [0.7, 0.1, 0.1, 0.1] from SD3-Turbo paper - print(f"sigmas before: {noise_scheduler_to_copy.sigmas}") for i in range(0, 250): noise_scheduler_to_copy.sigmas[i] = 1.0 for i in range(250, 500): @@ -14,7 +13,6 @@ def update_flux_schedule_to_fast(args, noise_scheduler_to_copy): noise_scheduler_to_copy.sigmas[i] = 0.2 for i in range(750, 1000): noise_scheduler_to_copy.sigmas[i] = 0.1 - print(f"sigmas after: {noise_scheduler_to_copy.sigmas}") return noise_scheduler_to_copy diff --git a/train.py b/train.py index 9455e405..1187e2c7 100644 --- a/train.py +++ b/train.py @@ -100,7 +100,6 @@ prepare_latent_image_ids, pack_latents, unpack_latents, - update_flux_schedule_to_fast, ) # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -504,12 +503,6 @@ def main(): subfolder="scheduler", shift=1 if args.flux else 3, ) - noise_scheduler_copy = copy.deepcopy( - update_flux_schedule_to_fast( - args=args, noise_scheduler_to_copy=noise_scheduler - ) - ) - else: if args.legacy: args.rescale_betas_zero_snr = True @@ -1882,56 +1875,25 @@ def main(): bsz = latents.shape[0] training_logger.debug(f"Working on batch size: {bsz}") - if flow_matching and args.timestep_scheme == "sd3": - # for weighting schemes where we sample timesteps non-uniformly - # thanks to @Slickytail who implemented this correctly via #8528 - if args.weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal( - mean=args.logit_mean, - std=args.logit_std, - size=(bsz,), - device="cpu", - ) - u = torch.nn.functional.sigmoid(u) - elif args.weighting_scheme == "mode": - u = torch.rand(size=(bsz,), device="cpu") - u = ( - 1 - - u - - args.mode_scale - * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + if flow_matching: + if not args.flux_fast_schedule: + # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF + # also used by: https://github.com/XLabs-AI/x-flux/tree/main + # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f + sigmas = torch.sigmoid( + args.flow_matching_sigmoid_scale + * torch.randn((bsz,), device=accelerator.device) ) + timesteps = sigmas * 1000.0 + sigmas = sigmas.view(-1, 1, 1, 1) else: - u = torch.rand(size=(bsz,), device="cpu") - if args.flux_fast_schedule: - # We need to train only timesteps [1, 0.75, 0.5, 0.25] based on SD3-Turbo paper - quarter_step = int( - noise_scheduler_copy.config.num_train_timesteps / 4 + # fast schedule can only use these sigmas, and they can be sampled up to batch size times + available_sigmas = [0.7, 0.1] + sigmas = torch.tensor( + random.choices(available_sigmas, k=bsz), + device=accelerator.device, ) - indices = ((u * 4).long() + 1) * quarter_step - 1 - # indices = (u * 4).long() * quarter_step - 1 - else: - indices = ( - u * noise_scheduler_copy.config.num_train_timesteps - ).long() - - indices = ( - u * noise_scheduler_copy.config.num_train_timesteps - ).long() - timesteps = noise_scheduler_copy.timesteps[indices].to( - device=latents.device - ) - elif flow_matching and args.timestep_scheme == "flux": - # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF - # also used by: https://github.com/XLabs-AI/x-flux/tree/main - # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f - sigmas = torch.sigmoid( - args.flux_sigmoid_scale - * torch.randn((bsz,), device=accelerator.device) - ) - timesteps = sigmas * 1000.0 - sigmas = sigmas.view(-1, 1, 1, 1) + timesteps = sigmas * 1000.0 else: # Sample a random timestep for each image, potentially biased by the timestep weights. # Biasing the timestep weights allows us to spend less time training irrelevant timesteps. @@ -1958,22 +1920,7 @@ def main(): timesteps_buffer.append((global_step, timestep)) if flow_matching: - if args.timestep_scheme == "sd3": - # Add noise according to flow matching. - sigmas = get_sd3_sigmas( - accelerator, - noise_scheduler_copy, - timesteps, - n_dim=latents.ndim, - dtype=latents.dtype, - ) - noisy_latents = ( - 1.0 - sigmas - ) * latents.float() + sigmas * noise.float() - # is equal to: - # zt = (1 - texp) * x + texp * z1 - elif args.timestep_scheme == "flux": - noisy_latents = (1 - sigmas) * latents + sigmas * noise + noisy_latents = (1 - sigmas) * latents + sigmas * noise else: # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -2152,9 +2099,13 @@ def main(): # Stable Diffusion 3 uses a MM-DiT model where the VAE-produced # image embeds are passed in with the TE-produced text embeds. model_pred = transformer( - hidden_states=noisy_latents, + hidden_states=noisy_latents.to( + device=accelerator.device, dtype=base_weight_dtype + ), timestep=timesteps, - encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states.to( + device=accelerator.device, dtype=base_weight_dtype + ), pooled_projections=add_text_embeds.to( device=accelerator.device, dtype=weight_dtype ), @@ -2231,14 +2182,6 @@ def main(): width=latents.shape[3] * 8, vae_scale_factor=16, ) - if flow_matching and args.timestep_scheme == "sd3": - # Follow: Section 5 of https://arxiv.org/abs/2206.00364. - # Preconditioning of the model outputs. - if args.flow_matching_loss == "diffusers": - model_pred = model_pred * (-sigmas) + noisy_latents - elif args.flow_matching_loss == "compatible": - # we shouldn't mess with the model prediction. - pass # x-prediction requires that we now subtract the noise residual from the prediction to get the target sample. if ( @@ -2249,29 +2192,12 @@ def main(): model_pred = model_pred - noise if flow_matching: - # upstream TODO: weighting sceme needs to be experimented with :) - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - if ( - args.timestep_scheme == "sd3" - and args.weighting_scheme != "none" - ): - if args.weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif args.weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) loss = torch.mean( - ( - weighting.float() - * (model_pred.float() - target.float()) ** 2 - ).reshape(target.shape[0], -1), + ((model_pred.float() - target.float()) ** 2).reshape( + target.shape[0], -1 + ), 1, - ) - loss = loss.mean() - + ).mean() elif args.snr_gamma is None or args.snr_gamma == 0: training_logger.debug("Calculating loss") loss = args.snr_weight * F.mse_loss( @@ -2785,7 +2711,7 @@ def main(): add_watermarker=args.enable_watermark, torch_dtype=weight_dtype, ) - if args.validation_noise_scheduler is not None: + if not flow_matching and args.validation_noise_scheduler is not None: pipeline.scheduler = SCHEDULER_NAME_MAP[ args.validation_noise_scheduler ].from_pretrained(