-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: Add Weight Normalization Support (weight_norm) #1888
Comments
Is the reference code you posted working as expected? If not, what's the issue with it? |
It works well for Linear, but I can't get it to match the output torch for Conv layers from torch, even if I override the weights with MLX ones. The Conv are my focus. |
Example usage with conv1d: MLXimport mlx.core as mx
import mlx.nn as nn
from typing import Optional, Any
from dataclasses import dataclass
import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple
from torch.nn.utils import weight_norm
from torch import nn
import torch
# Set seeds for reproducibility
mx.random.seed(42)
torch.manual_seed(42)
def compute_norm(x: mx.array,
p: int,
dim: Optional[Union[int, List[int]]] = None,
keepdim: bool = False) -> mx.array:
"""
Compute the p-norm of a tensor along specified dimensions.
Args:
x: Input array
p: Order of the norm (1 or 2)
dim: Dimension(s) along which to compute the norm
keepdim: Whether to keep the reduced dimensions
Returns:
MLX array containing the computed norm
"""
if p not in [1, 2]:
raise ValueError("Only p-norms with p of 1 or 2 are supported")
# Handle dimension input
if dim is None:
dim = tuple(range(x.ndim))
elif isinstance(dim, int):
dim = (dim,)
if p == 1:
# L1 norm
return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
else:
# L2 norm
return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))
def weight_norm(weight_v: mx.array,
weight_g: mx.array,
dim: Optional[int] = None) -> mx.array:
"""
Applies weight normalization to the input tensor.
Weight normalization reparameterizes weight vectors in a neural network
as a magnitude scalar times a direction vector: w = g * v/||v||
Args:
weight_v: Weight direction tensor (v)
weight_g: Weight magnitude tensor (g)
dim: Dimension along which to normalize. If None, normalize over all dims
except dim=-1
Returns:
Normalized weight tensor
"""
rank = len(weight_v.shape)
if dim is not None:
# Adjust negative dim
if dim < -1:
dim += rank
# Create list of axes to normalize over
axes = list(range(rank))
if dim != -1:
axes.remove(dim)
else:
# Default behavior: normalize over all dimensions
axes = list(range(rank))
# Compute L2 norm of v along specified axes
norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
# Normalize and scale by g: w = g * (v / ||v||)
normalized_weight = weight_v / (norm_v + 1e-7) # Add epsilon for numerical stability
return normalized_weight * weight_g
class WeightNormConv1d(nn.Module):
"""Conv1d layer with weight normalization"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1,
bias: bool = True, dim: int = 0, transpose_weight_g: bool = False):
super().__init__()
# Initialize weight parameters
weight_shape_g = (out_channels, 1, 1) if not transpose_weight_g else (in_channels, 1, 1)
weight_shape_v = (out_channels, in_channels, kernel_size)
# Store parameters
self.weight_g = mx.random.normal(weight_shape_g)
self.weight_v = mx.random.normal(weight_shape_v)
self.bias = mx.zeros(out_channels) if bias else None
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.dim = dim
# Compute normalized weight
self.weight = weight_norm(self.weight_v, self.weight_g, dim=0)
def __call__(self, x: mx.array) -> mx.array:
# Apply conv1d
out = mx.conv1d(x, self.weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups)
if self.bias is not None:
out = out + self.bias.reshape(1, 1, -1)
return out
# Example usage:
layer = WeightNormConv1d(in_channels=20, out_channels=40, kernel_size=3, padding=1)
x = mx.random.normal((16, 20, 3)) # batch_size=16, channels=20, length=3
output = layer(x).swapaxes(1, 2)
print(f"Output shape: {output.shape}") # Should be (16, 40, 3)
print(f"Weight_g shape: {layer.weight_g.shape}") # Should be (40, 1, 1)
print(f"Weight_v shape: {layer.weight_v.shape}") # Should be (40, 20, 3) torchm = weight_norm(nn.Conv1d(20, 40, kernel_size=3, padding=1, bias=False), name='weight')
torch_x = torch.from_dlpack(x)
torch_output = m(torch_x)
print(f"Output shape: {torch_output.shape}") # Should be (16, 40, 3)
print(m)
print(m.weight_g.size())
print(m.weight_v.size())
print(np.allclose(output, torch_output.detach().numpy(), rtol=1e-3, atol=1e-3), output.sum().tolist(), torch_output.detach().numpy().sum())
>> (False, 23.648128509521484, 18.668007) |
@Blaizzy I just submitted a PR that might address the problems you were facing and if merged could hopefully let you utilise weight_norm natively with good performance boost. Here's a summary (beyond what's in the PR)
My PR includes both core API and module wrapper implementations, along with convenience classes similar to PyTorch's. Hope this helps, and feel free to check out the implementation if it's merged, or adopt as you need regardless |
Improve the weight normalization implementation by: Use the optimized C++ mx.weight_norm() in WeightNormWrapper.call Add comprehensive tests for WeightNormConv2d Verify direct API usage matches module wrapper results Test normalization over multiple axes and edge cases Add specific test for GitHub issue ml-explore#1888 This change ensures maximum performance by leveraging the C++ implementation with its optimized handling of >2 axes normalization.
Thank you very much @cavit99, indeed it does solve it and matches the torch design! ❤️
For a bit of context, I didn't update my implementation here because I was focused on the MLX releases but I did reach out to Awni privately notifying him that I had found the solution and I was going to send a PR after I had rested bit. My updated solution that shipped with mlx-audio handles it and matches the torch implementation as I specified here: #1921 (comment)
I like your approach here! I followed the torch implementation of normalization.
Finally, the models I'm implementing at the moment only use weighted 1D convs so I didn't think beyod 1D but it's absolute joy knowing that now it will work with 2D out-of-the-box. How about 3D? |
Cool! sorry a bit of duplicate work then but at least it's got bindings too now! The weight norm should work with 3D and any higher dims See: # From weight_norm.py
elif "Conv" in self.wn_module_type and dim == 0:
weight_flat = mx.reshape(weight, (weight.shape[0], -1))
self.weight_g = mx.linalg.norm(weight_flat, axis=1, keepdims=True)
g_shape = [weight.shape[0]] + [1] * (weight.ndim - 1)
self.weight_g = mx.reshape(self.weight_g, g_shape) This code handles any convolution of any dimensionality by:
The C++ implementation in // If we have more than 2 axes, use the reshape approach
if (norm_axes.size() > 2) {
// Common case: keep one dimension (e.g., output channels)
int keep_dim = keep_axes[0];
std::vector<int> reshape_dims = {v.shape()[keep_dim], -1};
array v_reshaped = reshape(v, reshape_dims, s);
// Use the 2D norm kernel which is optimized
array v_norm = linalg::norm(v_reshaped, std::vector<int>{1}, true, s);
} The metal shaders optimized for 1D/2D data. By reshaping to a 2D tensor with output channels as one dimension and everything else flattened into the second dimension, we achieve the same mathematical result while using optimized kernels. |
No worries, I think you did a great job and gave it the necessary attention! |
Thank you very much once again! |
MLX currently lacks built-in support for weight normalization, which is a crucial feature for various deep learning architectures, particularly in audio processing and generative models. Weight normalization is a reparameterization technique that decouples the magnitude and direction of weight vectors, often leading to better conditioning and faster convergence.
Current Situation:
Proposed Solution:
I've developed a reference implementation that could serve as a starting point:
The text was updated successfully, but these errors were encountered: