Skip to content

Commit 301a971

Browse files
committed
fix self conditioning shape in diffusion prior
1 parent 9440411 commit 301a971

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

dalle2_pytorch/dalle2_pytorch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1004,9 +1004,9 @@ def forward(
10041004

10051005
# setup self conditioning
10061006

1007-
self_cond = None
10081007
if self.self_cond:
1009-
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
1008+
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
1009+
self_cond = rearrange(self_cond, 'b d -> b 1 d')
10101010

10111011
# in section 2.2, last paragraph
10121012
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -1287,7 +1287,7 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
12871287
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
12881288

12891289
self_cond = None
1290-
if self.net.self_cond and random.random() < 0.5:
1290+
if self.net.self_cond and random.random() < 1.5:
12911291
with torch.no_grad():
12921292
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
12931293

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.1'
1+
__version__ = '1.6.2'

0 commit comments

Comments
 (0)