38
38
39
39
NAT = 1. / math .log (2. )
40
40
41
+ UnetOutput = namedtuple ('UnetOutput' , ['pred' , 'var_interp_frac_unnormalized' ])
42
+
41
43
# helper functions
42
44
43
45
def exists (val ):
@@ -2584,6 +2586,14 @@ def get_unet(self, unet_number):
2584
2586
index = unet_number - 1
2585
2587
return self .unets [index ]
2586
2588
2589
+ def parse_unet_output (self , learned_variance , output ):
2590
+ var_interp_frac_unnormalized = None
2591
+
2592
+ if learned_variance :
2593
+ output , var_interp_frac_unnormalized = output .chunk (2 , dim = 1 )
2594
+
2595
+ return UnetOutput (output , var_interp_frac_unnormalized )
2596
+
2587
2597
@contextmanager
2588
2598
def one_unet_in_gpu (self , unet_number = None , unet = None ):
2589
2599
assert exists (unet_number ) ^ exists (unet )
@@ -2625,10 +2635,9 @@ def dynamic_threshold(self, x):
2625
2635
def p_mean_variance (self , unet , x , t , image_embed , noise_scheduler , text_encodings = None , lowres_cond_img = None , self_cond = None , clip_denoised = True , predict_x_start = False , learned_variance = False , cond_scale = 1. , model_output = None , lowres_noise_level = None ):
2626
2636
assert not (cond_scale != 1. and not self .can_classifier_guidance ), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
2627
2637
2628
- pred = default (model_output , lambda : unet .forward_with_cond_scale (x , t , image_embed = image_embed , text_encodings = text_encodings , cond_scale = cond_scale , lowres_cond_img = lowres_cond_img , self_cond = self_cond , lowres_noise_level = lowres_noise_level ))
2638
+ model_output = default (model_output , lambda : unet .forward_with_cond_scale (x , t , image_embed = image_embed , text_encodings = text_encodings , cond_scale = cond_scale , lowres_cond_img = lowres_cond_img , self_cond = self_cond , lowres_noise_level = lowres_noise_level ))
2629
2639
2630
- if learned_variance :
2631
- pred , var_interp_frac_unnormalized = pred .chunk (2 , dim = 1 )
2640
+ pred , var_interp_frac_unnormalized = self .parse_unet_output (learned_variance , model_output )
2632
2641
2633
2642
if predict_x_start :
2634
2643
x_start = pred
@@ -2811,10 +2820,9 @@ def p_sample_loop_ddim(
2811
2820
2812
2821
self_cond = x_start if unet .self_cond else None
2813
2822
2814
- pred = unet .forward_with_cond_scale (img , time_cond , image_embed = image_embed , text_encodings = text_encodings , cond_scale = cond_scale , self_cond = self_cond , lowres_cond_img = lowres_cond_img , lowres_noise_level = lowres_noise_level )
2823
+ unet_output = unet .forward_with_cond_scale (img , time_cond , image_embed = image_embed , text_encodings = text_encodings , cond_scale = cond_scale , self_cond = self_cond , lowres_cond_img = lowres_cond_img , lowres_noise_level = lowres_noise_level )
2815
2824
2816
- if learned_variance :
2817
- pred , _ = pred .chunk (2 , dim = 1 )
2825
+ pred , _ = self .parse_unet_output (learned_variance , unet_output )
2818
2826
2819
2827
if predict_x_start :
2820
2828
x_start = pred
@@ -2886,16 +2894,13 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
2886
2894
2887
2895
if unet .self_cond and random .random () < 0.5 :
2888
2896
with torch .no_grad ():
2889
- self_cond = unet (x_noisy , times , ** unet_kwargs )
2890
-
2891
- if learned_variance :
2892
- self_cond , _ = self_cond .chunk (2 , dim = 1 )
2893
-
2897
+ unet_output = unet (x_noisy , times , ** unet_kwargs )
2898
+ self_cond , _ = self .parse_unet_output (learned_variance , unet_output )
2894
2899
self_cond = self_cond .detach ()
2895
2900
2896
2901
# forward to get model prediction
2897
2902
2898
- model_output = unet (
2903
+ unet_output = unet (
2899
2904
x_noisy ,
2900
2905
times ,
2901
2906
** unet_kwargs ,
@@ -2904,10 +2909,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
2904
2909
text_cond_drop_prob = self .text_cond_drop_prob ,
2905
2910
)
2906
2911
2907
- if learned_variance :
2908
- pred , _ = model_output .chunk (2 , dim = 1 )
2909
- else :
2910
- pred = model_output
2912
+ pred , _ = self .parse_unet_output (learned_variance , unet_output )
2911
2913
2912
2914
target = noise if not predict_x_start else x_start
2913
2915
@@ -2930,7 +2932,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
2930
2932
# if learning the variance, also include the extra weight kl loss
2931
2933
2932
2934
true_mean , _ , true_log_variance_clipped = noise_scheduler .q_posterior (x_start = x_start , x_t = x_noisy , t = times )
2933
- model_mean , _ , model_log_variance , _ = self .p_mean_variance (unet , x = x_noisy , t = times , image_embed = image_embed , noise_scheduler = noise_scheduler , clip_denoised = clip_denoised , learned_variance = True , model_output = model_output )
2935
+ model_mean , _ , model_log_variance , _ = self .p_mean_variance (unet , x = x_noisy , t = times , image_embed = image_embed , noise_scheduler = noise_scheduler , clip_denoised = clip_denoised , learned_variance = True , model_output = unet_output )
2934
2936
2935
2937
# kl loss with detached model predicted mean, for stability reasons as in paper
2936
2938
0 commit comments