Skip to content

Commit dc816b1

Browse files
committed
dry up some code around handling unet outputs with learned variance
1 parent 05192ff commit dc816b1

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

dalle2_pytorch/dalle2_pytorch.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
NAT = 1. / math.log(2.)
4040

41+
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
42+
4143
# helper functions
4244

4345
def exists(val):
@@ -2584,6 +2586,14 @@ def get_unet(self, unet_number):
25842586
index = unet_number - 1
25852587
return self.unets[index]
25862588

2589+
def parse_unet_output(self, learned_variance, output):
2590+
var_interp_frac_unnormalized = None
2591+
2592+
if learned_variance:
2593+
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
2594+
2595+
return UnetOutput(output, var_interp_frac_unnormalized)
2596+
25872597
@contextmanager
25882598
def one_unet_in_gpu(self, unet_number = None, unet = None):
25892599
assert exists(unet_number) ^ exists(unet)
@@ -2625,10 +2635,9 @@ def dynamic_threshold(self, x):
26252635
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
26262636
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
26272637

2628-
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
2638+
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
26292639

2630-
if learned_variance:
2631-
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
2640+
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
26322641

26332642
if predict_x_start:
26342643
x_start = pred
@@ -2811,10 +2820,9 @@ def p_sample_loop_ddim(
28112820

28122821
self_cond = x_start if unet.self_cond else None
28132822

2814-
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
2823+
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
28152824

2816-
if learned_variance:
2817-
pred, _ = pred.chunk(2, dim = 1)
2825+
pred, _ = self.parse_unet_output(learned_variance, unet_output)
28182826

28192827
if predict_x_start:
28202828
x_start = pred
@@ -2886,16 +2894,13 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
28862894

28872895
if unet.self_cond and random.random() < 0.5:
28882896
with torch.no_grad():
2889-
self_cond = unet(x_noisy, times, **unet_kwargs)
2890-
2891-
if learned_variance:
2892-
self_cond, _ = self_cond.chunk(2, dim = 1)
2893-
2897+
unet_output = unet(x_noisy, times, **unet_kwargs)
2898+
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
28942899
self_cond = self_cond.detach()
28952900

28962901
# forward to get model prediction
28972902

2898-
model_output = unet(
2903+
unet_output = unet(
28992904
x_noisy,
29002905
times,
29012906
**unet_kwargs,
@@ -2904,10 +2909,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
29042909
text_cond_drop_prob = self.text_cond_drop_prob,
29052910
)
29062911

2907-
if learned_variance:
2908-
pred, _ = model_output.chunk(2, dim = 1)
2909-
else:
2910-
pred = model_output
2912+
pred, _ = self.parse_unet_output(learned_variance, unet_output)
29112913

29122914
target = noise if not predict_x_start else x_start
29132915

@@ -2930,7 +2932,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
29302932
# if learning the variance, also include the extra weight kl loss
29312933

29322934
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
2933-
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
2935+
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
29342936

29352937
# kl loss with detached model predicted mean, for stability reasons as in paper
29362938

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.3'
1+
__version__ = '1.6.4'

0 commit comments

Comments
 (0)