From 4a633a3e731154f7f9e90a62665a5f7c80dd0012 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 12:36:18 -0600 Subject: [PATCH 1/2] fix #972 by unwrapping model --- helpers/training/trainer.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index dca93610..75d1a93d 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -822,26 +822,40 @@ def init_post_load_freeze(self): if self.unet is not None: logger.info("Applying BitFit freezing strategy to the U-net.") - self.unet = apply_bitfit_freezing(self.unet, self.config) + self.unet = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.unet), self.config + ) if self.transformer is not None: logger.warning( "Training DiT models with BitFit is not yet tested, and unexpected results may occur." ) - self.transformer = apply_bitfit_freezing(self.transformer, self.config) + self.transformer = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.transformer), self.config + ) if self.config.gradient_checkpointing: if self.unet is not None: - self.unet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.unet + ).enable_gradient_checkpointing() if self.transformer is not None and self.config.model_family != "smoldit": - self.transformer.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.transformer + ).enable_gradient_checkpointing() if self.config.controlnet: - self.controlnet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.controlnet + ).enable_gradient_checkpointing() if ( hasattr(self.config, "train_text_encoder") and self.config.train_text_encoder ): - self.text_encoder_1.gradient_checkpointing_enable() - self.text_encoder_2.gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_enable() def _recalculate_training_steps(self): # Scheduler and math around the number of training steps. From a997ec49c03f4aa731a79e072ade222890ab7c35 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 12:38:44 -0600 Subject: [PATCH 2/2] fix nonetype reference when ctrl+c --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index cb63532b..b3c72a18 100644 --- a/train.py +++ b/train.py @@ -8,8 +8,6 @@ logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) if __name__ == "__main__": - global bf - bf = None trainer = None try: import multiprocessing @@ -64,4 +62,4 @@ print(e) print(traceback.format_exc()) if trainer is not None and trainer.bf is not None: - bf.stop_fetching() + trainer.bf.stop_fetching()