12
12
from torch import nn , einsum
13
13
import torchvision .transforms as T
14
14
15
- from einops import rearrange , repeat , reduce
15
+ from einops import rearrange , repeat , reduce , pack , unpack
16
16
from einops .layers .torch import Rearrange
17
- from einops_exts import rearrange_many , repeat_many , check_shape
18
- from einops_exts .torch import EinopsToAndFrom
19
17
20
18
from kornia .filters import gaussian_blur2d
21
19
import kornia .augmentation as K
@@ -669,6 +667,23 @@ def p2_reweigh_loss(self, loss, times):
669
667
return loss
670
668
return loss * extract (self .p2_loss_weight , times , loss .shape )
671
669
670
+ # rearrange image to sequence
671
+
672
+ class RearrangeToSequence (nn .Module ):
673
+ def __init__ (self , fn ):
674
+ super ().__init__ ()
675
+ self .fn = fn
676
+
677
+ def forward (self , x ):
678
+ x = rearrange (x , 'b c ... -> b ... c' )
679
+ x , ps = pack ([x ], 'b * c' )
680
+
681
+ x = self .fn (x )
682
+
683
+ x , = unpack (x , ps , 'b * c' )
684
+ x = rearrange (x , 'b ... c -> b c ...' )
685
+ return x
686
+
672
687
# diffusion prior
673
688
674
689
class LayerNorm (nn .Module ):
@@ -867,7 +882,7 @@ def forward(self, x, mask = None, attn_bias = None):
867
882
868
883
# add null key / value for classifier free guidance in prior net
869
884
870
- nk , nv = repeat_many ( self . null_kv . unbind ( dim = - 2 ) , 'd -> b 1 d' , b = b )
885
+ nk , nv = map ( lambda t : repeat ( t , 'd -> b 1 d' , b = b ), self . null_kv . unbind ( dim = - 2 ) )
871
886
k = torch .cat ((nk , k ), dim = - 2 )
872
887
v = torch .cat ((nv , v ), dim = - 2 )
873
888
@@ -1629,14 +1644,10 @@ def __init__(
1629
1644
self .cross_attn = None
1630
1645
1631
1646
if exists (cond_dim ):
1632
- self .cross_attn = EinopsToAndFrom (
1633
- 'b c h w' ,
1634
- 'b (h w) c' ,
1635
- CrossAttention (
1636
- dim = dim_out ,
1637
- context_dim = cond_dim ,
1638
- cosine_sim = cosine_sim_cross_attn
1639
- )
1647
+ self .cross_attn = CrossAttention (
1648
+ dim = dim_out ,
1649
+ context_dim = cond_dim ,
1650
+ cosine_sim = cosine_sim_cross_attn
1640
1651
)
1641
1652
1642
1653
self .block1 = Block (dim , dim_out , groups = groups , weight_standardization = weight_standardization )
@@ -1655,8 +1666,15 @@ def forward(self, x, time_emb = None, cond = None):
1655
1666
1656
1667
if exists (self .cross_attn ):
1657
1668
assert exists (cond )
1669
+
1670
+ h = rearrange (h , 'b c ... -> b ... c' )
1671
+ h , ps = pack ([h ], 'b * c' )
1672
+
1658
1673
h = self .cross_attn (h , context = cond ) + h
1659
1674
1675
+ h , = unpack (h , ps , 'b * c' )
1676
+ h = rearrange (h , 'b ... c -> b c ...' )
1677
+
1660
1678
h = self .block2 (h )
1661
1679
return h + self .res_conv (x )
1662
1680
@@ -1702,11 +1720,11 @@ def forward(self, x, context, mask = None):
1702
1720
1703
1721
q , k , v = (self .to_q (x ), * self .to_kv (context ).chunk (2 , dim = - 1 ))
1704
1722
1705
- q , k , v = rearrange_many (( q , k , v ) , 'b n (h d) -> b h n d' , h = self .heads )
1723
+ q , k , v = map ( lambda t : rearrange ( t , 'b n (h d) -> b h n d' , h = self .heads ), ( q , k , v ) )
1706
1724
1707
1725
# add null key / value for classifier free guidance in prior net
1708
1726
1709
- nk , nv = repeat_many ( self . null_kv . unbind ( dim = - 2 ) , 'd -> b h 1 d' , h = self .heads , b = b )
1727
+ nk , nv = map ( lambda t : repeat ( t , 'd -> b h 1 d' , h = self .heads , b = b ), self . null_kv . unbind ( dim = - 2 ) )
1710
1728
1711
1729
k = torch .cat ((nk , k ), dim = - 2 )
1712
1730
v = torch .cat ((nv , v ), dim = - 2 )
@@ -1759,7 +1777,7 @@ def forward(self, fmap):
1759
1777
1760
1778
fmap = self .norm (fmap )
1761
1779
q , k , v = self .to_qkv (fmap ).chunk (3 , dim = 1 )
1762
- q , k , v = rearrange_many (( q , k , v ) , 'b (h c) x y -> (b h) (x y) c' , h = h )
1780
+ q , k , v = map ( lambda t : rearrange ( t , 'b (h c) x y -> (b h) (x y) c' , h = h ), ( q , k , v ) )
1763
1781
1764
1782
q = q .softmax (dim = - 1 )
1765
1783
k = k .softmax (dim = - 2 )
@@ -1993,7 +2011,7 @@ def __init__(
1993
2011
1994
2012
self_attn = cast_tuple (self_attn , num_stages )
1995
2013
1996
- create_self_attn = lambda dim : EinopsToAndFrom ( 'b c h w' , 'b (h w) c' , Residual (Attention (dim , ** attn_kwargs )))
2014
+ create_self_attn = lambda dim : RearrangeToSequence ( Residual (Attention (dim , ** attn_kwargs )))
1997
2015
1998
2016
# resnet block klass
1999
2017
@@ -3230,7 +3248,7 @@ def forward(
3230
3248
learned_variance = self .learned_variance [unet_index ]
3231
3249
b , c , h , w , device , = * image .shape , image .device
3232
3250
3233
- check_shape ( image , 'b c h w' , c = self .channels )
3251
+ assert image . shape [ 1 ] == self .channels
3234
3252
assert h >= target_image_size and w >= target_image_size
3235
3253
3236
3254
times = torch .randint (0 , noise_scheduler .num_timesteps , (b ,), device = device , dtype = torch .long )
0 commit comments