Skip to content

Commit

Permalink
Allow marked axis in einx.experimental.shard only as leftmost member …
Browse files Browse the repository at this point in the history
…of composition
  • Loading branch information
fferflo committed Jun 11, 2024
1 parent 254797d commit dbe39a5
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions einx/experimental/op/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
tnp = einx.tracer.import_("numpy", as_="np")


def _is_composed(expr):
node = expr
while node is not None:
if isinstance(node, einx.expr.stage3.Composition):
return True
node = node.parent
return False


@einx.jit(
trace=lambda t, c: lambda expr_in, tensor_in, expr_out, backend=None: c(
expr_in,
Expand All @@ -26,6 +35,19 @@ def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None):
for expr in root.all():
if isinstance(expr, einx.expr.stage3.Concatenation):
raise ValueError("Concatenation not allowed")
if isinstance(expr, einx.expr.stage3.Marker):
child = expr
while child.parent is not None:
if (
isinstance(child.parent, einx.expr.stage3.List)
and _is_composed(child.parent)
and child is not child.parent.children[0]
):
raise ValueError(
"If device axes are used within a composition they "
"must appear as the left-most member of the composition"
)
child = child.parent

# Call tensor factories
tensor_in = einx.tracer.call_factory(tensor_in, expr_in.shape, backend=backend)
Expand Down Expand Up @@ -59,9 +81,10 @@ def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None):
# Construct partition spec
axes = tuple(axis for axis in expr_in if isinstance(axis, einx.expr.stage3.Axis))
partition_spec = [axis.name if einx.expr.stage3.is_marked(axis) else None for axis in axes]
partition_spec = tP(*partition_spec)

# Shard tensor
sharding = tNamedSharding(mesh, tP(*partition_spec))
sharding = tNamedSharding(mesh, partition_spec)
tensor_in = tjax.device_put(tensor_in, sharding)

# Unflatten output expressions
Expand Down Expand Up @@ -203,7 +226,7 @@ def shard(
representation of the operation.
"""
if backend.name != "jax":
raise NotImplementedError("einx.shard is currently only supported for Jax")
raise NotImplementedError("einx.experimental.shard is currently only supported for Jax")
expr_in, expr_out = parse(
description, einx.tracer.get_shape(tensor), mesh=mesh, cse=cse, **parameters
)
Expand Down

0 comments on commit dbe39a5

Please sign in to comment.