@@ -937,9 +937,12 @@ def __init__(
937
937
num_image_embeds = 1 ,
938
938
num_text_embeds = 1 ,
939
939
max_text_len = 256 ,
940
+ self_cond = False ,
940
941
** kwargs
941
942
):
942
943
super ().__init__ ()
944
+ self .dim = dim
945
+
943
946
self .num_time_embeds = num_time_embeds
944
947
self .num_image_embeds = num_image_embeds
945
948
self .num_text_embeds = num_text_embeds
@@ -967,6 +970,10 @@ def __init__(
967
970
self .max_text_len = max_text_len
968
971
self .null_text_embed = nn .Parameter (torch .randn (1 , max_text_len , dim ))
969
972
973
+ # whether to use self conditioning, Hinton's group's new ddpm technique
974
+
975
+ self .self_cond = self_cond
976
+
970
977
def forward_with_cond_scale (
971
978
self ,
972
979
* args ,
@@ -988,12 +995,19 @@ def forward(
988
995
* ,
989
996
text_embed ,
990
997
text_encodings = None ,
998
+ self_cond = None ,
991
999
cond_drop_prob = 0.
992
1000
):
993
1001
batch , dim , device , dtype = * image_embed .shape , image_embed .device , image_embed .dtype
994
1002
995
1003
num_time_embeds , num_image_embeds , num_text_embeds = self .num_time_embeds , self .num_image_embeds , self .num_text_embeds
996
1004
1005
+ # setup self conditioning
1006
+
1007
+ self_cond = None
1008
+ if self .self_cond :
1009
+ self_cond = default (self_cond , lambda : torch .zeros (batch , 1 , self .dim , device = device , dtype = dtype ))
1010
+
997
1011
# in section 2.2, last paragraph
998
1012
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
999
1013
@@ -1043,13 +1057,16 @@ def forward(
1043
1057
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
1044
1058
# but let's just do it right
1045
1059
1046
- attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
1060
+ attend_padding = 1 + num_time_embeds + num_image_embeds + int ( self . self_cond ) # 1 for learned queries + number of image embeds + time embeds
1047
1061
mask = F .pad (mask , (0 , attend_padding ), value = True ) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
1048
1062
1049
1063
time_embed = self .to_time_embeds (diffusion_timesteps )
1050
1064
1051
1065
learned_queries = repeat (self .learned_query , 'd -> b 1 d' , b = batch )
1052
1066
1067
+ if self .self_cond :
1068
+ learned_queries = torch .cat ((image_embed , self_cond ), dim = - 2 )
1069
+
1053
1070
tokens = torch .cat ((
1054
1071
text_encodings ,
1055
1072
text_embed ,
@@ -1151,10 +1168,10 @@ def device(self):
1151
1168
def l2norm_clamp_embed (self , image_embed ):
1152
1169
return l2norm (image_embed ) * self .image_embed_scale
1153
1170
1154
- def p_mean_variance (self , x , t , text_cond , clip_denoised = False , cond_scale = 1. ):
1171
+ def p_mean_variance (self , x , t , text_cond , self_cond = None , clip_denoised = False , cond_scale = 1. ):
1155
1172
assert not (cond_scale != 1. and not self .can_classifier_guidance ), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
1156
1173
1157
- pred = self .net .forward_with_cond_scale (x , t , cond_scale = cond_scale , ** text_cond )
1174
+ pred = self .net .forward_with_cond_scale (x , t , cond_scale = cond_scale , self_cond = self_cond , ** text_cond )
1158
1175
1159
1176
if self .predict_x_start :
1160
1177
x_start = pred
@@ -1168,28 +1185,33 @@ def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1
1168
1185
x_start = l2norm (x_start ) * self .image_embed_scale
1169
1186
1170
1187
model_mean , posterior_variance , posterior_log_variance = self .noise_scheduler .q_posterior (x_start = x_start , x_t = x , t = t )
1171
- return model_mean , posterior_variance , posterior_log_variance
1188
+ return model_mean , posterior_variance , posterior_log_variance , x_start
1172
1189
1173
1190
@torch .no_grad ()
1174
- def p_sample (self , x , t , text_cond = None , clip_denoised = True , cond_scale = 1. ):
1191
+ def p_sample (self , x , t , text_cond = None , self_cond = None , clip_denoised = True , cond_scale = 1. ):
1175
1192
b , * _ , device = * x .shape , x .device
1176
- model_mean , _ , model_log_variance = self .p_mean_variance (x = x , t = t , text_cond = text_cond , clip_denoised = clip_denoised , cond_scale = cond_scale )
1193
+ model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = t , text_cond = text_cond , self_cond = self_cond , clip_denoised = clip_denoised , cond_scale = cond_scale )
1177
1194
noise = torch .randn_like (x )
1178
1195
# no noise when t == 0
1179
1196
nonzero_mask = (1 - (t == 0 ).float ()).reshape (b , * ((1 ,) * (len (x .shape ) - 1 )))
1180
- return model_mean + nonzero_mask * (0.5 * model_log_variance ).exp () * noise
1197
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance ).exp () * noise
1198
+ return pred , x_start
1181
1199
1182
1200
@torch .no_grad ()
1183
1201
def p_sample_loop_ddpm (self , shape , text_cond , cond_scale = 1. ):
1184
1202
batch , device = shape [0 ], self .device
1203
+
1185
1204
image_embed = torch .randn (shape , device = device )
1205
+ x_start = None # for self-conditioning
1186
1206
1187
1207
if self .init_image_embed_l2norm :
1188
1208
image_embed = l2norm (image_embed ) * self .image_embed_scale
1189
1209
1190
1210
for i in tqdm (reversed (range (0 , self .noise_scheduler .num_timesteps )), desc = 'sampling loop time step' , total = self .noise_scheduler .num_timesteps ):
1191
1211
times = torch .full ((batch ,), i , device = device , dtype = torch .long )
1192
- image_embed = self .p_sample (image_embed , times , text_cond = text_cond , cond_scale = cond_scale )
1212
+
1213
+ self_cond = x_start if self .net .self_cond else None
1214
+ image_embed , x_start = self .p_sample (image_embed , times , text_cond = text_cond , self_cond = self_cond , cond_scale = cond_scale )
1193
1215
1194
1216
if self .sampling_final_clamp_l2norm and self .predict_x_start :
1195
1217
image_embed = self .l2norm_clamp_embed (image_embed )
@@ -1207,6 +1229,8 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
1207
1229
1208
1230
image_embed = torch .randn (shape , device = device )
1209
1231
1232
+ x_start = None # for self-conditioning
1233
+
1210
1234
if self .init_image_embed_l2norm :
1211
1235
image_embed = l2norm (image_embed ) * self .image_embed_scale
1212
1236
@@ -1216,7 +1240,9 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
1216
1240
1217
1241
time_cond = torch .full ((batch ,), time , device = device , dtype = torch .long )
1218
1242
1219
- pred = self .net .forward_with_cond_scale (image_embed , time_cond , cond_scale = cond_scale , ** text_cond )
1243
+ self_cond = x_start if self .net .self_cond else None
1244
+
1245
+ pred = self .net .forward_with_cond_scale (image_embed , time_cond , self_cond = self_cond , cond_scale = cond_scale , ** text_cond )
1220
1246
1221
1247
if self .predict_x_start :
1222
1248
x_start = pred
@@ -1260,9 +1286,15 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
1260
1286
1261
1287
image_embed_noisy = self .noise_scheduler .q_sample (x_start = image_embed , t = times , noise = noise )
1262
1288
1289
+ self_cond = None
1290
+ if self .net .self_cond and random .random () < 0.5 :
1291
+ with torch .no_grad ():
1292
+ self_cond = self .net (image_embed_noisy , times , ** text_cond ).detach ()
1293
+
1263
1294
pred = self .net (
1264
1295
image_embed_noisy ,
1265
1296
times ,
1297
+ self_cond = self_cond ,
1266
1298
cond_drop_prob = self .cond_drop_prob ,
1267
1299
** text_cond
1268
1300
)
0 commit comments