@@ -1451,6 +1451,30 @@ def Downsample(dim, *, dim_out = None):
1451
1451
dim_out = default (dim_out , dim )
1452
1452
return nn .Conv2d (dim , dim_out , 4 , 2 , 1 )
1453
1453
1454
+ class WeightStandardizedConv2d (nn .Conv2d ):
1455
+ """
1456
+ https://arxiv.org/abs/1903.10520
1457
+ weight standardization purportedly works synergistically with group normalization
1458
+ """
1459
+
1460
+ def __init__ (self , * args , ** kwargs ):
1461
+ super ().__init__ (* args , ** kwargs )
1462
+
1463
+ def forward (self , x ):
1464
+ eps = 1e-5 if x .dtype == torch .float32 else 1e-3
1465
+
1466
+ weight = self .weight
1467
+ flattened_weights = rearrange (weight , 'o ... -> o (...)' )
1468
+
1469
+ mean = reduce (weight , 'o ... -> o 1 1 1' , 'mean' )
1470
+
1471
+ var = torch .var (flattened_weights , dim = - 1 , unbiased = False )
1472
+ var = rearrange (var , 'o -> o 1 1 1' )
1473
+
1474
+ weight = (weight - mean ) * (var + eps ).rsqrt ()
1475
+
1476
+ return F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
1477
+
1454
1478
class SinusoidalPosEmb (nn .Module ):
1455
1479
def __init__ (self , dim ):
1456
1480
super ().__init__ ()
@@ -1469,10 +1493,13 @@ def __init__(
1469
1493
self ,
1470
1494
dim ,
1471
1495
dim_out ,
1472
- groups = 8
1496
+ groups = 8 ,
1497
+ weight_standardization = False
1473
1498
):
1474
1499
super ().__init__ ()
1475
- self .project = nn .Conv2d (dim , dim_out , 3 , padding = 1 )
1500
+ conv_klass = nn .Conv2d if not weight_standardization else WeightStandardizedConv2d
1501
+
1502
+ self .project = conv_klass (dim , dim_out , 3 , padding = 1 )
1476
1503
self .norm = nn .GroupNorm (groups , dim_out )
1477
1504
self .act = nn .SiLU ()
1478
1505
@@ -1496,6 +1523,7 @@ def __init__(
1496
1523
cond_dim = None ,
1497
1524
time_cond_dim = None ,
1498
1525
groups = 8 ,
1526
+ weight_standardization = False ,
1499
1527
cosine_sim_cross_attn = False
1500
1528
):
1501
1529
super ().__init__ ()
@@ -1521,8 +1549,8 @@ def __init__(
1521
1549
)
1522
1550
)
1523
1551
1524
- self .block1 = Block (dim , dim_out , groups = groups )
1525
- self .block2 = Block (dim_out , dim_out , groups = groups )
1552
+ self .block1 = Block (dim , dim_out , groups = groups , weight_standardization = weight_standardization )
1553
+ self .block2 = Block (dim_out , dim_out , groups = groups , weight_standardization = weight_standardization )
1526
1554
self .res_conv = nn .Conv2d (dim , dim_out , 1 ) if dim != dim_out else nn .Identity ()
1527
1555
1528
1556
def forward (self , x , time_emb = None , cond = None ):
@@ -1747,6 +1775,7 @@ def __init__(
1747
1775
init_dim = None ,
1748
1776
init_conv_kernel_size = 7 ,
1749
1777
resnet_groups = 8 ,
1778
+ resnet_weight_standardization = False ,
1750
1779
num_resnet_blocks = 2 ,
1751
1780
init_cross_embed = True ,
1752
1781
init_cross_embed_kernel_sizes = (3 , 7 , 15 ),
@@ -1894,7 +1923,7 @@ def __init__(
1894
1923
1895
1924
# prepare resnet klass
1896
1925
1897
- resnet_block = partial (ResnetBlock , cosine_sim_cross_attn = cosine_sim_cross_attn )
1926
+ resnet_block = partial (ResnetBlock , cosine_sim_cross_attn = cosine_sim_cross_attn , weight_standardization = resnet_weight_standardization )
1898
1927
1899
1928
# give memory efficient unet an initial resnet block
1900
1929
0 commit comments