Skip to content

Commit f8b0055

Browse files
committed
add weight standardization behind feature flag, which may potentially work well with group norm
1 parent 3480666 commit f8b0055

File tree

3 files changed

+45
-6
lines changed

3 files changed

+45
-6
lines changed

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d
12641264
}
12651265
```
12661266

1267+
```bibtex
1268+
@article{Qiao2019WeightS,
1269+
title = {Weight Standardization},
1270+
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
1271+
journal = {ArXiv},
1272+
year = {2019},
1273+
volume = {abs/1903.10520}
1274+
}
1275+
```
1276+
12671277
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

dalle2_pytorch/dalle2_pytorch.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,30 @@ def Downsample(dim, *, dim_out = None):
14511451
dim_out = default(dim_out, dim)
14521452
return nn.Conv2d(dim, dim_out, 4, 2, 1)
14531453

1454+
class WeightStandardizedConv2d(nn.Conv2d):
1455+
"""
1456+
https://arxiv.org/abs/1903.10520
1457+
weight standardization purportedly works synergistically with group normalization
1458+
"""
1459+
1460+
def __init__(self, *args, **kwargs):
1461+
super().__init__(*args, **kwargs)
1462+
1463+
def forward(self, x):
1464+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
1465+
1466+
weight = self.weight
1467+
flattened_weights = rearrange(weight, 'o ... -> o (...)')
1468+
1469+
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
1470+
1471+
var = torch.var(flattened_weights, dim = -1, unbiased = False)
1472+
var = rearrange(var, 'o -> o 1 1 1')
1473+
1474+
weight = (weight - mean) * (var + eps).rsqrt()
1475+
1476+
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
1477+
14541478
class SinusoidalPosEmb(nn.Module):
14551479
def __init__(self, dim):
14561480
super().__init__()
@@ -1469,10 +1493,13 @@ def __init__(
14691493
self,
14701494
dim,
14711495
dim_out,
1472-
groups = 8
1496+
groups = 8,
1497+
weight_standardization = False
14731498
):
14741499
super().__init__()
1475-
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
1500+
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
1501+
1502+
self.project = conv_klass(dim, dim_out, 3, padding = 1)
14761503
self.norm = nn.GroupNorm(groups, dim_out)
14771504
self.act = nn.SiLU()
14781505

@@ -1496,6 +1523,7 @@ def __init__(
14961523
cond_dim = None,
14971524
time_cond_dim = None,
14981525
groups = 8,
1526+
weight_standardization = False,
14991527
cosine_sim_cross_attn = False
15001528
):
15011529
super().__init__()
@@ -1521,8 +1549,8 @@ def __init__(
15211549
)
15221550
)
15231551

1524-
self.block1 = Block(dim, dim_out, groups = groups)
1525-
self.block2 = Block(dim_out, dim_out, groups = groups)
1552+
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
1553+
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
15261554
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
15271555

15281556
def forward(self, x, time_emb = None, cond = None):
@@ -1747,6 +1775,7 @@ def __init__(
17471775
init_dim = None,
17481776
init_conv_kernel_size = 7,
17491777
resnet_groups = 8,
1778+
resnet_weight_standardization = False,
17501779
num_resnet_blocks = 2,
17511780
init_cross_embed = True,
17521781
init_cross_embed_kernel_sizes = (3, 7, 15),
@@ -1894,7 +1923,7 @@ def __init__(
18941923

18951924
# prepare resnet klass
18961925

1897-
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
1926+
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
18981927

18991928
# give memory efficient unet an initial resnet block
19001929

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.5'
1+
__version__ = '1.7.0'

0 commit comments

Comments
 (0)