Skip to content

Commit

Permalink
Add weight normalization to MLX core and nn.layers (fixes ml-explore#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cavit99 committed Mar 4, 2025
1 parent 6bcd6bc commit b9d29f7
Show file tree
Hide file tree
Showing 6 changed files with 1,222 additions and 0 deletions.
90 changes: 90 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <sstream>

#include "mlx/fast.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
Expand Down Expand Up @@ -5003,4 +5004,93 @@ array contiguous(
{a});
}

array weight_norm(
const array& v,
const array& g,
const std::vector<int>& axes,
float eps /* = 1e-5 */,
StreamOrDevice s /* = {} */) {
// If no axes provided, normalize over all axes
std::vector<int> norm_axes = axes.empty() ? std::vector<int>{} : axes;
if (norm_axes.empty()) {
for (int i = 0; i < v.ndim(); ++i) {
norm_axes.push_back(i);
}
}

// If we have more than 2 axes, use the reshape approach
if (norm_axes.size() > 2) {
// Find the dimensions to keep (not in norm_axes)
std::vector<int> keep_axes;
for (int i = 0; i < v.ndim(); ++i) {
if (std::find(norm_axes.begin(), norm_axes.end(), i) == norm_axes.end()) {
keep_axes.push_back(i);
}
}

// Handle based on dimensions to keep
if (keep_axes.empty()) {
// If normalizing over all dimensions, reshape to 1D
array v_flat = reshape(v, {-1}, s);
array v_norm = linalg::norm(v_flat, std::vector<int>{0}, true, s);
v_norm = reshape(v_norm, std::vector<int>(v.ndim(), 1), s);

// Add epsilon for numerical stability
v_norm = maximum(v_norm, array(eps), s);

// Normalize v and scale by g
return multiply(g, divide(v, v_norm, s), s);
} else if (keep_axes.size() == 1) {
// Common case: keep one dimension (e.g., output channels)
int keep_dim = keep_axes[0];
std::vector<int> reshape_dims = {v.shape()[keep_dim], -1};
array v_reshaped = reshape(v, reshape_dims, s);

// Use the 2D norm kernel which is optimized
array v_norm = linalg::norm(v_reshaped, std::vector<int>{1}, true, s);

// Reshape for broadcasting
std::vector<int> norm_shape(v.ndim(), 1);
norm_shape[keep_dim] = v.shape()[keep_dim];
v_norm = reshape(v_norm, norm_shape, s);

// Add epsilon for numerical stability
v_norm = maximum(v_norm, array(eps), s);

// Normalize v and scale by g
return multiply(g, divide(v, v_norm, s), s);
} else {
// Multiple keep dimensions - more complex case
int prod_keep_dims = 1;
for (auto dim : keep_axes) {
prod_keep_dims *= v.shape()[dim];
}

std::vector<int> reshape_dims = {prod_keep_dims, -1};
array v_reshaped = reshape(v, reshape_dims, s);

array v_norm = linalg::norm(v_reshaped, std::vector<int>{1}, true, s);

// Reshape back for correct broadcasting
std::vector<int> norm_shape(v.ndim(), 1);
for (auto dim : keep_axes) {
norm_shape[dim] = v.shape()[dim];
}

v_norm = reshape(v_norm, norm_shape, s);

// Add epsilon for numerical stability
v_norm = maximum(v_norm, array(eps), s);

// Normalize v and scale by g
return multiply(g, divide(v, v_norm, s), s);
}
} else {
// Use direct approach for 1-2 axes (leveraging optimized kernels)
array v_norm = linalg::norm(v, norm_axes, true, s);
v_norm = maximum(v_norm, array(eps), s);
return multiply(g, divide(v, v_norm, s), s);
}
}

} // namespace mlx::core
47 changes: 47 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,53 @@ array contiguous(
bool allow_col_major = false,
StreamOrDevice s = {});

/**
* Performs weight normalization on a tensor.
*
* Weight normalization reparameterizes a tensor as:
* weight = g * (v / ||v||)
*
* Where:
* - g is a scalar or vector scaling factor
* - v is the unnormalized weight
* - ||v|| is the norm of v along specified dimensions
*
* Args:
* v: Input tensor to be normalized
* g: Scaling factor (should match shape of v with singleton dimensions for
* normalized axes) axes: Axes along which to normalize. For more than 2 axes, a
* reshape-based approach is used. eps: Small constant for numerical stability
* s: Stream or device
*
* Returns:
* Normalized weight tensor
*/
array weight_norm(
const array& v,
const array& g,
const std::vector<int>& axes,
float eps = 1e-5,
StreamOrDevice s = {});

/** Weight normalization along a single axis */
inline array weight_norm(
const array& v,
const array& g,
int axis,
float eps = 1e-5,
StreamOrDevice s = {}) {
return weight_norm(v, g, std::vector<int>{axis}, eps, s);
}

/** Weight normalization along all axes */
inline array weight_norm(
const array& v,
const array& g,
float eps = 1e-5,
StreamOrDevice s = {}) {
return weight_norm(v, g, std::vector<int>{}, eps, s);
}

/** @} */

} // namespace mlx::core
6 changes: 6 additions & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,9 @@
TransformerEncoderLayer,
)
from mlx.nn.layers.upsample import Upsample
from mlx.nn.layers.weight_norm import (
WeightNormConv1d,
WeightNormConv2d,
WeightNormLinear,
weight_norm,
)
Loading

0 comments on commit b9d29f7

Please sign in to comment.