Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Add Weight Normalization Support (weight_norm) #1888

Open
Blaizzy opened this issue Feb 19, 2025 · 8 comments · May be fixed by #1921
Open

Feature Request: Add Weight Normalization Support (weight_norm) #1888

Blaizzy opened this issue Feb 19, 2025 · 8 comments · May be fixed by #1921

Comments

@Blaizzy
Copy link

Blaizzy commented Feb 19, 2025

MLX currently lacks built-in support for weight normalization, which is a crucial feature for various deep learning architectures, particularly in audio processing and generative models. Weight normalization is a reparameterization technique that decouples the magnitude and direction of weight vectors, often leading to better conditioning and faster convergence.
Current Situation:

  • No built-in equivalent to PyTorch's torch.nn.utils.weight_norm
  • Users need to implement custom solutions, which may not be optimal or consistent

Proposed Solution:
I've developed a reference implementation that could serve as a starting point:

import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple

def compute_norm(x: mx.array, 
                p: int, 
                dim: Optional[Union[int, List[int]]] = None, 
                keepdim: bool = False) -> mx.array:
    """
    Compute the p-norm of a tensor along specified dimensions.
    
    Args:
        x: Input array
        p: Order of the norm (1 or 2)
        dim: Dimension(s) along which to compute the norm
        keepdim: Whether to keep the reduced dimensions
    
    Returns:
        MLX array containing the computed norm
    """
    if p not in [1, 2]:
        raise ValueError("Only p-norms with p of 1 or 2 are supported")
    
    # Handle dimension input
    if dim is None:
        dim = tuple(range(x.ndim))
    elif isinstance(dim, int):
        dim = (dim,)
    
    if p == 1:
        # L1 norm
        return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
    else:
        # L2 norm
        return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))

def weight_norm(weight_v: mx.array, 
                weight_g: mx.array, 
                dim: Optional[int] = None) -> mx.array:
    """
    Applies weight normalization to the input tensor.
    
    Weight normalization reparameterizes weight vectors in a neural network 
    as a magnitude scalar times a direction vector: w = g * v/||v||
    
    Args:
        weight_v: Weight direction tensor (v)
        weight_g: Weight magnitude tensor (g)
        dim: Dimension along which to normalize. If None, normalize over all dims
            except dim=-1
    
    Returns:
        Normalized weight tensor
    """
    rank = len(weight_v.shape)
    
    if dim is not None:
        # Adjust negative dim
        if dim < -1:
            dim += rank
            
        # Create list of axes to normalize over
        axes = list(range(rank))
        if dim != -1:
            axes.remove(dim)
    else:
        # Default behavior: normalize over all dimensions
        axes = list(range(rank))
    
    # Compute L2 norm of v along specified axes
    norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
    
    # Normalize and scale by g: w = g * (v / ||v||)
    normalized_weight = weight_v / (norm_v + 1e-7)  # Add epsilon for numerical stability
    return normalized_weight * weight_g

# Example usage:
def test_weight_norm():
    # Create sample tensors
    v = mx.random.normal((64, 3, 3))  # Direction tensor
    g = mx.random.normal((64, 1, 1))  # Magnitude tensor
    
    # Apply weight normalization
    w = weight_norm(v, g, dim=0)
    
    # Verify shape
    assert w.shape == v.shape
    
    # Verify norm along specified dimension
    norm_w = compute_norm(w, p=2, dim=[1, 2], keepdim=True)
    mx.eval(norm_w)  # Force computation
    
    return w, norm_w

if __name__ == "__main__":
    normalized_weight, weight_norm = test_weight_norm()
@awni
Copy link
Member

awni commented Feb 19, 2025

Is the reference code you posted working as expected? If not, what's the issue with it?

@Blaizzy
Copy link
Author

Blaizzy commented Feb 20, 2025

It works well for Linear, but I can't get it to match the output torch for Conv layers from torch, even if I override the weights with MLX ones.

The Conv are my focus.

@Blaizzy
Copy link
Author

Blaizzy commented Feb 20, 2025

Example usage with conv1d:

MLX

import mlx.core as mx
import mlx.nn as nn
from typing import Optional, Any
from dataclasses import dataclass
import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple
from torch.nn.utils import weight_norm
from torch import nn
import torch

# Set seeds for reproducibility
mx.random.seed(42)
torch.manual_seed(42)

def compute_norm(x: mx.array,
                p: int,
                dim: Optional[Union[int, List[int]]] = None,
                keepdim: bool = False) -> mx.array:
    """
    Compute the p-norm of a tensor along specified dimensions.

    Args:
        x: Input array
        p: Order of the norm (1 or 2)
        dim: Dimension(s) along which to compute the norm
        keepdim: Whether to keep the reduced dimensions

    Returns:
        MLX array containing the computed norm
    """
    if p not in [1, 2]:
        raise ValueError("Only p-norms with p of 1 or 2 are supported")

    # Handle dimension input
    if dim is None:
        dim = tuple(range(x.ndim))
    elif isinstance(dim, int):
        dim = (dim,)

    if p == 1:
        # L1 norm
        return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
    else:
        # L2 norm
        return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))

def weight_norm(weight_v: mx.array,
                weight_g: mx.array,
                dim: Optional[int] = None) -> mx.array:
    """
    Applies weight normalization to the input tensor.

    Weight normalization reparameterizes weight vectors in a neural network
    as a magnitude scalar times a direction vector: w = g * v/||v||

    Args:
        weight_v: Weight direction tensor (v)
        weight_g: Weight magnitude tensor (g)
        dim: Dimension along which to normalize. If None, normalize over all dims
            except dim=-1

    Returns:
        Normalized weight tensor
    """
    rank = len(weight_v.shape)

    if dim is not None:
        # Adjust negative dim
        if dim < -1:
            dim += rank

        # Create list of axes to normalize over
        axes = list(range(rank))
        if dim != -1:
            axes.remove(dim)
    else:
        # Default behavior: normalize over all dimensions
        axes = list(range(rank))

    # Compute L2 norm of v along specified axes
    norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)

    # Normalize and scale by g: w = g * (v / ||v||)
    normalized_weight = weight_v / (norm_v + 1e-7)  # Add epsilon for numerical stability
    return normalized_weight * weight_g


class WeightNormConv1d(nn.Module):
    """Conv1d layer with weight normalization"""
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
                 stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1,
                 bias: bool = True, dim: int = 0, transpose_weight_g: bool = False):
        super().__init__()

        # Initialize weight parameters
        weight_shape_g = (out_channels, 1, 1) if not transpose_weight_g else (in_channels, 1, 1)
        weight_shape_v = (out_channels, in_channels, kernel_size)

        # Store parameters
        self.weight_g = mx.random.normal(weight_shape_g)
        self.weight_v = mx.random.normal(weight_shape_v)
        self.bias = mx.zeros(out_channels) if bias else None
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.dim = dim


        # Compute normalized weight
        self.weight = weight_norm(self.weight_v, self.weight_g, dim=0)


    def __call__(self, x: mx.array) -> mx.array:
        # Apply conv1d
        out = mx.conv1d(x, self.weight, stride=self.stride, padding=self.padding,
                       dilation=self.dilation, groups=self.groups)
        if self.bias is not None:
            out = out + self.bias.reshape(1, 1, -1)
        return out

# Example usage:
layer = WeightNormConv1d(in_channels=20, out_channels=40, kernel_size=3, padding=1)
x = mx.random.normal((16, 20, 3))  # batch_size=16, channels=20, length=3
output = layer(x).swapaxes(1, 2)
print(f"Output shape: {output.shape}")  # Should be (16, 40, 3)
print(f"Weight_g shape: {layer.weight_g.shape}")  # Should be (40, 1, 1)
print(f"Weight_v shape: {layer.weight_v.shape}")  # Should be (40, 20, 3)

torch

m = weight_norm(nn.Conv1d(20, 40, kernel_size=3, padding=1, bias=False), name='weight')
torch_x = torch.from_dlpack(x)
torch_output = m(torch_x)
print(f"Output shape: {torch_output.shape}")  # Should be (16, 40, 3)
print(m)
print(m.weight_g.size())
print(m.weight_v.size())

print(np.allclose(output, torch_output.detach().numpy(), rtol=1e-3, atol=1e-3), output.sum().tolist(), torch_output.detach().numpy().sum())

>> (False, 23.648128509521484, 18.668007)

@cavit99
Copy link

cavit99 commented Mar 4, 2025

@Blaizzy
Inspired by your recent mlx-audio I looked into this issue based on our x convo.

I just submitted a PR that might address the problems you were facing and if merged could hopefully let you utilise weight_norm natively with good performance boost. Here's a summary (beyond what's in the PR)

  1. Dimension Ordering: MLX uses channel-last format while PyTorch uses channel-first. For Conv1d, the weight shapes are:

    • PyTorch: [out_channels, in_channels, kernel_size]
    • MLX: [out_channels, kernel_size, in_channels]
      This ordering difference needs special handling when normalizing, which I think may be missed in your above script
  2. linalg::norm Limitation: MLX's linalg::norm can only handle up to 2 axes. For Conv2d weights (with 3 axes to normalize), I implemented a reshape-based approach that:

    • Identifies dimensions to keep vs. normalize
    • Reshapes the tensor to use the optimized 2D norm kernel
    • Reshapes back for proper broadcasting
  3. Module Wrapper vs. Custom Layer: Rather than implementing a custom layer from scratch, I used a module wrapper approach that applies weight normalization to existing MLX layers. This:

    • Works with all of MLX's optimized layer implementations
    • Maintains compatibility with MLX's dimension ordering
    • Shows better performance (my benchmarks show 1.5-5x speedup over PyTorch MPS based on realistic audio implementations like yours, haven't benchmarked against your unbound python which uses mx.sum and mx.sqrt instead of mx.linalg.norm, but I suspect there would be improvements)
  4. Testing Approach: When comparing with PyTorch, I found two important insights:

    • Independent Implementations: Using common seeds still shows expected differences (up to ~5.0) which seems normal between frameworks
    • Direct Weight Transfer: Exact equivalence (differences < 1e-5) can be achieved when weights are properly transposed between frameworks
      I tested both approaches thoroughly, confirming that the mathematical properties are preserved even when numeric values differ slightly. This explains why direct output comparison might fail even when both implementations are correct.
      I do recommend checking out my test script in particular as it was a good amount of learning for me test_weight_norm.py

My PR includes both core API and module wrapper implementations, along with convenience classes similar to PyTorch's. Hope this helps, and feel free to check out the implementation if it's merged, or adopt as you need regardless

cavit99 added a commit to cavit99/mlx that referenced this issue Mar 4, 2025
Improve the weight normalization implementation by:
Use the optimized C++ mx.weight_norm() in WeightNormWrapper.call
Add comprehensive tests for WeightNormConv2d
Verify direct API usage matches module wrapper results
Test normalization over multiple axes and edge cases
Add specific test for GitHub issue ml-explore#1888
This change ensures maximum performance by leveraging the C++
implementation with its optimized handling of >2 axes normalization.
@Blaizzy
Copy link
Author

Blaizzy commented Mar 4, 2025

Thank you very much @cavit99, indeed it does solve it and matches the torch design! ❤️

MLX: [out_channels, kernel_size, in_channels]
This ordering difference needs special handling when normalizing, which I think may be missed in your above script

For a bit of context, I didn't update my implementation here because I was focused on the MLX releases but I did reach out to Awni privately notifying him that I had found the solution and I was going to send a PR after I had rested bit.

My updated solution that shipped with mlx-audio handles it and matches the torch implementation as I specified here: #1921 (comment)

linalg::norm Limitation: MLX's linalg::norm can only handle up to 2 axes. For Conv2d weights (with 3 axes to normalize)...

I like your approach here! I followed the torch implementation of normalization.

https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L5738
https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L3276

Finally, the models I'm implementing at the moment only use weighted 1D convs so I didn't think beyod 1D but it's absolute joy knowing that now it will work with 2D out-of-the-box. How about 3D?

@cavit99
Copy link

cavit99 commented Mar 4, 2025

Cool! sorry a bit of duplicate work then but at least it's got bindings too now!

The weight norm should work with 3D and any higher dims

See:

# From weight_norm.py
elif "Conv" in self.wn_module_type and dim == 0:
    weight_flat = mx.reshape(weight, (weight.shape[0], -1))
    self.weight_g = mx.linalg.norm(weight_flat, axis=1, keepdims=True)
    g_shape = [weight.shape[0]] + [1] * (weight.ndim - 1)
    self.weight_g = mx.reshape(self.weight_g, g_shape)

This code handles any convolution of any dimensionality by:

  1. Reshaping the weight tensor to a 2D matrix (output channels × flattened everything else)
  2. Computing the norm along the flattened dimension
  3. Reshaping the result back to match the original tensor's dimensions

The C++ implementation in ops.cpp explicitly handles higher dimensions:

// If we have more than 2 axes, use the reshape approach
if (norm_axes.size() > 2) {
  // Common case: keep one dimension (e.g., output channels)
  int keep_dim = keep_axes[0];
  std::vector<int> reshape_dims = {v.shape()[keep_dim], -1};
  array v_reshaped = reshape(v, reshape_dims, s);
  
  // Use the 2D norm kernel which is optimized
  array v_norm = linalg::norm(v_reshaped, std::vector<int>{1}, true, s);
}

The metal shaders optimized for 1D/2D data. By reshaping to a 2D tensor with output channels as one dimension and everything else flattened into the second dimension, we achieve the same mathematical result while using optimized kernels.

@Blaizzy
Copy link
Author

Blaizzy commented Mar 4, 2025

No worries, I think you did a great job and gave it the necessary attention!

@Blaizzy
Copy link
Author

Blaizzy commented Mar 4, 2025

Thank you very much once again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants