diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 22b940814e..8def857d3f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -473,8 +473,27 @@ array hadamard_transform( std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) - float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); + int n = a.shape(-1); + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + + // Support large N on GPU with multiple uploads + constexpr int threadgroup_mem_size = 32768; + bool multi_upload = + n * dtype.size() > threadgroup_mem_size && is_power_of_2(n); + if (to_stream(s).device == Device::gpu && multi_upload) { + int n1 = threadgroup_mem_size / dtype.size(); + int n2 = n / n1; + auto b = unflatten(a, -1, {n2, n1}, s); + b = swapaxes(b, -1, -2, s); + b = hadamard_transform(b, /*scale=*/1.0, s); + b = swapaxes(b, -1, -2, s); + b = hadamard_transform(b, /*scale=*/1.0, s); + b = flatten(b, -2, -1, s); + b = multiply(b, array({scale}, dtype), s); + return b; + } + return array( a.shape(), dtype,