From a323642c6fcc39bcf26e23ce6204ac8e47273724 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 3 Mar 2025 13:25:44 -0800 Subject: [PATCH] Add groups support for sharding --- mlx/backend/cpu/primitives.cpp | 6 +- mlx/backend/metal/primitives.cpp | 6 +- mlx/ops.cpp | 3 + python/mlx/distributed_run.py | 2 + python/mlx/nn/layers/distributed.py | 95 +++++++++++++++++++++++------ python/src/ops.cpp | 19 ++++++ 6 files changed, 108 insertions(+), 23 deletions(-) diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index b5d9d7ef35..126ebe2d54 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -141,8 +141,10 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy(in, out, CopyType::General); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index df8638012f..20d8409ada 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -252,8 +252,10 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { void Contiguous::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { move_or_copy(in, out); } else { copy_gpu(in, out, CopyType::General); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4e147487d5..5a64a78521 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -993,6 +993,9 @@ array concatenate( throw std::invalid_argument( "[concatenate] No arrays provided for concatenation"); } + if (arrays.size() == 1) { + return arrays[0]; + } auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] "); diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 1a749beedf..ba40bda41a 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -756,6 +756,8 @@ def main(): "--cwd", help="Set the working directory on each node to the provided one" ) args, rest = parser.parse_known_args() + if rest[0] == "--": + rest.pop(0) if args.print_python: print(sys.executable) diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py index ec0f8cd1a1..4e9a665c60 100644 --- a/python/mlx/nn/layers/distributed.py +++ b/python/mlx/nn/layers/distributed.py @@ -2,7 +2,7 @@ import math from functools import lru_cache -from typing import Optional +from typing import Optional, Union import mlx.core as mx from mlx.nn.layers.base import Module @@ -26,7 +26,20 @@ def f(x, dx, _): return f -def _all_to_sharded(parameters: dict, group: Optional[mx.distributed.Group] = None): +def _split(weight, groups, axis): + if isinstance(groups, int) or isinstance(groups[0], int): + return mx.split(weight, groups, axis=axis) + + N = weight.shape[axis] + indices = [int(g * N) for g in groups] + return mx.split(weight, indices, axis=axis) + + +def _all_to_sharded( + parameters: dict, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): group = group or mx.distributed.init() N = group.size() r = group.rank() @@ -37,13 +50,25 @@ def _all_to_sharded(parameters: dict, group: Optional[mx.distributed.Group] = No if not isinstance(parameters[k], mx.array): continue - step = parameters[k].shape[-2] // N - parameters[k] = parameters[k][..., r * step : (r + 1) * step, :] * 1 + axis = max(parameters[k].ndim - 2, 0) + parameters[k] = mx.contiguous( + mx.concatenate( + [ + _split(part, N, axis)[r] + for part in _split(parameters[k], groups, axis) + ], + axis=axis, + ) + ) return parameters -def _sharded_to_all(parameters: dict, group: Optional[mx.distributed.Group] = None): +def _sharded_to_all( + parameters: dict, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): group = group or mx.distributed.init() N = group.size() r = group.rank() @@ -56,8 +81,12 @@ def _sharded_to_all(parameters: dict, group: Optional[mx.distributed.Group] = No if k == "bias": continue - step = parameters[k].shape[-1] // N - parameters[k] = parameters[k][..., r * step : (r + 1) * step] * 1 + parameters[k] = mx.contiguous( + mx.concatenate( + [_split(part, N, -1)[r] for part in _split(parameters[k], groups, -1)], + axis=-1, + ) + ) return parameters @@ -75,18 +104,22 @@ def _check_sharding(sharding): def shard_inplace( module: Module, sharding: str, + *, + groups: Union[int, list] = 1, group: Optional[mx.distributed.Group] = None, ): _check_sharding(sharding) shard_function = ( _all_to_sharded if sharding == "all-to-sharded" else _sharded_to_all ) - module.update(shard_function(module.parameters(), group)) + module.update(shard_function(module.parameters(), groups=groups, group=group)) def shard_linear( module: Module, sharding: str, + *, + groups: Union[int, list] = 1, group: Optional[mx.distributed.Group] = None, ): _check_sharding(sharding) @@ -96,7 +129,7 @@ def shard_linear( ("sharded-to-all", True): ShardedToAllLinear.from_linear, ("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear, } - return fns[sharding, isinstance(module, Linear)](module, group) + return fns[sharding, isinstance(module, Linear)](module, groups=groups, group=group) class AllToShardedLinear(Module): @@ -166,13 +199,19 @@ def __call__(self, x: mx.array) -> mx.array: @classmethod def from_linear( - cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None + cls, + linear_layer: Module, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, ): group = group or mx.distributed.init() output_dims, input_dims = linear_layer.weight.shape - sl = cls(input_dims, output_dims, False, group) - sl.update(_all_to_sharded(linear_layer.parameters(), group)) + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update( + _all_to_sharded(linear_layer.parameters(), groups=groups, group=group) + ) return sl @@ -252,13 +291,19 @@ def __call__(self, x: mx.array) -> mx.array: @classmethod def from_linear( - cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None + cls, + linear_layer: Module, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, ): group = group or mx.distributed.init() output_dims, input_dims = linear_layer.weight.shape - sl = cls(input_dims, output_dims, False, group) - sl.update(_sharded_to_all(linear_layer.parameters(), group)) + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update( + _sharded_to_all(linear_layer.parameters(), groups=groups, group=group) + ) return sl @@ -361,6 +406,8 @@ def __call__(self, x: mx.array) -> mx.array: def from_quantized_linear( cls, quantized_linear_layer: Module, + *, + groups: Union[int, list] = 1, group: Optional[mx.distributed.Group] = None, ): group = group or mx.distributed.init() @@ -370,12 +417,16 @@ def from_quantized_linear( sl = cls( input_dims, output_dims, - False, + hasattr(quantized_linear_layer, "bias"), group_size=quantized_linear_layer.group_size, bits=quantized_linear_layer.bits, group=group, ) - sl.update(_all_to_sharded(quantized_linear_layer.parameters(), group)) + sl.update( + _all_to_sharded( + quantized_linear_layer.parameters(), groups=groups, group=group + ) + ) return sl @@ -477,6 +528,8 @@ def __call__(self, x: mx.array) -> mx.array: def from_quantized_linear( cls, quantized_linear_layer: Module, + *, + groups: Union[int, list] = 1, group: Optional[mx.distributed.Group] = None, ): group = group or mx.distributed.init() @@ -486,11 +539,15 @@ def from_quantized_linear( sl = cls( input_dims, output_dims, - False, + hasattr(quantized_linear_layer, "bias"), group_size=quantized_linear_layer.group_size, bits=quantized_linear_layer.bits, group=group, ) - sl.update(_sharded_to_all(quantized_linear_layer.parameters(), group)) + sl.update( + _sharded_to_all( + quantized_linear_layer.parameters(), groups=groups, group=group + ) + ) return sl diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1577cae185..6de580d1b9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5124,4 +5124,23 @@ void init_ops(nb::module_& m) { [0, 1, 0], [0, 1, 0]], dtype=float32) )pbdoc"); + m.def( + "contiguous", + &mx::contiguous, + nb::arg(), + "allow_col_major"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Force an array to be row contiguous. Copy if necessary. + + Args: + a (array): The input to make contiguous + allow_col_major (bool): Consider column major as contiguous and don't copy + + Returns: + array: The row or col contiguous output. + )pbdoc"); }