diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py index 9c6265fa..e9e7e847 100644 --- a/scripts/train_stage1.py +++ b/scripts/train_stage1.py @@ -16,6 +16,7 @@ """ import argparse +import copy import logging import math import os @@ -211,6 +212,7 @@ def log_validation( logger.info("Running validation... ") ori_net = accelerator.unwrap_model(net) + ori_net = copy.deepcopy(ori_net) reference_unet = ori_net.reference_unet denoising_unet = ori_net.denoising_unet face_locator = ori_net.face_locator @@ -278,6 +280,7 @@ def log_validation( canvas.save(out_file) del pipe + del ori_net torch.cuda.empty_cache() return pil_images