Skip to content

Commit

Permalink
add huber loss, timestep clamping, slightly safer txt reading
Browse files Browse the repository at this point in the history
  • Loading branch information
victorchall committed Apr 27, 2024
1 parent f369e53 commit 3a6fe3b
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 34 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@
.idea
/.cache
/models
/*.safetensors
/*.webp
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ Covers install, setup of base models, startning training, basic tweaking, and lo

Behind the scenes look at how the trainer handles multiaspect and crop jitter

### Companion tools repo

Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.

## Cloud/Docker

### [Free tier Google Colab notebook](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb)
Expand Down Expand Up @@ -81,7 +77,7 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea

[Validation](doc/VALIDATION.md) - Use a validation split on your data to see when you are overfitting and tune hyperparameters

[Captioning](doc/CAPTION.md) - (beta) tools to automate captioning
[Captioning](doc/CAPTION.md) - tools to generate synthetic captioning (recommend [Cog](doc/CAPTION_COG.md))

[Plugins](doc/PLUGINS.md) - (beta) write your own plugins to execute arbitrary code during training

Expand Down
38 changes: 31 additions & 7 deletions doc/ADVANCED_TWEAKING.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,45 @@ This is useful if you want to dump the CKPT files directly to your webui/inferen

## Conditional dropout

Conditional dropout means the prompt or caption on the training image is dropped, and the caption is "blank". The theory is this can help with unconditional guidance, per the original paper and authors of Latent Diffusion and Stable Diffusion.
Conditional dropout means the prompt or caption on the training image is dropped, and the caption is "blank". This can help with unconditional guidance, per the original paper and authors of Latent Diffusion and Stable Diffusion. This means the CFG Scale used at inference time will respond more smoothly.

The value is defaulted at 0.04, which means 4% conditional dropout. You can set it to 0.0 to disable it, or increase it. Many users of EveryDream 1.0 have had great success tweaking this, especially for larger models. You may wish to try 0.10. This may also be useful to really "force" a style into the model with a high setting such as 0.15. However, setting it very high may lead to bleeding or overfitting to your training data, especially if your data is not very diverse, which may or may not be desirable for your project.
The value is defaulted at 0.04, which means 4% conditional dropout. You can set it to 0.0 to disable it, or increase it. For larger training (many tens of thousands) using 0.10 would be my recommendation.

This may also be useful to really "force" a style into the model with a high setting such as 0.15. However, setting it very high may lead to bleeding or overfitting to your training data, especially if your data is not very diverse, which may or may not be desirable for your project.

--cond_dropout 0.1 ^

## LR tweaking
## Timestep clamping

Learning rate adjustment is a very important part of training.
Stable Diffusion uses 1000 possible timesteps for denoising steps. If you wish to train only a portion of those timesteps instead of the entire schedule you can clamp the value.

--lr 1.0e-6 ^
Timesteps are always chosen randomly per training example, per step, within the possible or allowed timesteps.

For instance, if you only want to train from 500 to 999, use this:

--timestep_start 500

Or if you only want to try from 0 to 449, use this:

--timestep_end 450

Possible use cases are to "focus" training on aesthetics or composition. It's likely you may need to train all timesteps as a "clean up" if you train just specific timestep ranges first.

By default, the learning rate is constant for the entire training session. However, if you want it to change by itself during training, you can use cosine.
## Loss Type

General suggestion is 1e-6 for training SD1.5 at 512 resolution. For SD2.1 at 768, try a much lower value, such as 2e-7. [Validation](VALIDATION.md) can be helpful to tune learning rate.
You can change the type of loss from the standard [MSE ("L2") loss](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) to [Huber loss](https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html), or a interpolated value across timesteps. Valid values are "mse", "huber", "mse_huber", and "huber_mse".

--loss_type huber

mse_huber will use MSE at timestep 0 and huber at timestep 999, and interpolate between the two across the intermediate timesteps. huber_mse is the reverse

## LR tweaking

You should use [Optimizer config](doc/OPTIMZER.md) to tweak instead of the primary arg here, but it is left for legacy support of the Jupyter Notebook to make it easier to use the Jupyter Notbook in a happy path or simplified scenario.

--lr 1.0e-6 ^

*If you set this in train.json or the main CLI arg it will override the value from your optimizer.json, so use with caution...* Again, best to use optimizer.json instead.

## Clip skip

Expand Down
9 changes: 5 additions & 4 deletions train.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"loss_type": "mse",
"max_epochs": 40,
"notebook": false,
"optimizer_config": "optimizer.json",
Expand All @@ -29,17 +30,17 @@
"save_optimizer": false,
"scale_lr": false,
"seed": 555,
"timestep_start": 0,
"timestep_end": 1000,
"shuffle_tags": false,
"validation_config": "validation_default.json",
"wandb": false,
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02,
"pyramid_noise_discount": null,
"pyramid_noise_discount": 0.03,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"min_snr_gamma": 5.0,
"ema_decay_rate": null,
"ema_strength_target": null,
"ema_update_interval": null,
Expand Down
55 changes: 46 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,11 @@ def setup_args(args):

args.aspects = aspects.get_aspect_buckets(args.resolution)

if args.timestep_start < 0:
raise ValueError("timestep_start must be >= 0")
if args.timestep_end > 1000:
raise ValueError("timestep_end must be <= 1000")

return args


Expand Down Expand Up @@ -727,16 +732,22 @@ def release_memory(model_to_delete, original_device):
text_encoder_ema = None

try:
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
#vae = torch.compile(vae)
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.allow_tf32 = True
print()
#unet = torch.compile(unet, mode="max-autotune")
#text_encoder = torch.compile(text_encoder, mode="max-autotune")
#vae = torch.compile(vae, mode="max-autotune")
#logging.info("Successfully compiled models")
except Exception as ex:
logging.warning(f"Failed to compile model, continuing anyway, ex: {ex}")
pass

try:
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.allow_tf32 = True
except Exception as ex:
logging.warning(f"Failed to set float32 matmul precision, continuing anyway, ex: {ex}")
pass

optimizer_config = None
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
Expand Down Expand Up @@ -944,7 +955,7 @@ def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.

bsz = latents.shape[0]

timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(args.timestep_start, args.timestep_end, (bsz,), device=latents.device)
timesteps = timesteps.long()

cuda_caption = tokens.to(text_encoder.device)
Expand Down Expand Up @@ -987,9 +998,32 @@ def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.
mse_loss_weights[snr == 0] = 1.0
loss_scale = loss_scale * mse_loss_weights.to(loss_scale.device)

loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
loss = loss.mean()
loss_mse = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss_scale = loss_scale.view(-1, 1, 1, 1).expand_as(loss_mse)

if args.loss_type == "mse_huber":
early_timestep_bias = (timesteps / noise_scheduler.config.num_train_timesteps)
early_timestep_bias = torch.tensor(early_timestep_bias, dtype=torch.float).to(unet.device)
early_timestep_bias = early_timestep_bias.view(-1, 1, 1, 1).expand_as(loss_mse)
loss_huber = F.huber_loss(model_pred.float(), target.float(), reduction="none", delta=1.0)
loss_mse = loss_mse * loss_scale.to(unet.device) * early_timestep_bias
loss_huber = loss_huber * loss_scale.to(unet.device) * (1.0 - early_timestep_bias)
loss = loss_mse.mean() + loss_huber.mean()
elif args.loss_type == "huber_mse":
early_timestep_bias = (timesteps / noise_scheduler.config.num_train_timesteps)
early_timestep_bias = torch.tensor(early_timestep_bias, dtype=torch.float).to(unet.device)
early_timestep_bias = early_timestep_bias.view(-1, 1, 1, 1).expand_as(loss_mse)
loss_huber = F.huber_loss(model_pred.float(), target.float(), reduction="none", delta=1.0)
loss_mse = loss_mse * loss_scale.to(unet.device) * (1.0 - early_timestep_bias)
loss_huber = loss_huber * loss_scale.to(unet.device) * early_timestep_bias
loss = loss_huber.mean() + loss_mse.mean()
elif args.loss_type == "huber":
loss_huber = F.huber_loss(model_pred.float(), target.float(), reduction="none", delta=1.0)
loss_huber = loss_huber * loss_scale.to(unet.device)
loss = loss_huber.mean()
else:
loss_mse = loss_mse * loss_scale.to(unet.device)
loss = loss_mse.mean()

return model_pred, target, loss

Expand Down Expand Up @@ -1334,6 +1368,7 @@ def update_arg(arg: str, newValue):
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--loss_type", type=str, default="mse_huber", help="type of loss, 'huber', 'mse', or 'mse_huber' for interpolated (def: mse_huber)", choices=["huber", "mse", "mse_huber"])
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
Expand All @@ -1356,6 +1391,8 @@ def update_arg(arg: str, newValue):
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--timestep_start", type=int, default=0, help="Noising timestep minimum (def: 0)")
argparser.add_argument("--timestep_end", type=int, default=1000, help="Noising timestep (def: 1000)")
argparser.add_argument("--train_sampler", type=str, default="ddpm", help="noise sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"])
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, used to randomly select subset of tags when shuffling is enabled, def: 0 (shuffle all)")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
Expand Down
18 changes: 9 additions & 9 deletions utils/fs_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ def is_image(file):

def read_text(file):
try:
with open(file, encoding='utf-8', mode='r') as stream:
return stream.read().strip()
encodings = ['utf-8', 'iso-8859-1', 'windows-1252', 'latin-1']
for encoding in encodings:
try:
with open(file, encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f'Could not decode file with any of the provided encodings: {encodings}')
except Exception as e:
logging.warning(f" *** Error reading text file as utf-8: {file}: {e}")

try:
with open(file, encoding='latin-1', mode='r') as stream:
return stream.read().strip()
except Exception as e:
logging.warning(f" *** Error reading text file as latin-1: {file}: {e}")
logging.warning(f" *** Error reading text file: {file}: {e}")

def read_float(file):
try:
Expand Down

0 comments on commit 3a6fe3b

Please sign in to comment.