From dbe39a5334d5a549a31d0f12702e7c1eca707812 Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Tue, 11 Jun 2024 12:56:32 +0200 Subject: [PATCH] Allow marked axis in einx.experimental.shard only as leftmost member of composition --- einx/experimental/op/shard.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/einx/experimental/op/shard.py b/einx/experimental/op/shard.py index 2453374..aad6215 100644 --- a/einx/experimental/op/shard.py +++ b/einx/experimental/op/shard.py @@ -12,6 +12,15 @@ tnp = einx.tracer.import_("numpy", as_="np") +def _is_composed(expr): + node = expr + while node is not None: + if isinstance(node, einx.expr.stage3.Composition): + return True + node = node.parent + return False + + @einx.jit( trace=lambda t, c: lambda expr_in, tensor_in, expr_out, backend=None: c( expr_in, @@ -26,6 +35,19 @@ def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None): for expr in root.all(): if isinstance(expr, einx.expr.stage3.Concatenation): raise ValueError("Concatenation not allowed") + if isinstance(expr, einx.expr.stage3.Marker): + child = expr + while child.parent is not None: + if ( + isinstance(child.parent, einx.expr.stage3.List) + and _is_composed(child.parent) + and child is not child.parent.children[0] + ): + raise ValueError( + "If device axes are used within a composition they " + "must appear as the left-most member of the composition" + ) + child = child.parent # Call tensor factories tensor_in = einx.tracer.call_factory(tensor_in, expr_in.shape, backend=backend) @@ -59,9 +81,10 @@ def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None): # Construct partition spec axes = tuple(axis for axis in expr_in if isinstance(axis, einx.expr.stage3.Axis)) partition_spec = [axis.name if einx.expr.stage3.is_marked(axis) else None for axis in axes] + partition_spec = tP(*partition_spec) # Shard tensor - sharding = tNamedSharding(mesh, tP(*partition_spec)) + sharding = tNamedSharding(mesh, partition_spec) tensor_in = tjax.device_put(tensor_in, sharding) # Unflatten output expressions @@ -203,7 +226,7 @@ def shard( representation of the operation. """ if backend.name != "jax": - raise NotImplementedError("einx.shard is currently only supported for Jax") + raise NotImplementedError("einx.experimental.shard is currently only supported for Jax") expr_in, expr_out = parse( description, einx.tracer.get_shape(tensor), mesh=mesh, cse=cse, **parameters )