Skip to content

Commit 9440411

Browse files
committed
make self conditioning technique work with diffusion prior
1 parent 981d407 commit 9440411

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

dalle2_pytorch/dalle2_pytorch.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -937,9 +937,12 @@ def __init__(
937937
num_image_embeds = 1,
938938
num_text_embeds = 1,
939939
max_text_len = 256,
940+
self_cond = False,
940941
**kwargs
941942
):
942943
super().__init__()
944+
self.dim = dim
945+
943946
self.num_time_embeds = num_time_embeds
944947
self.num_image_embeds = num_image_embeds
945948
self.num_text_embeds = num_text_embeds
@@ -967,6 +970,10 @@ def __init__(
967970
self.max_text_len = max_text_len
968971
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
969972

973+
# whether to use self conditioning, Hinton's group's new ddpm technique
974+
975+
self.self_cond = self_cond
976+
970977
def forward_with_cond_scale(
971978
self,
972979
*args,
@@ -988,12 +995,19 @@ def forward(
988995
*,
989996
text_embed,
990997
text_encodings = None,
998+
self_cond = None,
991999
cond_drop_prob = 0.
9921000
):
9931001
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
9941002

9951003
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
9961004

1005+
# setup self conditioning
1006+
1007+
self_cond = None
1008+
if self.self_cond:
1009+
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
1010+
9971011
# in section 2.2, last paragraph
9981012
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
9991013

@@ -1043,13 +1057,16 @@ def forward(
10431057
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
10441058
# but let's just do it right
10451059

1046-
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
1060+
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
10471061
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
10481062

10491063
time_embed = self.to_time_embeds(diffusion_timesteps)
10501064

10511065
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
10521066

1067+
if self.self_cond:
1068+
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
1069+
10531070
tokens = torch.cat((
10541071
text_encodings,
10551072
text_embed,
@@ -1151,10 +1168,10 @@ def device(self):
11511168
def l2norm_clamp_embed(self, image_embed):
11521169
return l2norm(image_embed) * self.image_embed_scale
11531170

1154-
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
1171+
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
11551172
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
11561173

1157-
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
1174+
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
11581175

11591176
if self.predict_x_start:
11601177
x_start = pred
@@ -1168,28 +1185,33 @@ def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1
11681185
x_start = l2norm(x_start) * self.image_embed_scale
11691186

11701187
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
1171-
return model_mean, posterior_variance, posterior_log_variance
1188+
return model_mean, posterior_variance, posterior_log_variance, x_start
11721189

11731190
@torch.no_grad()
1174-
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
1191+
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
11751192
b, *_, device = *x.shape, x.device
1176-
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
1193+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
11771194
noise = torch.randn_like(x)
11781195
# no noise when t == 0
11791196
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1180-
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1197+
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1198+
return pred, x_start
11811199

11821200
@torch.no_grad()
11831201
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
11841202
batch, device = shape[0], self.device
1203+
11851204
image_embed = torch.randn(shape, device = device)
1205+
x_start = None # for self-conditioning
11861206

11871207
if self.init_image_embed_l2norm:
11881208
image_embed = l2norm(image_embed) * self.image_embed_scale
11891209

11901210
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
11911211
times = torch.full((batch,), i, device = device, dtype = torch.long)
1192-
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
1212+
1213+
self_cond = x_start if self.net.self_cond else None
1214+
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
11931215

11941216
if self.sampling_final_clamp_l2norm and self.predict_x_start:
11951217
image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1207,6 +1229,8 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
12071229

12081230
image_embed = torch.randn(shape, device = device)
12091231

1232+
x_start = None # for self-conditioning
1233+
12101234
if self.init_image_embed_l2norm:
12111235
image_embed = l2norm(image_embed) * self.image_embed_scale
12121236

@@ -1216,7 +1240,9 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
12161240

12171241
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
12181242

1219-
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
1243+
self_cond = x_start if self.net.self_cond else None
1244+
1245+
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
12201246

12211247
if self.predict_x_start:
12221248
x_start = pred
@@ -1260,9 +1286,15 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
12601286

12611287
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
12621288

1289+
self_cond = None
1290+
if self.net.self_cond and random.random() < 0.5:
1291+
with torch.no_grad():
1292+
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
1293+
12631294
pred = self.net(
12641295
image_embed_noisy,
12651296
times,
1297+
self_cond = self_cond,
12661298
cond_drop_prob = self.cond_drop_prob,
12671299
**text_cond
12681300
)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.0'
1+
__version__ = '1.6.1'

0 commit comments

Comments
 (0)