Skip to content

Commit 3480666

Browse files
committed
make it so diffusion prior p_sample_loop returns unnormalized image embeddings
1 parent dc816b1 commit 3480666

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

dalle2_pytorch/dalle2_pytorch.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1279,9 +1279,12 @@ def p_sample_loop(self, *args, timesteps = None, **kwargs):
12791279
is_ddim = timesteps < self.noise_scheduler.num_timesteps
12801280

12811281
if not is_ddim:
1282-
return self.p_sample_loop_ddpm(*args, **kwargs)
1282+
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
1283+
else:
1284+
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
12831285

1284-
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
1286+
image_embed = normalized_image_embed / self.image_embed_scale
1287+
return image_embed
12851288

12861289
def p_losses(self, image_embed, times, text_cond, noise = None):
12871290
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1350,8 +1353,6 @@ def sample(
13501353

13511354
# retrieve original unscaled image embed
13521355

1353-
image_embeds /= self.image_embed_scale
1354-
13551356
text_embeds = text_cond['text_embed']
13561357

13571358
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.4'
1+
__version__ = '1.6.5'

0 commit comments

Comments
 (0)