diff --git a/docs/source/api.rst b/docs/source/api.rst index b90e32b..a5ff4c7 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -122,4 +122,9 @@ Keras .. autoclass:: einx.nn.keras.Norm .. autoclass:: einx.nn.keras.Dropout -.. autofunction:: einx.nn.keras.param \ No newline at end of file +.. autofunction:: einx.nn.keras.param + +Experimental +============ + +.. autofunction:: einx.experimental.shard \ No newline at end of file diff --git a/einx/__init__.py b/einx/__init__.py index 876c514..59bc6f2 100644 --- a/einx/__init__.py +++ b/einx/__init__.py @@ -7,3 +7,4 @@ from . import expr from .op import * from . import nn +from . import experimental diff --git a/einx/experimental/__init__.py b/einx/experimental/__init__.py new file mode 100644 index 0000000..3c4b4b4 --- /dev/null +++ b/einx/experimental/__init__.py @@ -0,0 +1 @@ +from .op import * diff --git a/einx/experimental/op/__init__.py b/einx/experimental/op/__init__.py new file mode 100644 index 0000000..21f5ab7 --- /dev/null +++ b/einx/experimental/op/__init__.py @@ -0,0 +1 @@ +from .shard import * diff --git a/einx/experimental/op/shard.py b/einx/experimental/op/shard.py new file mode 100644 index 0000000..6354806 --- /dev/null +++ b/einx/experimental/op/shard.py @@ -0,0 +1,214 @@ +import einx +import einx.op.util as util +import numpy as np +from functools import partial +from typing import Callable, Union, Any +import numpy.typing as npt + +tP = einx.tracer.import_("PartitionSpec", "P", from_="jax.sharding") +tNamedSharding = einx.tracer.import_("NamedSharding", from_="jax.sharding") +tMesh = einx.tracer.import_("Mesh", from_="jax.sharding") +tjax = einx.tracer.import_("jax") +tnp = einx.tracer.import_("numpy", as_="np") + + +@einx.jit( + trace=lambda t, c: lambda expr_in, tensor_in, expr_out, backend=None: c( + expr_in, + t(tensor_in), + expr_out, + ) +) +def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None): + import jax + + for root in [expr_in, expr_out]: + for expr in root.all(): + if isinstance(expr, einx.expr.stage3.Concatenation): + raise ValueError("Concatenation not allowed") + + # Call tensor factories + tensor_in = einx.tracer.call_factory(tensor_in, expr_in.shape, backend=backend) + (tensor_in,) = backend.all_to_tensor([tensor_in]) + + # Flatten expressions + (expr_in,), (tensor_in,) = util.flatten([expr_in], [tensor_in], backend=backend) + marked_axes = tuple( + axis + for axis in expr_in + if isinstance(axis, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(axis) + ) + + if mesh is None: + # Construct new mesh + devices = tnp.array(tjax.devices()).reshape(tuple(a.value for a in marked_axes)) + mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes)) + elif isinstance(mesh, jax.sharding.Mesh): + # Got mesh -> check that marked axes match mesh + marked_names = set(a.name for a in marked_axes) + mesh_names = set(str(a) for a in mesh.axis_names) + if not marked_names.issubset(mesh_names): + raise ValueError( + f"Marked axes must be subset of mesh axes. Got marked axes {marked_names} and mesh axes {mesh_names}" + ) + else: + # Got list of devices -> construct new mesh + devices = tnp.array(mesh).reshape(tuple(a.value for a in marked_axes)) + mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes)) + + # Construct partition spec + axes = tuple(axis for axis in expr_in if isinstance(axis, einx.expr.stage3.Axis)) + partition_spec = [axis.name if einx.expr.stage3.is_marked(axis) else None for axis in axes] + + # Shard tensor + sharding = tNamedSharding(mesh, tP(*partition_spec)) + tensor_in = tjax.device_put(tensor_in, sharding) + + # Unflatten output expressions + (tensor_in,) = util.unflatten([expr_in], [tensor_in], [expr_out], backend=backend) + + return tensor_in, expr_in + + +@einx.lru_cache +def parse(description, tensor_shape, cse=True, mesh=None, jax_devices=None, **parameters): + import jax + + description, parameters = einx.op.util._clean_description_and_parameters( + description, parameters + ) + + op = einx.expr.stage1.parse_op(description) + + if len(op) != 1: + raise ValueError(f"Expected exactly one expression, got {len(op)}") + + def solve(eqs): + return einx.expr.solve( + [einx.expr.Equation(op[0][0], tensor_shape)] + + eqs + + [ + einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None) + for k, v in parameters.items() + ], + cse=cse, + )[0] + + if mesh is None: + # If no mesh is given, create new mesh of all devices + try: + expr_in = solve([]) + except einx.expr.SolveException as e: + # Try with additional constraint of total number of devices + expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0])) + mesh_eq = einx.expr.Equation(expr_mesh, [len(jax.devices())]) + try: + expr_in = solve([mesh_eq]) + except einx.expr.SolveException: + # If it still fails, reraise original exception + raise e + elif isinstance(mesh, jax.sharding.Mesh): + # Add constraints for existing mesh axes + expr_mesh = einx.expr.stage1.Marker( + einx.expr.stage1.List.maybe([ + einx.expr.stage1.NamedAxis(name) for name in mesh.axis_names + ]) + ) + mesh_eq = einx.expr.Equation(expr_mesh, mesh.devices.shape) + + expr_in = solve([mesh_eq]) + elif isinstance(mesh, (list, tuple)): + # Add constraint for number of devices + expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0])) + mesh_eq = einx.expr.Equation(expr_mesh, [len(mesh)]) + expr_in = solve([mesh_eq]) + + expr_out = expr_in.__deepcopy__() + + return expr_in, expr_out + + +@einx.traceback_util.filter +@einx.jit( + trace=lambda t, c: lambda description, tensor, mesh=None, backend=None, **kwargs: c( + description, t(tensor), mesh=mesh, **kwargs + ) +) +def shard( + description: str, + tensor: einx.Tensor, + mesh: Any = None, + backend: Union[einx.Backend, str, None] = "jax", + cse: bool = True, + **parameters: npt.ArrayLike, +) -> einx.Tensor: + """Shards a tensor over a mesh of devices. + + *This function is currently only supported for Jax: A sharding is created + based on the given expression, and applied to the tensor using* ``jax.device_put``. + + The tensor is sharded across the marked axes in the input expression. The marked axes + match the axis names and shape of the mesh: + + >>> x = jnp.ones((2, 4, 128)) + >>> x = einx.experimental.shard("[d1 d2] c") + >>> x.sharding + NamedSharding(mesh=Mesh('d1': 2, 'd2': 4), spec=PartitionSpec('d1', 'd2', None)) + + Axis compositions can be used to apply the + `sharding rules of Jax `_, + where tensor axes are evenly divided by the number of shards: + + >>> x = jnp.ones((128, 640, 480, 3)) + >>> x = einx.experimental.shard("([batch] _) ...", x) + >>> x.sharding + NamedSharding(mesh=Mesh('batch': 8), spec=PartitionSpec('batch',)) + + If possible, the sharding is created over all devices. ``_`` is a regular axis name, + and its value is determined by :doc:`einx's expression solver `. + + Optionally, an existing mesh can be passed: + + >>> from jax.sharding import Mesh + >>> devices = np.asarray(jax.devices()).reshape(4, 2) + >>> mesh = Mesh(devices, axis_names=("d1", "d2")) + >>> x = jnp.ones((4, 1024, 1024)) + >>> x = einx.experimental.shard("a ([d2] b) ([d1] c)", x, mesh=mesh) + >>> x.sharding + NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd2', 'd1')) + + The array is replicated over all mesh axes that are not part of the expression: + + >>> x = jnp.ones((1024, 1024)) + >>> x = einx.experimental.shard("a ([d1] b)", x, mesh=mesh) + >>> x.sharding + NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd1',)) + + **This function is currently experimental and will likely change in future versions.** + + Args: + description: Description string in Einstein notation (see above). + tensor: Input tensor or tensor factory matching the description string. + mesh: Mesh or list of devices to shard the tensor over. If not given, a new mesh over all + available devices will be created matching the axes in the given expression. + Defaults to ``None``. + cse: Whether to apply common subexpression elimination to the expressions. Defaults + to True. + graph: Whether to return the graph representation of the operation instead of + computing the result. Defaults to False. + **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``. + + Returns: + The sharded tensor if ``graph=False``, otherwise the graph + representation of the operation. + """ + if backend.name != "jax": + raise NotImplementedError("einx.shard is currently only supported for Jax") + expr_in, expr_out = parse( + description, einx.tracer.get_shape(tensor), mesh=mesh, cse=cse, **parameters + ) + tensor, expr_out = shard_stage3(expr_in, tensor, expr_out, mesh=mesh, backend=backend) + return tensor + + +shard.parse = parse diff --git a/einx/expr/__init__.py b/einx/expr/__init__.py index 7bc8c62..396b927 100644 --- a/einx/expr/__init__.py +++ b/einx/expr/__init__.py @@ -1,2 +1,3 @@ from . import stage1, stage2, stage3 from .util import * +from .solver import SolveException diff --git a/einx/expr/solver.py b/einx/expr/solver.py index 5da72fa..3811415 100644 --- a/einx/expr/solver.py +++ b/einx/expr/solver.py @@ -141,7 +141,7 @@ def to_term(x): class SolveException(Exception): def __init__(self, message): - self.message = message + super().__init__(message) def solve(equations): diff --git a/einx/expr/stage2.py b/einx/expr/stage2.py index 05cfcff..edb4264 100644 --- a/einx/expr/stage2.py +++ b/einx/expr/stage2.py @@ -189,7 +189,7 @@ def all(self): yield from self.inner.all() -class SolveDepthException(Exception): +class SolveDepthException(solver.SolveException): def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, message): assert ( len({ @@ -208,22 +208,23 @@ def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, m self.expansions2 = expansions2 self.depths1 = depths1 self.depths2 = depths2 - self.message = ( + message_in = message + message = ( "Failed to solve for the depth of axes, i.e. the number of outer ellipses.\n" "Equations:\n" ) for expr1, expr2 in zip(exprs1, exprs2): if expr1 is not None and expr2 is not None: - self.message += " " - self.message += f"{einx.expr.util._to_str(expr1)}" - self.message += " = " - self.message += f"{einx.expr.util._to_str(expr2)}" - self.message += "\n" - self.message += f"Reason: {message}" - super().__init__(self.message) + message += " " + message += f"{einx.expr.util._to_str(expr1)}" + message += " = " + message += f"{einx.expr.util._to_str(expr2)}" + message += "\n" + message += f"Reason: {message_in}" + super().__init__(message) -class SolveExpansionException(Exception): +class SolveExpansionException(solver.SolveException): def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, message): assert ( len({ @@ -242,16 +243,17 @@ def __init__(self, exprs1, exprs2, expansions1, expansions2, depths1, depths2, m self.expansions2 = expansions2 self.depths1 = depths1 self.depths2 = depths2 - self.message = "Failed to solve for the number of axes in the expressions.\nEquations:\n" + message_in = message + message = "Failed to solve for the number of axes in the expressions.\nEquations:\n" for expr1, expr2 in zip(exprs1, exprs2): if expr1 is not None and expr2 is not None: - self.message += " " - self.message += f"{einx.expr.util._to_str(expr1)}" - self.message += " = " - self.message += f"{einx.expr.util._to_str(expr2)}" - self.message += "\n" - self.message += f"Reason: {message}" - super().__init__(self.message) + message += " " + message += f"{einx.expr.util._to_str(expr1)}" + message += " = " + message += f"{einx.expr.util._to_str(expr2)}" + message += "\n" + message += f"Reason: {message_in}" + super().__init__(message) def solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2): diff --git a/einx/expr/stage3.py b/einx/expr/stage3.py index 9027136..b37ed25 100644 --- a/einx/expr/stage3.py +++ b/einx/expr/stage3.py @@ -223,16 +223,14 @@ def all(self): yield from self.inner.all() -class SolveValueException(Exception): +class SolveValueException(solver.SolveException): def __init__(self, exprs1, exprs2, message): self.exprs1 = exprs1 self.exprs2 = exprs2 - self.message = f"Failed to solve values of expressions. {message}\nInput:\n" + message = f"Failed to solve values of expressions. {message}\nInput:\n" for expr1, expr2 in zip(exprs1, exprs2): - self.message += ( - f" '{einx.expr.util._to_str(expr1)} = {einx.expr.util._to_str(expr2)}'\n" - ) - super().__init__(self.message) + message += f" '{einx.expr.util._to_str(expr1)} = {einx.expr.util._to_str(expr2)}'\n" + super().__init__(message) def solve(exprs1, exprs2): diff --git a/einx/tracer/compile.py b/einx/tracer/compile.py index 37c209f..eaee72a 100644 --- a/einx/tracer/compile.py +++ b/einx/tracer/compile.py @@ -213,11 +213,13 @@ def execute_application(self, application): use_dynamic_output_check = False if isinstance(application.op, Import): - import_str = f"import {application.op.module}" - name = application.op.module - if not application.op.shorthand is None: - import_str += f" as {application.op.shorthand}" - name = application.op.shorthand + import_str = f"import {application.op.import_}" + name = application.op.import_ + if not application.op.as_ is None: + import_str = f"{import_str} as {application.op.as_}" + name = application.op.as_ + if not application.op.from_ is None: + import_str = f"from {application.op.from_} {import_str}" # Import only once if not any( diff --git a/einx/tracer/tracer.py b/einx/tracer/tracer.py index a5eb370..3099a3b 100644 --- a/einx/tracer/tracer.py +++ b/einx/tracer/tracer.py @@ -131,17 +131,18 @@ def __copy__(self): class Import(Tracer): - def __init__(self, module, shorthand=None): + def __init__(self, import_, as_, from_): Tracer.__init__(self, origin="constant") - self.module = module - self.shorthand = shorthand + self.import_ = import_ + self.as_ = as_ + self.from_ = from_ def __call__(self): # Overwrite allowed arguments return apply(self) -def import_(module, shorthand=None): - return Import(module, shorthand)() +def import_(import_, as_=None, from_=None): + return Import(import_, as_, from_)() class MemberAccess(Tracer): diff --git a/test/conftest.py b/test/conftest.py index f76990c..c418557 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,6 +4,7 @@ import einx import threading import multiprocessing +import os tests = [] @@ -121,6 +122,10 @@ def run(result, exception): if importlib.util.find_spec("jax"): + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_force_host_platform_device_count=8" + ) + import jax import jax.numpy as jnp diff --git a/test/test_experimental.py b/test/test_experimental.py new file mode 100644 index 0000000..0d69aec --- /dev/null +++ b/test/test_experimental.py @@ -0,0 +1,63 @@ +import importlib +import einx +import numpy as np + +if importlib.util.find_spec("jax"): + import jax + import jax.numpy as jnp + from jax.sharding import Mesh + + def assert_sharding(x, mesh=None, partition=None): + assert {**x.sharding.mesh.shape} == mesh + assert tuple(x.sharding.spec) == partition + + def test_sharding(): + mesh24 = Mesh(np.asarray(jax.devices("cpu")).reshape(2, 4), axis_names=("d1", "d2")) + mesh42 = Mesh(np.asarray(jax.devices("cpu")).reshape(4, 2), axis_names=("d1", "d2")) + mesh4 = Mesh(np.asarray(jax.devices("cpu"))[:4], axis_names=("d1",)) + + # Pass mesh=jax.devices("cpu") instead of mesh=None since we cannot set + # global device to cpu here + x = jnp.ones((128, 64)) + assert_sharding( + einx.experimental.shard("([d1] a) b", x, mesh=jax.devices("cpu")), {"d1": 8}, ("d1",) + ) + assert_sharding( + einx.experimental.shard("([d1] a) ([d2] b)", x, d2=2, mesh=jax.devices("cpu")), + {"d1": 4, "d2": 2}, + ("d1", "d2"), + ) + assert_sharding( + einx.experimental.shard("([batch] _) ...", x, d2=2, mesh=jax.devices("cpu")), + {"batch": 8}, + ("batch",), + ) + assert_sharding( + einx.experimental.shard("([d1] a) ([d2] b)", x, mesh=mesh24), + {"d1": 2, "d2": 4}, + ("d1", "d2"), + ) + assert_sharding(einx.experimental.shard("([d1] a) b", x, mesh=mesh4), {"d1": 4}, ("d1",)) + assert_sharding( + einx.experimental.shard("b ([d1] a)", x, mesh=mesh4), + {"d1": 4}, + ( + None, + "d1", + ), + ) + assert_sharding( + einx.experimental.shard("a ([d1] b)", x, mesh=mesh42), + {"d1": 4, "d2": 2}, + ( + None, + "d1", + ), + ) + + x = jnp.ones((4, 1024, 1024)) + assert_sharding( + einx.experimental.shard("a ([d2] b) ([d1] c)", x, mesh=mesh42), + {"d1": 4, "d2": 2}, + (None, "d2", "d1"), + )