Skip to content

Commit

Permalink
Add groups support for sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Mar 4, 2025
1 parent b543cb4 commit a323642
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 23 deletions.
6 changes: 4 additions & 2 deletions mlx/backend/cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
Expand Down
6 changes: 4 additions & 2 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
move_or_copy(in, out);
} else {
copy_gpu(in, out, CopyType::General);
Expand Down
3 changes: 3 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,9 @@ array concatenate(
throw std::invalid_argument(
"[concatenate] No arrays provided for concatenation");
}
if (arrays.size() == 1) {
return arrays[0];
}

auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");

Expand Down
2 changes: 2 additions & 0 deletions python/mlx/distributed_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,8 @@ def main():
"--cwd", help="Set the working directory on each node to the provided one"
)
args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)

if args.print_python:
print(sys.executable)
Expand Down
95 changes: 76 additions & 19 deletions python/mlx/nn/layers/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from functools import lru_cache
from typing import Optional
from typing import Optional, Union

import mlx.core as mx
from mlx.nn.layers.base import Module
Expand All @@ -26,7 +26,20 @@ def f(x, dx, _):
return f


def _all_to_sharded(parameters: dict, group: Optional[mx.distributed.Group] = None):
def _split(weight, groups, axis):
if isinstance(groups, int) or isinstance(groups[0], int):
return mx.split(weight, groups, axis=axis)

N = weight.shape[axis]
indices = [int(g * N) for g in groups]
return mx.split(weight, indices, axis=axis)


def _all_to_sharded(
parameters: dict,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
Expand All @@ -37,13 +50,25 @@ def _all_to_sharded(parameters: dict, group: Optional[mx.distributed.Group] = No
if not isinstance(parameters[k], mx.array):
continue

step = parameters[k].shape[-2] // N
parameters[k] = parameters[k][..., r * step : (r + 1) * step, :] * 1
axis = max(parameters[k].ndim - 2, 0)
parameters[k] = mx.contiguous(
mx.concatenate(
[
_split(part, N, axis)[r]
for part in _split(parameters[k], groups, axis)
],
axis=axis,
)
)

return parameters


def _sharded_to_all(parameters: dict, group: Optional[mx.distributed.Group] = None):
def _sharded_to_all(
parameters: dict,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
N = group.size()
r = group.rank()
Expand All @@ -56,8 +81,12 @@ def _sharded_to_all(parameters: dict, group: Optional[mx.distributed.Group] = No
if k == "bias":
continue

step = parameters[k].shape[-1] // N
parameters[k] = parameters[k][..., r * step : (r + 1) * step] * 1
parameters[k] = mx.contiguous(
mx.concatenate(
[_split(part, N, -1)[r] for part in _split(parameters[k], groups, -1)],
axis=-1,
)
)

return parameters

Expand All @@ -75,18 +104,22 @@ def _check_sharding(sharding):
def shard_inplace(
module: Module,
sharding: str,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
_check_sharding(sharding)
shard_function = (
_all_to_sharded if sharding == "all-to-sharded" else _sharded_to_all
)
module.update(shard_function(module.parameters(), group))
module.update(shard_function(module.parameters(), groups=groups, group=group))


def shard_linear(
module: Module,
sharding: str,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
_check_sharding(sharding)
Expand All @@ -96,7 +129,7 @@ def shard_linear(
("sharded-to-all", True): ShardedToAllLinear.from_linear,
("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear,
}
return fns[sharding, isinstance(module, Linear)](module, group)
return fns[sharding, isinstance(module, Linear)](module, groups=groups, group=group)


class AllToShardedLinear(Module):
Expand Down Expand Up @@ -166,13 +199,19 @@ def __call__(self, x: mx.array) -> mx.array:

@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
cls,
linear_layer: Module,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
output_dims, input_dims = linear_layer.weight.shape

sl = cls(input_dims, output_dims, False, group)
sl.update(_all_to_sharded(linear_layer.parameters(), group))
sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group)
sl.update(
_all_to_sharded(linear_layer.parameters(), groups=groups, group=group)
)

return sl

Expand Down Expand Up @@ -252,13 +291,19 @@ def __call__(self, x: mx.array) -> mx.array:

@classmethod
def from_linear(
cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None
cls,
linear_layer: Module,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
output_dims, input_dims = linear_layer.weight.shape

sl = cls(input_dims, output_dims, False, group)
sl.update(_sharded_to_all(linear_layer.parameters(), group))
sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group)
sl.update(
_sharded_to_all(linear_layer.parameters(), groups=groups, group=group)
)

return sl

Expand Down Expand Up @@ -361,6 +406,8 @@ def __call__(self, x: mx.array) -> mx.array:
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
Expand All @@ -370,12 +417,16 @@ def from_quantized_linear(
sl = cls(
input_dims,
output_dims,
False,
hasattr(quantized_linear_layer, "bias"),
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.update(_all_to_sharded(quantized_linear_layer.parameters(), group))
sl.update(
_all_to_sharded(
quantized_linear_layer.parameters(), groups=groups, group=group
)
)

return sl

Expand Down Expand Up @@ -477,6 +528,8 @@ def __call__(self, x: mx.array) -> mx.array:
def from_quantized_linear(
cls,
quantized_linear_layer: Module,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
group = group or mx.distributed.init()
Expand All @@ -486,11 +539,15 @@ def from_quantized_linear(
sl = cls(
input_dims,
output_dims,
False,
hasattr(quantized_linear_layer, "bias"),
group_size=quantized_linear_layer.group_size,
bits=quantized_linear_layer.bits,
group=group,
)
sl.update(_sharded_to_all(quantized_linear_layer.parameters(), group))
sl.update(
_sharded_to_all(
quantized_linear_layer.parameters(), groups=groups, group=group
)
)

return sl
19 changes: 19 additions & 0 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5124,4 +5124,23 @@ void init_ops(nb::module_& m) {
[0, 1, 0],
[0, 1, 0]], dtype=float32)
)pbdoc");
m.def(
"contiguous",
&mx::contiguous,
nb::arg(),
"allow_col_major"_a = false,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Force an array to be row contiguous. Copy if necessary.
Args:
a (array): The input to make contiguous
allow_col_major (bool): Consider column major as contiguous and don't copy
Returns:
array: The row or col contiguous output.
)pbdoc");
}

0 comments on commit a323642

Please sign in to comment.