From 73dd98c0437b93e16e360a0b21dbb5309bf7fc82 Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Tue, 4 Jun 2024 20:54:46 +0200 Subject: [PATCH] Overhaul of docs and readme --- README.md | 29 +- docs/source/api.rst | 16 +- docs/source/conf.py | 2 +- docs/source/faq/backend.rst | 5 +- docs/source/faq/einops.rst | 50 ++- docs/source/faq/flatten.rst | 5 +- docs/source/faq/solver.rst | 18 +- docs/source/faq/universal.rst | 63 +++ docs/source/gettingstarted/cheatsheet.rst | 65 --- docs/source/gettingstarted/commonnnops.rst | 10 +- docs/source/gettingstarted/gpt2.rst | 4 +- docs/source/gettingstarted/introduction.rst | 42 +- .../gettingstarted/tensormanipulation.rst | 404 ------------------ ...tworks.rst => tutorial_neuralnetworks.rst} | 10 +- ...teinnotation.rst => tutorial_notation.rst} | 107 +++-- docs/source/gettingstarted/tutorial_ops.rst | 342 +++++++++++++++ .../gettingstarted/tutorial_overview.rst | 35 ++ docs/source/index.rst | 22 +- .../{gettingstarted => more}/gotchas.rst | 13 +- docs/source/{gettingstarted => more}/jit.rst | 2 +- docs/source/more/related.rst | 20 + einx/experimental/op/shard.py | 4 +- einx/op/arange.py | 4 +- einx/op/dot.py | 29 +- einx/op/elementwise.py | 37 +- einx/op/index.py | 29 +- einx/op/rearrange.py | 11 +- einx/op/reduce.py | 32 +- einx/op/solve.py | 22 +- einx/op/util.py | 27 -- einx/op/vmap.py | 16 +- einx/op/vmap_with_axis.py | 21 +- 32 files changed, 715 insertions(+), 781 deletions(-) create mode 100644 docs/source/faq/universal.rst delete mode 100644 docs/source/gettingstarted/cheatsheet.rst delete mode 100644 docs/source/gettingstarted/tensormanipulation.rst rename docs/source/gettingstarted/{neuralnetworks.rst => tutorial_neuralnetworks.rst} (97%) rename docs/source/gettingstarted/{einsteinnotation.rst => tutorial_notation.rst} (67%) create mode 100644 docs/source/gettingstarted/tutorial_ops.rst create mode 100644 docs/source/gettingstarted/tutorial_overview.rst rename docs/source/{gettingstarted => more}/gotchas.rst (85%) rename docs/source/{gettingstarted => more}/jit.rst (96%) create mode 100644 docs/source/more/related.rst diff --git a/README.md b/README.md index ef6d967..8b85962 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,23 @@ -# *einx* - Tensor Operations in Einstein-Inspired Notation +# *einx* - Universal Tensor Operations in Einstein-Inspired Notation [![pytest](https://github.com/fferflo/einx/actions/workflows/run_pytest.yml/badge.svg)](https://github.com/fferflo/einx/actions/workflows/run_pytest.yml) [![Documentation](https://img.shields.io/badge/documentation-link-blue.svg)](https://einx.readthedocs.io) [![PyPI version](https://badge.fury.io/py/einx.svg)](https://badge.fury.io/py/einx) [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) -einx is a Python library that allows formulating many tensor operations as concise expressions using Einstein notation. It is inspired by [einops](https://github.com/arogozhnikov/einops), but follows a novel and unique design: +einx is a Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax and Tensorflow. The design is based on the following principles: -- Fully composable and powerful Einstein expressions with `[]`-notation. -- Support for many tensor operations (`einx.{sum|max|where|add|dot|flip|get_at|...}`) with Numpy-like naming. -- Easy integration and mixing with existing code. Supports tensor frameworks Numpy, PyTorch, Tensorflow, Jax and others. -- Just-in-time compilation of all operations into regular Python functions using Python's [`exec()`](https://docs.python.org/3/library/functions.html#exec). +1. **Provide a set of elementary tensor operations** following Numpy-like naming: `einx.{sum|max|where|add|dot|flip|get_at|...}` +2. **Use einx notation to express vectorization of the elementary operations.** einx notation is inspired by [einops](https://github.com/arogozhnikov/einops), but introduces several novel concepts such as `[]`-bracket notation and full composability that allow using it as a universal language for tensor operations. -*Optional:* - -- Generalized neural network layers in Einstein notation. Supports PyTorch, Flax, Haiku, Equinox and Keras. +einx can be integrated and mixed with existing code seamlessly. All operations are [just-in-time compiled](https://einx.readthedocs.io/en/latest/more/jit.html) into regular Python functions using Python's [exec()](https://docs.python.org/3/library/functions.html#exec) and invoke operations from the respective framework. **Getting started:** -* [Tutorial](https://einx.readthedocs.io/en/latest/gettingstarted/einsteinnotation.html) -* [Example: GPT-2/ Mamba with einx](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html) -* [How does einx compare with einops?](https://einx.readthedocs.io/en/latest/faq/einops.html) +* [Tutorial](https://einx.readthedocs.io/en/latest/gettingstarted/tutorial_overview.html) +* [Example: GPT-2 with einx](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html) +* [How is einx different from einops?](https://einx.readthedocs.io/en/latest/faq/einops.html) +* [How is einx notation universal?](https://einx.readthedocs.io/en/latest/faq/universal.html) * [API reference](https://einx.readthedocs.io/en/latest/api.html) ## Installation @@ -50,8 +47,6 @@ einx.get_at("b [h w] c, b i [2] -> b i c", x, y) # Gather values at coordinates einx.rearrange("b (q + k) -> b q, b k", x, q=2) # Split einx.rearrange("b c, 1 -> b (c + 1)", x, [42]) # Append number to each channel -einx.dot("... [c1->c2]", x, y) # Matmul = linear map from c1 to c2 channels - # Apply custom operations: einx.vmap("b [s...] c -> b c", x, op=np.mean) # Global mean-pooling einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) # Matmul @@ -84,7 +79,7 @@ einx.dot("b [s...->s2] c", x, w) # - Spatial mixing as in MLP- See [Common neural network ops](https://einx.readthedocs.io/en/latest/gettingstarted/commonnnops.html) for more examples. -#### Deep learning modules +#### Optional: Deep learning modules ```python import einx.nn.{torch|flax|haiku|equinox|keras} as einn @@ -105,7 +100,7 @@ spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2) droppath = einn.Dropout("[b] ...", drop_rate=0.2) ``` -See `examples/train_{torch|flax|haiku|equinox|keras}.py` for example trainings on CIFAR10, [GPT-2](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html) and [Mamba](https://github.com/fferflo/weightbridge/blob/master/examples/mamba2flax.py) for working example implementations of language models using einx, and [Tutorial: Neural networks](https://einx.readthedocs.io/en/latest/gettingstarted/neuralnetworks.html) for more details. +See `examples/train_{torch|flax|haiku|equinox|keras}.py` for example trainings on CIFAR10, [GPT-2](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html) and [Mamba](https://github.com/fferflo/weightbridge/blob/master/examples/mamba2flax.py) for working example implementations of language models using einx, and [Tutorial: Neural networks](https://einx.readthedocs.io/en/latest/gettingstarted/tutorial_neuralnetworks.html) for more details. #### Just-in-time compilation @@ -122,4 +117,4 @@ def op0(i0): return x1 ``` -See [Just-in-time compilation](https://einx.readthedocs.io/en/latest/gettingstarted/jit.html) for more details. \ No newline at end of file +See [Just-in-time compilation](https://einx.readthedocs.io/en/latest/more/jit.html) for more details. \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api.rst index a5ff4c7..54dfbe9 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -2,27 +2,16 @@ einx API ######## -Abstractions -============ - Main ---- .. autofunction:: einx.rearrange .. autofunction:: einx.vmap_with_axis .. autofunction:: einx.vmap -.. autofunction:: einx.dot - -Partial specializations ------------------------ - .. autofunction:: einx.reduce .. autofunction:: einx.elementwise .. autofunction:: einx.index -Numpy-like functions -==================== - Reduction operations -------------------- @@ -76,6 +65,11 @@ Miscellaneous operations .. autofunction:: einx.log_softmax .. autofunction:: einx.arange +General dot-product +------------------- + +.. autofunction:: einx.dot + Deep Learning Modules ===================== diff --git a/docs/source/conf.py b/docs/source/conf.py index b34f024..d297183 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,7 +7,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "einx" -copyright = "2023, Florian Fervers" +copyright = "2024, Florian Fervers" author = 'Florian Fervers' # -- General configuration --------------------------------------------------- diff --git a/docs/source/faq/backend.rst b/docs/source/faq/backend.rst index 4989731..05b44cd 100644 --- a/docs/source/faq/backend.rst +++ b/docs/source/faq/backend.rst @@ -1,8 +1,7 @@ How does einx support different tensor frameworks? ################################################## -einx provides interfaces for tensor frameworks in the ``einx.backend.*`` namespace. For each framework, a backend object is implemented that -provides a numpy-like interface for all necessary tensor operations using the framework's own functions. Every einx function accepts a ``backend`` argument +einx provides interfaces for tensor frameworks in the ``einx.backend.*`` namespace. einx functions accept a ``backend`` argument that defines which backend to use for the computation. For ``backend=None`` (the default case), the backend is implicitly determined from the input tensors. .. code:: python @@ -22,7 +21,7 @@ Numpy cannot be mixed in the same operation. einx.dot("a [c1->c2]", x, jnp.asarray(y)) # Uses jax einx.dot("a [c1->c2]", torch.from_numpy(x), jnp.asarray(y)) # Raises exception -Unkown tensor objects and python sequences are converted using ``np.asarray`` and used as numpy backend tensors. +Unkown tensor objects and python sequences are converted to tensors using calls from the respective backend if possible (e.g. ``np.asarray``, ``torch.asarray``). .. code:: python diff --git a/docs/source/faq/einops.rst b/docs/source/faq/einops.rst index d600616..f9916c1 100644 --- a/docs/source/faq/einops.rst +++ b/docs/source/faq/einops.rst @@ -1,23 +1,49 @@ -How does einx compare with einops? +How is einx different from einops? ################################## -einx uses Einstein notation that is inspired by and compatible with the notation used in `einops `_, -but follows a novel design: +einx uses Einstein-inspired notation that is based on and compatible with the notation used in `einops `_, +but introduces several novel concepts that allow using it as a universal language for tensor operations: -* Full composability of Einstein expressions: Axis lists, compositions, ellipses and concatenations can be nested arbitrarily (e.g. ``(a b)...`` or +* Introduction of ``[]``-notation to express vectorization of elementary operations (see :ref:`Bracket notation `). +* Ellipses repeat the preceding expression rather than an anonymous axis. This allows expressing multi-dimensional operations more concisely + (e.g. ``(a b)...`` or ``b (s [ds])... c``) +* Full composability of expressions: Axis lists, compositions, ellipses, brackets and concatenations can be nested arbitrarily (e.g. ``(a b)...`` or ``b (1 + (s...)) c``). -* Introduction of ``[]``-notation that allows expressing vectorization in an intuitive and concise way, similar to the ``axis`` argument in Numpy functions (see :ref:`Bracket notation `). -* Introduction of concatenations as first-class expressions in Einstein notation. +* Introduction of concatenations as first-class expressions. -When combined, these features allow for a concise and expressive formulation of a large variety of tensor operations. +The library provides the following additional features based on the einx notation: -The einx library provides the following additional features: +* Support for many more tensor operations, for example: + + .. code:: + + einx.flip("... (g [c])", x, c=2) # Flip pairs of values + einx.add("a, b -> a b", x, y) # Outer sum + einx.get_at("b [h w] c, b i [2] -> b i c", x, indices) # Gather values + einx.softmax("b q [k] h", attn) # Part of attention operation + +* Simpler notation for existing tensor operations: + + .. code:: + + einx.sum("a [b]", x) + # same op as + einops.reduce(x, "a b -> a", reduction="sum") + + einx.mean("b (s [ds])... c", x, ds=2) + # einops does not support named ellipses. Alternative for 2D case: + einops.reduce(x, "b (h h2) (w w2) c -> b h w c", reduction="mean", h2=2, w2=2) * Full support for rearranging expressions in all operations (see :doc:`How does einx handle input and output tensors? `). -* ``einx.vmap`` and ``einx.vmap_with_axis`` allow applying arbitrary operations using Einstein notation. -* Specializations provide ease-of-use for main abstractions using Numpy naming convention, e.g. ``einx.sum`` and ``einx.multiply``. -* Several generalized deep learning modules in the ``einx.nn.*`` namespace (see :doc:`Tutorial: Neural networks `). -* Support for inspecting the backend calls made by einx in index-based notation (see :doc:`Just-in-time compilation `). + + .. code:: + + einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=16) + # Axis composition not supported e.g. in einops.einsum. + +* ``einx.vmap`` and ``einx.vmap_with_axis`` allow applying arbitrary operations using einx notation. +* Several generalized deep learning modules in the ``einx.nn.*`` namespace (see :doc:`Tutorial: Neural networks `). +* Support for inspecting the backend calls made by einx in index-based notation (see :doc:`Just-in-time compilation `). A non-exhaustive comparison of operations expressed in einx-notation and einops-notation: diff --git a/docs/source/faq/flatten.rst b/docs/source/faq/flatten.rst index 98af700..5ddbbb6 100644 --- a/docs/source/faq/flatten.rst +++ b/docs/source/faq/flatten.rst @@ -1,7 +1,7 @@ How does einx handle input and output tensors? ############################################## -einx functions accept an operation string that specifies Einstein expressions for the input and output tensors. The expressions potentially +einx functions accept an operation string that specifies einx expressions for the input and output tensors. The expressions potentially contain nested compositions and concatenations that prevent the backend functions from directly accessing the required axes. To resolve this, einx first flattens the input tensors in each operation such that they contain only a flat list of axes. After the backend operation is applied, the resulting tensors are unflattened to match the requested output expressions. @@ -22,9 +22,6 @@ Concatenations are flattened by splitting the input tensor into multiple tensors # same as np.split(x, [10], axis=0) -Using a concatenated tensor as input performs the same operation as passing the split tensors as separate inputs to the operation. einx handles -expressions with multiple nested compositions and concatenations gracefully. - After the operation is applied to the flattened tensors, the results are reshaped and concatenated and missing axes are inserted and broadcasted to match the requested output expressions. diff --git a/docs/source/faq/solver.rst b/docs/source/faq/solver.rst index daee7eb..7832777 100644 --- a/docs/source/faq/solver.rst +++ b/docs/source/faq/solver.rst @@ -1,16 +1,16 @@ -How does einx parse Einstein expressions? -######################################### +How does einx parse expressions? +################################ Overview -------- -einx functions accept a operation string that specifies the shapes of input and output tensors and the requested operation in Einstein notation. For example: +einx functions accept a operation string that specifies the shapes of input and output tensors and the requested operation in einx notation. For example: .. code:: einx.mean("b (s [r])... c -> b s... c", x, r=4) # Mean-pooling with stride 4 -To identify the backend operations that are required to execute this statement, einx first parses the operation string and determines an *Einstein expression tree* +To identify the backend operations that are required to execute this statement, einx first parses the operation string and determines an *expression tree* for each input and output tensor. The tree represents a full description of the tensor's shape and axes marked with brackets. The nodes represent different types of subexpressions such as axis lists, compositions, ellipses and concatenations. The leaves of the tree are the named and unnamed axes of the tensor. The expression trees are used to determine the required rearranging steps and axes along which backend operations are applied. @@ -20,12 +20,12 @@ einx uses a multi-step process to convert expression strings into expression tre * **Stage 0**: Split the operation string into separate expression strings for each tensor. * **Stage 1**: Parse the expression string for each tensor and return a (stage-1) tree of nodes representing the nested subexpressions. * **Stage 2**: Expand all ellipses by repeating the respective subexpression, resulting in a stage-2 tree. -* **Stage 3**: Determine a value for each axis (i.e. the axis length) using the provided constraints, resulting in a stage-3 tree, i.e. the final Einstein expression tree. +* **Stage 3**: Determine a value for each axis (i.e. the axis length) using the provided constraints, resulting in a stage-3 tree, i.e. the final expression tree. For a given operation string and signature of input arguments, the required backend operations are traced into graph representation and just-in-time compiled using Python's `exec() `_. Every subsequent call with the same signature will reuse the cached function and therefore incur no additional overhead other than for cache lookup (see -:doc:`Just-in-time compilation `). +:doc:`Just-in-time compilation `). Stage 0: Splitting the operation string --------------------------------------- @@ -53,7 +53,7 @@ Another example of shorthand notation in :func:`einx.dot`: # same as einx.dot("a [b->c]", x, y) -See :doc:`Tutorial: Tensor manipulation ` and the documentation of the respective functions for allowed shorthand notation. +See :doc:`Tutorial: Operations ` and the documentation of the respective functions for allowed shorthand notation. Stage 1: Parsing the expression string -------------------------------------- @@ -67,7 +67,7 @@ subexpressions: Stage-1 tree for ``b (s [r])... c``. -This includes several semantic checks, e.g. to ensure that axis names do not appear more than once per expression. +This includes semantic checks, e.g. to ensure that axis names do not appear more than once per expression. Stage 2: Expanding ellipses --------------------------- @@ -103,7 +103,7 @@ Stage 3: Determining axis values -------------------------------- In the last step, the values of all axes (i.e. their lengths) are determined using the constraints provided by the input tensors and additional parameters. For example, the above -expression with an input tensor of shape ``(2, 4, 8, 3)`` and additional constraint ``r=4`` results in the following final Einstein expression tree: +expression with an input tensor of shape ``(2, 4, 8, 3)`` and additional constraint ``r=4`` results in the following final expression tree: .. figure:: /images/stage3-tree.png :height: 240 diff --git a/docs/source/faq/universal.rst b/docs/source/faq/universal.rst new file mode 100644 index 0000000..92f4098 --- /dev/null +++ b/docs/source/faq/universal.rst @@ -0,0 +1,63 @@ +How is einx notation universal? +############################### + +To address this question, let's first look at how tensor operations are commonly expressed in existing tensor frameworks. + +Classical notation +------------------ + +Tensor operations can be dissected into two distinct components: + +1. An **elementary operation** that is performed. + + * Example: ``np.sum`` computes a sum-reduction. + +2. A division of the input tensor into sub-tensors. The elementary operation is applied to each sub-tensor independently. We refer to this as **vectorization**. + + * Example: Sub-tensors in ``np.sum`` span the dimensions specified by the ``axis`` parameter. The sum-reduction is repeated over all other dimensions. + +In common tensor frameworks like Numpy, PyTorch, Tensorflow or Jax, different elementary operations are implemented with different vectorization rules. +For example, to express vectorization + +* ``np.sum`` uses the ``axis`` parameter, +* ``np.add`` follows `implicit broadcasting rules `_ (e.g. in combination with ``np.newaxis``), and +* ``np.matmul`` provides `an implicit and custom set of rules `_. + +Furthermore, an elementary operation is sometimes implemented in multiple APIs in order to offer vectorization rules for different use cases. +For example, the retrieve-at-index operation can be implemented in PyTorch using ``tensor[coords]``, ``torch.gather``, ``torch.index_select``, ``torch.take``, +``torch.take_along_dim``, which conceptually apply the same low-level operation, but follow different vectorization rules. +Still, these interfaces sometimes do not cover all desirable use cases. + +einx notation +------------- + +einx provides an interface to tensor operations where vectorization is expressed entirely using einx notation, and each elementary operation +is represented by exactly one API. The einx notation is: + +* **Consistent**: The same type of notation is used for all elementary operations. +* **Unique**: Each elementary operation is represented by exactly one API. +* **Complete**: Any operation that can be expressed with existing vectorization tools such as + `jax.vmap `_ can also be expressed in einx notation. + +The following table shows examples of universal einx functions that implement the same elementary operations as a variety of existing tensor operations: + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + + * - einx API + - Classical API + * - ``einx.get_at`` + - ``torch.gather`` ``torch.index_select`` ``torch.take`` ``torch.take_along_dim`` ``tf.gather`` ``tf.gather_nd`` ``tensor[coords]`` + * - ``einx.dot`` (similar to einsum) + - ``np.matmul`` ``np.dot`` ``np.tensordot`` ``np.inner`` ``np.outer`` + * - ``einx.add`` + - ``np.add`` with ``np.newaxis`` + * - ``einx.rearrange`` + - ``np.reshape`` ``np.transpose`` ``np.squeeze`` ``np.expand_dims`` ``tensor[np.newaxis]`` ``np.stack`` ``np.hstack`` ``np.concatenate`` + * - ``einx.flip`` + - ``np.flip`` ``np.fliplr`` ``np.flipud`` + +While elementary operations and vectorization are decoupled conceptually to provide a universal API, the implementation of the operations +in the respective backend do not necessarily follow the same decoupling. For example, a matrix multiplication is represented as a vectorized +dot-product in einx (using ``einx.dot``), but still invokes an efficient matmul operation instead of a vectorized evaluation of the dot product. \ No newline at end of file diff --git a/docs/source/gettingstarted/cheatsheet.rst b/docs/source/gettingstarted/cheatsheet.rst deleted file mode 100644 index d991dd6..0000000 --- a/docs/source/gettingstarted/cheatsheet.rst +++ /dev/null @@ -1,65 +0,0 @@ -Cheatsheet -########## - -**Simple tensor operations in Einstein notation and index-based notation**. - -.. list-table:: - :widths: 10 48 45 - :header-rows: 1 - - * - - - einx - - Numpy - * - Transpose - - ``einx.rearrange("a b c -> a c b", x)`` - - ``np.transpose(x, (0, 2, 1))`` - * - Compose - - ``einx.rearrange("a b c -> a (b c)", x)`` - - ``np.reshape(x, (2, -1))`` - * - Decompose - - ``einx.rearrange("a (b c) -> a b c", x, b=3)`` - - ``np.reshape(x, (2, 3, 4))`` - * - Concatenate - - ``einx.rearrange("a, b -> (a + b)", x, y)`` - - ``np.concatenate([x, y], axis=0)`` - * - Split - - ``einx.rearrange("(a + b) -> a, b", x, a=5)`` - - ``np.split(x, [5], axis=0)`` - * - Reduce - - ``einx.sum("[a] ...", x)`` - - ``np.sum(x, axis=0)`` - * - - - ``einx.sum("... [a]", x)`` - - ``np.sum(x, axis=-1)`` - * - - - ``einx.sum("a [...]", x)`` - - ``np.sum(x, axis=tuple(range(1, x.ndim)))`` - * - - - ``einx.sum("[...] a", x)`` - - ``np.sum(x, axis=tuple(range(0, x.ndim - 1)))`` - * - Elementwise - - | ``einx.add("a b, b -> a b", x, y)`` - | ``einx.add("a b, b", x, y)`` - | ``einx.add("a [b]", x, y)`` - - ``x + y[np.newaxis, :]`` - * - - - | ``einx.add("a b, a -> a b", x, y)`` - | ``einx.add("a b, a", x, y)`` - | ``einx.add("[a] b", x, y)`` - - ``x + y[:, np.newaxis]`` - * - Dot - - | ``einx.dot("a b, b c -> a c", x, y)`` - | ``einx.dot("a [b] -> a [c]", x, y)`` - | ``einx.dot("a [b->c]", x, y)`` - - ``np.einsum("ab,bc->ac", x, y)`` - * - - - | ``einx.dot("a b, a b c -> a c", x, y)`` - | ``einx.dot("[a b] -> [a c]", x, y)`` - | ``einx.dot("[a b -> a c]", x, y)`` - - ``np.einsum("ab,abc->ac", x, y)`` - * - Indexing - - ``einx.get_at("[h w] c, p [2] -> p c", x, y)`` - - ``x[y[:, 0], y[:, 1]]`` - * - - - ``einx.set_at("[h] c, p, p c -> [h] c", x, y, z)`` - - ``x[y] = z`` diff --git a/docs/source/gettingstarted/commonnnops.rst b/docs/source/gettingstarted/commonnnops.rst index b413ce9..9da124e 100644 --- a/docs/source/gettingstarted/commonnnops.rst +++ b/docs/source/gettingstarted/commonnnops.rst @@ -1,7 +1,7 @@ Example: Common neural network operations ######################################### -einx allows formulating many common operations of deep learning models in a concise and elegant way. This page provides a few examples. +einx allows formulating many common operations of deep learning models as concise expressions. This page provides a few examples. .. code-block:: python @@ -71,7 +71,7 @@ This can similarly be achieved using the ``einn.Norm`` layer: .. code-block:: python - import einx.nn.{torch|flax|haiku} as einn + import einx.nn.{torch|flax|haiku|...} as einn x = einn.Norm("... [c]")(x) Reference: `Layer normalization explained `_ @@ -83,9 +83,9 @@ Compute multihead attention for the queries ``q``, keys ``k`` and values ``v`` w .. code-block:: python - attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8) - attn = einx.softmax("b q [k] h", attn) - x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v) + a = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8) + a = einx.softmax("b q [k] h", a) + x = einx.dot("b q k h, b k (h c) -> b q (h c)", a, v) Reference: `Multi-Head Attention `_ diff --git a/docs/source/gettingstarted/gpt2.rst b/docs/source/gettingstarted/gpt2.rst index 1ab78c8..1efa84d 100644 --- a/docs/source/gettingstarted/gpt2.rst +++ b/docs/source/gettingstarted/gpt2.rst @@ -46,7 +46,7 @@ layer normalization at the beginning of the residual block: # Apply causal mask mask = jnp.tril(jnp.ones((q.shape[1], q.shape[1]), dtype=bool)) - attn = einx.where("q k, b q k h, ", mask, attn, -jnp.inf) + attn = einx.where("q k, b q k h,", mask, attn, -jnp.inf) # Apply softmax and compute weighted average over the input tokens attn = einx.softmax("b q [k] h", attn) @@ -118,7 +118,7 @@ logits using a linear layer: return x We use tensor factories with ``einn.param`` to construct the word and positional embeddings (see -:doc:`Tutorial: Neural networks `). +:doc:`Tutorial: Neural networks `). With this, we're done with the model definition. Next, we'll define some input data that the model will be applied to and encode it to token representation: diff --git a/docs/source/gettingstarted/introduction.rst b/docs/source/gettingstarted/introduction.rst index 9e3402c..8757484 100644 --- a/docs/source/gettingstarted/introduction.rst +++ b/docs/source/gettingstarted/introduction.rst @@ -5,43 +5,17 @@ Introduction ############ -einx is a Python library that allows formulating many tensor operations as concise expressions using few powerful abstractions. It is inspired by -`einops `_. +einx is a Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax and Tensorflow. +The design is based on the following principles: -*Main features:* +1. **Provide a set of elementary tensor operations** following Numpy-like naming: ``einx.{sum|max|where|add|dot|flip|get_at|...}`` +2. **Use einx notation to express vectorization of the elementary operations.** The notation is inspired by `einops `_, + but introduces several novel concepts such as ``[]``-bracket notation and full composability that allow using it as a universal language for tensor operations. -- Fully composable and powerful Einstein expressions with ``[]``-notation. -- Support for many tensor operations (``einx.{sum|max|where|add|dot|flip|get_at|...}``) with Numpy-like naming. -- Easy integration and mixing with existing code. Supports tensor frameworks Numpy, PyTorch, Tensorflow, Jax and others. -- Just-in-time compilation of all operations into regular Python functions using Python's `exec() `_. - -*Optional:* - -- Generalized neural network layers in Einstein notation. Supports PyTorch, Flax, Haiku, Equinox and Keras. +einx can be integrated and mixed with existing code seamlessly. All operations are :doc:`just-in-time compiled ` +into regular Python functions using Python's `exec() `_ and invoke operations from the respective framework. **Next steps:** - :doc:`Installation ` -- :doc:`Tutorial: Einstein notation ` -- :doc:`Tutorial: Tensor manipulation ` -- :doc:`Tutorial: Neural networks ` - -Related resources ------------------ - -* `einops `_ -* `einsum in Numpy `_ -* `eindex `_ -* `torchdim `_ -* `einindex `_ -* `einshape `_ -* `einop `_ -* `eingather `_ -* `einshard `_ -* `eins `_ -* `Named axes in PyTorch `_ -* `Named axes in Jax `_ -* `Named axes in Penzai `_ -* `Dex `_ -* `Named Tensor Notation `_ -* `Tensor Considered Harmful `_ \ No newline at end of file +- :doc:`Tutorial ` diff --git a/docs/source/gettingstarted/tensormanipulation.rst b/docs/source/gettingstarted/tensormanipulation.rst deleted file mode 100644 index 492f63b..0000000 --- a/docs/source/gettingstarted/tensormanipulation.rst +++ /dev/null @@ -1,404 +0,0 @@ -Tutorial: Tensor manipulation -############################# - -Overview --------- - -einx supports a wide variety of tensor operations. In each function, Einstein expressions are used to specify how the operation should be performed. -einx internally parses the expressions to determine the required steps and forwards computation to the respective backend, e.g. by -calling `np.reshape `_, -`np.transpose `_ or -`np.sum `_ with the appropriate arguments. - -The most basic operation in einx is :func:`einx.rearrange` which transforms tensors between Einstein expressions by reshaping, permuting axes, inserting new -broadcasted axes, concatenating and splitting as required. All other functions support the same rearranging of expressions, but additionally perform some -operation on the tensors' values. For ease-of-use, most of these follow a Numpy-like naming convention: - -* ``einx.{sum|prod|mean|any|all|max|min|count_nonzero|...}`` (see :func:`einx.reduce`). -* ``einx.{add|multiply|logical_and|where|equal|...}`` (see :func:`einx.elementwise`). -* ``einx.{flip|roll|softmax|...}`` (see :func:`einx.vmap_with_axis`). -* ``einx.{get_at|set_at|add_at|...}`` (see :func:`einx.index`). - -Most functions in einx are specializations of the two main abstractions :func:`einx.vmap` and :func:`einx.vmap_with_axis` which allow applying arbitrary operations -using Einstein notation. Each class of functions additionally introduces shorthand notations that allow for a concise and expressive formulation of the respective -operations. - -This tutorial gives an overview of most functions and their usage. For a complete list of available functions, see the :doc:`API reference `. - -Rearranging ------------ - -The function :func:`einx.rearrange` transforms tensors between Einstein expressions by determining and applying the required backend operations. For example: - ->>> x = np.ones((4, 256, 17)) ->>> y, z = einx.rearrange("b (s p) (c + 1) -> (b s) p c, (b p) s 1", x, p=8) ->>> y.shape, z.shape -((128, 8, 16), (32, 32, 1)) - -Using :func:`einx.rearrange` often produces more readable and concise code than specifying backend operations in index-based notation directly. The index-based calls can be -inspected using the just-in-time compiled function that einx creates for this expression (see :doc:`Just-in-time compilation `): - ->>> print(einx.rearrange("b (s p) (c + 1) -> (b s) p c, (b p) s 1", x, p=8, graph=True)) -import numpy as np -def op0(i0): - x0 = np.reshape(i0, (4, 32, 8, 17)) - x1 = np.reshape(x0[:, :, :, 0:16], (128, 8, 16)) - x2 = np.reshape(x0[:, :, :, 16:17], (4, 32, 8)) - x3 = np.transpose(x2, (0, 2, 1)) - x4 = np.reshape(x3, (32, 32, 1)) - return [x1, x4] - -Reduction ops -------------- - -einx provides a family of functions that reduce tensors along one or more axes. For example: - -.. code:: - - einx.sum("a [b]", x) - # same as - np.sum(x, axis=1) - - einx.mean("a [...]", x) - # same as - np.mean(x, axis=tuple(range(1, x.ndim))) - -These functions are specializations of :func:`einx.reduce` and use backend operations like `np.sum `_, -`np.prod `_ or `np.any `_ as the ``op`` argument: - -.. code:: - - einx.reduce("a [b]", x, op=np.sum) - # same as - einx.sum("a [b]", x) - -In ``einx.sum``, the respective backend is determined implicitly from the input tensor (see :doc:`How does einx support different tensor frameworks? `). - -In the most general case, the operation string represents both input and output expressions, and marks reduced axes using brackets: - ->>> x = np.ones((16, 8, 4)) ->>> einx.sum("a [b] c -> a c", x).shape -(16,) - -:func:`einx.reduce` supports shorthand notation as follows. When no brackets are found, brackets are placed implicitly around all axes that do not appear in the output: - -.. code:: - - einx.sum("a b c -> a c", x) # Expands to: "a [b] c -> a c" - -When no output is given, it is determined implicitly by removing marked subexpressions from the input: - -.. code:: - - einx.sum("a [b] c", x) # Expands to: "a [b] c -> a c" - -:func:`einx.reduce` also allows custom reduction operations that accept the ``axis`` argument similar to `np.sum `_: - -.. code:: - - def custom_mean(x, axis): - return np.sum(x, axis=axis) / x.shape[axis] - einx.reduce("a [b] c", x, op=custom_mean) - -:func:`einx.reduce` fully supports Einstein expression rearranging: - ->>> x = np.ones((16, 8)) ->>> einx.prod("a (b [c]) -> b a", x, c=2).shape -(4, 16) - -Element-by-element ops ----------------------- - -einx provides a family of functions that apply element-by-element operations to tensors. For example: - -.. code:: - - einx.add("a b, b -> a b", x, y) - # same as - x + y[np.newaxis, :] - - einx.multiply("a, a b -> a b", x, y) - # same as - x[:, np.newaxis] * y - - einx.subtract("a, (a b) -> b a", x, y) - # requires reshape and transpose in index-based notation - -Internally, the inputs are rearranged such that the operation can be applied using `Numpy broadcasting rules `_. -These functions are specializations of :func:`einx.elementwise` and use backend operations like `np.add `_, -`np.logical_and `_ and `np.where `_ -as the ``op`` argument: - -.. code:: - - einx.elementwise("a b, b -> a b", x, y, op=np.add) - # same as - einx.add("a b, b -> a b", x, y) - -In the most general case, the operation string of :func:`einx.elementwise` represents all input and output expressions explicitly: - ->>> x = np.ones((16, 8)) ->>> y = np.ones((16,)) ->>> einx.add("a b, a -> a b", x, y).shape -(16, 8) - -The output is determined implicitly if one of the input expressions contains the named axes of all other inputs and if this choice is unique: - -.. code:: - - einx.add("a b, a", x, y) # Expands to: "a b, a -> a b" - - einx.where("b a, b, a", x, y, z) # Expands to "b a, b, a -> b a" - - einx.subtract("a b, b a", x, y) # Raises an exception - - einx.add("a b, a b", x, y) # Expands to: "a b, a b -> a b" - -Bracket notation can be used to indicate that the second input is a subexpression of the first: - -.. code:: - - einx.add("a [b]", x, y) # Expands to: "a b, b" - -:func:`einx.elementwise` fully supports Einstein expression rearranging: - ->>> x = np.ones((16, 16, 32)) ->>> bias = np.ones((4,)) ->>> einx.add("b... (g [c])", x, bias).shape -(16, 16, 32) - -Indexing ops ------------- - -einx provides a family of functions that perform multi-dimensional indexing and update/retrieve values from tensors at specific coordinates: - -.. code:: - - image = np.ones((256, 256, 3)) - coordinates = np.ones((100, 2), dtype=np.int32) - updates = np.ones((100, 3)) - - # Retrieve values at specific locations in an image - y = einx.get_at("[h w] c, i [2] -> i c", image, coordinates) - # same as - y = image[coordinates[:, 0], coordinates[:, 1]] - - # Update values at specific locations in an image - y = einx.set_at("[h w] c, i [2], i c -> [h w] c", image, coordinates, updates) - # same as - image[coordinates[:, 0], coordinates[:, 1]] = updates - y = image - -Brackets in the first input indicate axes that are indexed, and a single bracket in the second input indicates the coordinate axis. The length of the coordinate axis should equal -the number of indexed axes in the first input. Coordinates can also be passed in separate tensors: - -.. code:: - - coordinates_x = np.ones((100,), dtype=np.int32) - coordinates_y = np.ones((100,), dtype=np.int32) - - y = einx.get_at("[h w] c, i, i -> i c", image, coordinates_x, coordinates_y) - -Indexing functions are specializations of :func:`einx.index` and fully support Einstein expression rearranging: - -.. code:: - - einx.add_at("b ([h w]) c, ([2] b) i, c i -> c [h w] b", image, coordinates, updates) - -Vectorization -------------- - -Both :func:`einx.reduce` and :func:`einx.elementwise` are adaptations of :func:`einx.vmap_with_axis`. The purpose of :func:`einx.vmap_with_axis` -is to augment backend functions providing a numpy-like interface (e.g. ``np.sum``) such that they can be called using Einstein notation. -For exmaple, :func:`einx.sum` wraps ``np.sum`` using :func:`einx.vmap_with_axis`: - -.. code:: - - y = einx.sum("a [b]", x) - # internally calls - y = np.sum(x, axis=1) - -Functions such as ``np.sum`` can be used with :func:`einx.vmap_with_axis` if they accept the ``axis`` argument (or work on scalars) -and follow `Numpy broadcasting rules `_ for multiple inputs. - -The ``axis`` argument specifies axes that the operation is applied to, and the operation is repeated implicitly over all other dimensions. -In the above example, the sum is computed over elements in a row, and this is repeated for all rows. - -A naive implementation without ``np.sum`` could simply loop over the first dimension manually to perform the same operation: - -.. code:: - - for r in range(x.shape[0]): - y[r] = sum(x[r, :]) - -However, since Python loops are notoriously slow, Numpy provides the highly optimized *vectorized* implementation ``np.sum`` that allows specifying which dimensions to apply the operation -to, and which dimensions to vectorize/ "loop" over. - -The bracket notation in Einstein expressions serves a similar purpose as the ``axis`` parameter: Operations are applied to -axes that are marked with ``[]``, and other axes are vectorized over. :func:`einx.vmap_with_axis` takes care of vectorization by -rearranging the inputs and outputs as required and determining the correct ``axis`` argument to pass to the backend function. This allows -applying operations to tensors with arbitrary Einstein expressions: - -.. code:: - - y = einx.sum("a ([b] c)", x, c=2) - # cannot be expressed in a single call to np.sum - y = np.sum(x, axis="?") - -:func:`einx.vmap` allows for more general vectorization than :func:`einx.vmap_with_axis` by applying arbitrary functions in vectorized form. Consider a function that accepts two tensors -and computes the mean and max: - -.. code:: - - def op(x, y): # c, d -> 2 - return np.stack([np.mean(x), np.max(y)]) - -This function can be vectorized over a batch dimension as follows: - ->>> x = np.ones((4, 16)) ->>> y = np.ones((4, 8)) ->>> einx.vmap("b [c], b [d] -> b [2]", x, y, op=op).shape -(4, 2) - -:func:`einx.vmap` takes care of vectorization automatically such that the arguments arriving at ``op`` always match the marked subexpressions in the inputs. Analogously, the return -value of ``op`` should match the marked subexpressions in the output. :func:`einx.vmap` is implemented using efficient automatic vectorization in the respective backend (e.g. -`jax.vmap `_, `torch.vmap `_). - -.. note:: - - einx implements a simple ``vmap`` function for the Numpy backend for testing/ debugging purposes using a Python loop. - -Analogous to other einx functions, :func:`einx.vmap` fully supports Einstein expression rearranging: - ->>> x = np.ones((4, 16)) ->>> y = np.ones((5, 8 * 4)) ->>> einx.vmap("b1 [c], b2 ([d] b1) -> [2] b1 b2", x, y, op=op).shape -(2, 4, 5) - -Since most backend operations that accept an ``axis`` argument operate on the entire input tensor when ``axis`` is not given, :func:`einx.vmap_with_axis` can often -analogously be expressed using :func:`einx.vmap`: - ->>> x = np.ones((4, 16)) ->>> einx.vmap_with_axis("a [b] -> a", x, op=np.sum).shape -(4,) ->>> einx.vmap ("a [b] -> a", x, op=np.sum).shape -(4,) - ->>> x = np.ones((4, 16)) ->>> y = np.ones((4,)) ->>> einx.vmap_with_axis("a b, a -> a b", x, y, op=np.add).shape -(4, 16) ->>> einx.vmap ("a b, a -> a b", x, y, op=np.add).shape -(4, 16) - -:func:`einx.vmap` provides more general vectorization capabilities than :func:`einx.vmap_with_axis`, but might in some cases be slower if the latter relies on a -specialized implementation. - -Composability of ``->`` and ``,`` ---------------------------------- - -The operators ``->`` and ``,`` that delimit input and output expressions can optionally be composed with other Einstein operations. If -they appear within a nested expression, the expression is expanded -`according to distributive law `_ such that ``->`` and ``,`` appear only at the root -of the expression tree. For example: - -.. code:: - - einx.vmap("a [b -> c]", x, op=..., c=16) - # expands to - einx.vmap("a [b] -> a [c]", x, op=..., c=16) - - einx.get_at("b p [i,->]", x, y) - # expands to - einx.get_at("b p [i], b p -> b p", x, y) - -General dot-product -------------------- - -The function :func:`einx.dot` computes general dot-products similar to `np.einsum `_. It represents a special case -of vectorization since matrix multiplications using ``einsum`` are highly optimized in the respective backends. - -In the most general case, the operation string is similar to that of ``einsum``. The inputs and outputs expressions are specified explicitly, and axes that appear in the input, but -not the output are reduced via a dot-product: - ->>> # Matrix multiplication between x and y ->>> x = np.ones((4, 16)) ->>> y = np.ones((16, 8)) ->>> einx.dot("a b, b c -> a c", x, y).shape -(4, 8) - -.. note:: - - ``einx.dot`` is not called ``einx.einsum`` despite providing einsum-like functionality to avoid confusion with ``einx.sum``. The name is - motivated by the fact that the function computes a generalized dot-product, and is in line with expressing the same operation using :func:`einx.vmap`: - - .. code:: - - einx.dot("a b, b c -> a c", x, y) - einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) - -:func:`einx.dot` fully supports Einstein expression rearranging: - ->>> # Simple grouped linear layer ->>> x = np.ones((20, 16)) ->>> w = np.ones((8, 4)) ->>> einx.dot("b (g c1), c1 c2 -> b (g c2)", x, w, g=2).shape -(20, 8) - -The graph representation shows that the inputs and output are rearranged as required and the dot-product is forwarded to the ``einsum`` function of the backend: - ->>> print(einx.dot("b (g c1), c1 c2 -> b (g c2)", x, w, g=2, graph=True)) -import numpy as np -def op0(i0, i1): - x0 = np.reshape(i0, (20, 2, 8)) - x1 = np.einsum("abc,cd->abd", x0, i1) - x2 = np.reshape(x1, (20, 8)) - return x2 - -Shorthand notation in :func:`einx.dot` is supported as follows. When given two input tensors, the expression of the second input is determined implicitly by marking -its components in the input and output expression: - -.. code:: - - einx.dot("a [b] -> a [c]", x, y) # Expands to: "a b, b c -> a c" - -This dot-product can be interpreted as a linear map that maps from ``b`` to ``c`` channels and is repeated over dimension ``a``, which motivates the -usage of bracket notation in this manner. - -Axes marked multiple times appear only once in the implicit second input expression: - -.. code:: - - einx.dot("[a b] -> [a c]", x, y) # Expands to: "a b, a b c -> a c" - -The graph representation shows that the expression forwarded to the ``einsum`` call is as expected: - ->>> x = np.ones((4, 8)) ->>> y = np.ones((8, 5)) ->>> print(einx.dot("a [b->c]", x, y, graph=True)) -import numpy as np -def op0(i0, i1): - x0 = np.einsum("ab,bc->ac", i0, i1) - return x0 - -.. _lazytensorconstruction: - -Tensor factories ----------------- - -All einx operations also accept tensor factories instead of tensors as arguments. A tensor factory is a function that accepts a ``shape`` -argument and returns a tensor with that shape. This allows deferring the construction of a tensor to the point inside -an einx operation where its shape has been resolved, and avoids having to manually determine the shape in advance: - -.. code:: - - einx.dot("b... c1, c1 c2 -> b... c2", x, lambda shape: np.random.uniform(shape), c2=32) - -In this example, the shape of ``x`` is used by the expression solver to determine the values of ``b...`` and ``c1``. Since the tensor factory provides no shape -constraints to the solver, the remaining axis values have to be specified explicitly, i.e. ``c2=32``. - -Tensor factories are particularly useful in the context of deep learning modules: The shapes of a layer's weights are typically chosen to align with the shapes -of the layer input and outputs (e.g. the number of input channels in a linear layer must match the corresponding axis in the layer's weight matrix). -This can be achieved implicitly by constructing layer weights using tensor factories. - -The following tutorial describes in more detail how this is used in einx to implement deep learning models. \ No newline at end of file diff --git a/docs/source/gettingstarted/neuralnetworks.rst b/docs/source/gettingstarted/tutorial_neuralnetworks.rst similarity index 97% rename from docs/source/gettingstarted/neuralnetworks.rst rename to docs/source/gettingstarted/tutorial_neuralnetworks.rst index 1f8c67e..16be585 100644 --- a/docs/source/gettingstarted/neuralnetworks.rst +++ b/docs/source/gettingstarted/tutorial_neuralnetworks.rst @@ -3,7 +3,7 @@ Tutorial: Neural networks einx provides several neural network layer types for deep learning frameworks (`PyTorch `_, `Flax `_, `Haiku `_, `Equinox `_, `Keras `_) in the ``einx.nn.*`` namespace -based on the functions in ``einx.*``. These layers provide abstractions that can implement a wide variety of deep learning operations using Einstein notation. +based on the functions in ``einx.*``. These layers provide abstractions that can implement a wide variety of deep learning operations using einx notation. The ``einx.nn.*`` namespace is entirely optional, and is imported as follows: .. code:: @@ -92,7 +92,7 @@ The utility of ``einn.param`` comes from providing several useful default argume elif init == "dot": init = nn.initializers.lecun_normal(kwargs["in_axis"], kwargs["out_axis"], kwargs["batch_axis"]) - :func:`einx.dot` additionally determines ``in_axis``, ``out_axis`` and ``batch_axis`` from the Einstein expression and forwards them as optional arguments + :func:`einx.dot` additionally determines ``in_axis``, ``out_axis`` and ``batch_axis`` from the einx expression and forwards them as optional arguments to tensor factories. In this case, they allow ``nn.initializers.lecun_normal`` to determine the fan-in of the layer and choose the initialization accordingly. * **Default argument for** ``name`` @@ -267,8 +267,8 @@ Layers einx provides the layer types ``einn.{Linear|Norm|Dropout}`` that are implemented as outlined above. -**einn.Norm** implements a normalization layer with optional exponential moving average (EMA) over the computed statistics. The first parameter is an Einstein expression for -the axes along which the statistics for normalization are computed. The second parameter is an Einstein expression for the axes corresponding to the bias and scale terms, and +**einn.Norm** implements a normalization layer with optional exponential moving average (EMA) over the computed statistics. The first parameter is an einx expression for +the axes along which the statistics for normalization are computed. The second parameter is an einx expression for the axes corresponding to the bias and scale terms, and defaults to ``b... [c]``. The different sub-steps can be toggled by passing ``True`` or ``False`` for the ``mean``, ``var``, ``scale`` and ``bias`` parameters. The EMA is used only if ``decay_rate`` is passed. @@ -292,7 +292,7 @@ A bias is added corresponding to the marked output expressions, and is disabled spatial_mix2 = einn.Linear("b [s2->s...] c", s=(64, 64)) patch_embed = einn.Linear("b (s [s2->])... [c1->c2]", s2=4, c2=64) -**einn.Dropout** implements a stochastic dropout. The first parameter specifies the shape of the mask in Einstein notation that is applied to the input tensor. +**einn.Dropout** implements a stochastic dropout. The first parameter specifies the shape of the mask in einx notation that is applied to the input tensor. .. code:: diff --git a/docs/source/gettingstarted/einsteinnotation.rst b/docs/source/gettingstarted/tutorial_notation.rst similarity index 67% rename from docs/source/gettingstarted/einsteinnotation.rst rename to docs/source/gettingstarted/tutorial_notation.rst index cead1aa..bf92cbf 100644 --- a/docs/source/gettingstarted/einsteinnotation.rst +++ b/docs/source/gettingstarted/tutorial_notation.rst @@ -1,16 +1,17 @@ -Tutorial: Einstein notation -########################### +Tutorial: Notation +####################### -This tutorial introduces the Einstein notation that is used in einx. It is inspired by and compatible with the notation used in `einops `_, -but follows a novel design based on a full composability of expressions, and the introduction of ``[]``-notation and intuitive shorthands. When combined, these features -allow for a concise and expressive formulation of a large variety of tensor operations. (See :doc:`How does einx compare with einops? ` for a complete list -of differences.) +This tutorial introduces the Einstein-inspired notation that is used in einx. It is based on and +compatible with the notation used in `einops `_, but +introduces several new concepts such as ``[]``-bracket notation, composable ellipses and axis +concatenations. See :doc:`How is einx different from einops? ` for a complete list +of differences. Introduction ------------ -An Einstein expression provides a description of the dimensions of a given tensor. In the simplest case, each dimension is given a unique name (``a``, ``b``, ``c``), and the names -are listed to form an Einstein expression: +An einx expression provides a description of the axes of a given tensor. In the simplest case, each dimension is given a unique name (``a``, ``b``, ``c``), and the names +are listed to form an einx expression: >>> x = np.ones((2, 3, 4)) >>> einx.matches("a b c", x) # Check whether expression matches the tensor's shape @@ -18,21 +19,21 @@ True >>> einx.matches("a b", x) False -One application of Einstein expressions is to formulate tensor operations such as reshaping and permuting axes in an intuitive way. Instead of defining an +einx expressions are used to formulate tensor operations such as reshaping and permuting axes in an intuitive way. Instead of defining an operation in classical index-based notation >>> y = np.transpose(x, (0, 2, 1)) >>> y.shape (2, 4, 3) -we instead provide the input and output expressions in Einstein notation and let einx determine the necessary operations: +we instead provide the input and output expressions in einx notation and let einx determine the necessary operations: >>> y = einx.rearrange("a b c -> a c b", x) >>> y.shape (2, 4, 3) -The purpose of :func:`einx.rearrange` is to map tensors between different Einstein expressions. It does not perform any computation itself, but rather forwards the computation -to the respective backend, e.g. Numpy. +The purpose of :func:`einx.rearrange` is to map tensors between different einx expressions. It does not perform any computation itself, +but rather forwards the computation to the respective backend, e.g. Numpy. To verify that the correct backend calls are made, the just-in-time compiled function that einx invokes for this expression can be printed using ``graph=True``: @@ -48,8 +49,9 @@ The function shows that einx performs the expected call to ``np.transpose``. .. note:: einx traces the backend calls made for a given operation and just-in-time compiles them into a regular Python function using Python's - `exec() `_. When the function is called with the same signature of arguments, the compiled function is reused and - therefore incurs no additional overhead other than for cache lookup (see :doc:`Just-in-time compilation `) + `exec() `_. When the function is called with the same signature of arguments, + the compiled function is reused and therefore incurs no additional overhead other than for cache lookup + (see :doc:`Just-in-time compilation `) .. _axiscomposition: @@ -66,7 +68,7 @@ The composition ``(a b)`` is an axis itself and comprises the subaxes ``a`` and `row-major order `_. This corresponds to ``a`` chunks of ``b`` elements each. The length of the composed axis is the product of the subaxis lengths. -We can use :func:`einx.rearrange` to compose and decompose axes in a tensor by passing the respective Einstein expressions: +We can use :func:`einx.rearrange` to compose and decompose axes in a tensor by passing the respective einx expressions: >>> # Stack 2 chunks of 3 elements into a single dimension with length 6 >>> x = np.ones((2, 3, 4)) @@ -78,8 +80,8 @@ We can use :func:`einx.rearrange` to compose and decompose axes in a tensor by p >>> einx.rearrange("(a b) c -> a b c", x, a=2).shape (2, 3, 4) -Since the decomposition is ambiguous w.r.t. the values of ``a`` and ``b`` (for example ``a=2 b=3`` and ``a=1 b=6`` would be valid), additional constraints have to be passed -to find unique axis values, e.g. ``a=2`` as in the example above. +Since the decomposition is ambiguous w.r.t. the values of ``a`` and ``b`` (for example ``a=2 b=3`` and ``a=1 b=6`` would be valid), +additional constraints have to be passed to find unique axis values, e.g. ``a=2`` as in the example above. Composing and decomposing axes is a cheap operation and e.g. preferred over calling ``np.split``. The graph of these functions shows that it uses a `np.reshape `_ @@ -99,8 +101,8 @@ def op0(i0): .. note:: - See `this great einops tutorial `_ for hands-on illustrations of axis - composition using a batch of images. + See `this great einops tutorial `_ for hands-on + illustrations of axis composition using a batch of images. Axis compositions are used for example to divide the channels of a tensor into equally sized groups (as in multi-headed attention), or to divide an image into patches by decomposing the spatial dimensions (if the image resolution is evenly divisible by the patch size). @@ -120,7 +122,8 @@ The number of repetitions is determined from the rank of the input tensors: >>> einx.matches("a b...", x) # Expands to "a b.0 b.1 b.2" True -Using ellipses e.g. for spatial dimensions often results in simpler and more readable expressions, and allows using the same expression for tensors with different dimensionality: +Using ellipses e.g. for spatial dimensions often results in simpler and more readable expressions, and allows using the same expression +for tensors with different dimensionality: >>> # Divide an image into a list of patches with size p=8 >>> x = np.ones((256, 256, 3), dtype="uint8") @@ -132,8 +135,8 @@ Using ellipses e.g. for spatial dimensions often results in simpler and more rea >>> einx.rearrange("(s p)... c -> (s...) p... c", x, p=8).shape (32768, 8, 8, 8, 3) -This operation requires multiple backend calls in index-based notation that might be difficult to understand on first glance. The einx call on the other hand clearly conveys -the intent of the operation and requires less code: +This operation requires multiple backend calls in index-based notation that might be difficult to understand on first glance. +The einx call on the other hand clearly conveys the intent of the operation and requires less code: >>> print(einx.rearrange("(s p)... c -> (s...) p... c", x, p=8, graph=True)) import numpy as np @@ -143,8 +146,8 @@ def op0(i0): x2 = np.reshape(x1, (1024, 8, 8, 3)) return x2 -In einops-style notation, an ellipsis can only appear once at root level without a preceding expression. To be fully compatible with einops notation, einx implicitly -converts anonymous ellipses by adding an axis in front: +In einops-style notation, an ellipsis always appears at root-level and is anonymous, i.e. does not have a preceding expression. +To be fully compatible with einops notation, einx implicitly converts anonymous ellipses by adding an axis in front: .. code:: @@ -155,7 +158,7 @@ converts anonymous ellipses by adding an axis in front: Unnamed axes ------------ -An *unnamed axis* is a number in the Einstein expression and similar to using a new unique axis name with an additional constraint specifying its length: +An *unnamed axis* is a number in the einx expression and similar to using a new unique axis name with an additional constraint specifying its length: >>> x = np.ones((2, 3, 4)) >>> einx.matches("2 b c", x) @@ -165,7 +168,7 @@ True >>> einx.matches("a 1 c", x) False -Unnamed axes can be used for example as an alternative to ``np.expand_dims``, ``np.squeeze``, ``np.newaxis``, ``np.broadcast_to``: +Unnamed axes is used for example as an alternative to ``np.expand_dims``, ``np.squeeze``, ``np.newaxis``, ``np.broadcast_to``: >>> x = np.ones((2, 1, 3)) >>> einx.rearrange("a 1 b -> 1 1 a b 1 5 6", x).shape @@ -180,18 +183,19 @@ Since each unnamed axis is given a unique name, multiple unnamed axes do not ref Concatenation ------------- -A *concatenation* represents an axis in Einstein notation along which two or more subtensors are concatenated. Using axis concatenations, we can describe operations such as +A *concatenation* represents an axis in einx notation along which two or more subtensors are concatenated. Using axis concatenations, +we can describe operations such as `np.concatenate `_, `np.split `_, `np.stack `_, -`einops.pack and einops.unpack `_ in pure Einstein notation. A concatenation axis is marked with ``+`` and wrapped in parentheses, -and its length is the sum of the subaxis lengths. +`einops.pack and einops.unpack `_ in pure einx notation. A concatenation axis is marked with +``+`` and wrapped in parentheses, and its length is the sum of the subaxis lengths. >>> x = np.ones((5, 4)) >>> einx.matches("(a + b) c", x) True -This can be used for example to concatenate tensors that do not have compatible dimensions: +This is used for example to concatenate tensors that do not have compatible dimensions: >>> x = np.ones((256, 256, 3)) >>> y = np.ones((256, 256)) @@ -224,9 +228,11 @@ Unlike the index-based `np.concatenate `_-based solver to determine the values of named axes in Einstein expressions (see :doc:`How does einx parse Einstein expressions? `). -In many cases, the shapes of the input tensors provide enough constraints to determine the values of all named axes in the solver. For other cases, einx functions accept -``**parameters`` that can be used to specify the values of some or all named axes and provide additional constraints to the solver: +einx uses a `SymPy `_-based solver to determine the values of named axes in Einstein expressions +(see :doc:`How does einx parse expressions? `). +In many cases, the shapes of the input tensors provide enough constraints to determine the values of all named axes in the solver. +For other cases, einx functions accept ``**parameters`` that are used to specify the values of some or all named axes and provide +additional constraints to the solver: .. code:: @@ -242,7 +248,7 @@ In many cases, the shapes of the input tensors provide enough constraints to det Bracket notation ---------------- -einx introduces the ``[]``-notation to denote axes that an operation is applied on. This corresponds to the ``axis`` argument in index-based notation: +einx introduces the ``[]``-notation to denote axes that an operation is applied to. This corresponds to the ``axis`` argument in index-based notation: .. code:: @@ -254,10 +260,14 @@ einx introduces the ``[]``-notation to denote axes that an operation is applied # same as np.sum(x, axis=tuple(range(1, x.ndim))) -The usage of brackets in all einx functions follows a general principle: +In general, brackets define which sub-tensors the given elementary operation is applied to. For example, the expression ``"a [b c] d"`` indicates +that the elementary operation ``einx.sum`` is applied to sub-tensors with shape ``b c`` and vectorized over axes ``a`` and ``d``: -**Brackets mark axes that an operation is applied to, while all other axes -are batch axes that the operation is repeated over.** +.. code:: + + einx.sum ("a [b c] d", x) + # ^^^^^^^^ ^ ^^^^^ ^ + # elementary operation vectorized axis sub-tensor axes vectorized axis Some other examples: @@ -268,7 +278,7 @@ Some other examples: einx.get_at("b [h w] c, b i [2] -> b i c", x, indices) # Gather values einx.softmax("b q [k] h", attn) # Part of attention operation -Bracket notation is fully compatible with expression rearranging and can therefore be placed anywhere inside a nested Einstein expression: +Bracket notation is fully compatible with expression rearranging and can therefore be placed anywhere inside a nested einx expression: >>> # Compute sum over pairs of values along the last axis >>> x = np.ones((2, 2, 16)) @@ -289,7 +299,7 @@ def op0(i0): .. note:: - See :doc:`How does einx handle input and output tensors? ` for details on how operations are applied to tensors with nested Einstein expressions. + See :doc:`How does einx handle input and output tensors? ` for details on how operations are applied to tensors with nested einx expressions. Operations are sensitive to the positioning of brackets, e.g. allowing for flexible ``keepdims=True`` behavior out-of-the-box: @@ -303,5 +313,22 @@ Operations are sensitive to the positioning of brackets, e.g. allowing for flexi In the second example, ``c`` is reduced within the composition ``(c)``, resulting in an empty composition ``()``, i.e. a trivial axis with size 1. -einx provides a wide range of tensor operations that accept arguments in Einstein notation as described in this document. +Composability of ``->`` and ``,`` +--------------------------------- + +The operators ``->`` and ``,`` that delimit input and output expressions in an operation can optionally be composed with the einx expressions themselves. If +they appear within a nested expression, the expression is expanded such that ``->`` and ``,`` appear only at the root +of the expression tree. For example: + +.. code:: + + einx.{...}("a [b -> c]", x) + # expands to + einx.{...}("a [b] -> a [c]", x) + + einx.{...}("b p [i,->]", x, y) + # expands to + einx.{...}("b p [i], b p -> b p", x, y) + +einx provides a wide range of elementary tensor operations that accept arguments in einx notation as described in this document. The following tutorial gives an overview of these functions and their usage. diff --git a/docs/source/gettingstarted/tutorial_ops.rst b/docs/source/gettingstarted/tutorial_ops.rst new file mode 100644 index 0000000..4350023 --- /dev/null +++ b/docs/source/gettingstarted/tutorial_ops.rst @@ -0,0 +1,342 @@ +Tutorial: Operations +#################### + +einx represents tensor operations using a set of elementary operations that are vectorized according to the given einx expressions. +Internally, einx does not implement the operations from scratch, but forwards computation to the respective backend, e.g. by +calling `np.reshape `_, +`np.transpose `_ or +`np.sum `_ with the appropriate arguments. + +This tutorial gives an overview of these operations and their usage. For a complete list of provided functions, see the :doc:`API reference `. + +Rearranging +----------- + +The function :func:`einx.rearrange` transforms tensors between einx expressions by determining and applying the required backend operations. For example: + +>>> x = np.ones((4, 256, 17)) +>>> y, z = einx.rearrange("b (s p) (c + 1) -> (b s) p c, (b p) s 1", x, p=8) +>>> y.shape, z.shape +((128, 8, 16), (32, 32, 1)) + +Conceptually, this corresponds with a vectorized identity mapping. Using :func:`einx.rearrange` often produces more readable and concise code than +specifying backend operations in index-based notation directly. The index-based calls can be +inspected using the just-in-time compiled function that einx creates for this expression (see :doc:`Just-in-time compilation `): + +>>> print(einx.rearrange("b (s p) (c + 1) -> (b s) p c, (b p) s 1", x, p=8, graph=True)) +import numpy as np +def op0(i0): + x0 = np.reshape(i0, (4, 32, 8, 17)) + x1 = np.reshape(x0[:, :, :, 0:16], (128, 8, 16)) + x2 = np.reshape(x0[:, :, :, 16:17], (4, 32, 8)) + x3 = np.transpose(x2, (0, 2, 1)) + x4 = np.reshape(x3, (32, 32, 1)) + return [x1, x4] + +Reduction +--------- + +einx provides a family of elementary operations that reduce tensors along one or more axes. For example: + +.. code:: + + einx.sum("a [b]", x) + # same as + np.sum(x, axis=1) + + einx.mean("a [...]", x) + # same as + np.mean(x, axis=tuple(range(1, x.ndim))) + +These functions are specializations of :func:`einx.reduce` and use backend operations like `np.sum `_, +`np.prod `_ or `np.any `_ as the ``op`` argument: + +.. code:: + + einx.reduce("a [b]", x, op=np.sum) + # same as + einx.sum("a [b]", x) + +In ``einx.sum``, the respective backend is determined implicitly from the input tensor (see :doc:`How does einx support different tensor frameworks? `). + +Generally, the operation string represents both input and output expressions, and marks reduced axes using brackets: + +>>> x = np.ones((16, 8, 4)) +>>> einx.sum("a [b] c -> a c", x).shape +(16,) + +Since the output of the elementary reduction operation is a scalar, no axis is marked in the output expression. + +The following shorthand notation is supported: + +* When no brackets are found, brackets are placed implicitly around all axes that do not appear in the output: + + .. code:: + + einx.sum("a b c -> a c", x) # Expands to: "a [b] c -> a c" + +* When no output is given, it is determined implicitly by removing marked subexpressions from the input: + + .. code:: + + einx.sum("a [b] c", x) # Expands to: "a [b] c -> a c" + +:func:`einx.reduce` also allows custom reduction operations that accept the ``axis`` argument similar to `np.sum `_: + +.. code:: + + def custom_mean(x, axis): + return np.sum(x, axis=axis) / x.shape[axis] + einx.reduce("a [b] c", x, op=custom_mean) + +:func:`einx.reduce` fully supports expression rearranging: + +>>> x = np.ones((16, 8)) +>>> einx.prod("a (b [c]) -> b a", x, c=2).shape +(4, 16) + +Element-by-element +------------------ + +einx provides a family of elementary operations that apply element-by-element operations to tensors. For example: + +.. code:: + + einx.add("a b, b -> a b", x, y) + # same as + x + y[np.newaxis, :] + + einx.multiply("a, a b -> a b", x, y) + # same as + x[:, np.newaxis] * y + + einx.subtract("a, (a b) -> b a", x, y) + # requires reshape and transpose in index-based notation + +The elementary operations accept and return scalars and no axes are marked with ``[]``-brackets. +Internally, the inputs are rearranged such that the operation can be applied using `Numpy broadcasting rules `_. +These functions are specializations of :func:`einx.elementwise` and use backend operations like `np.add `_, +`np.logical_and `_ and `np.where `_ +as the ``op`` argument: + +.. code:: + + einx.elementwise("a b, b -> a b", x, y, op=np.add) + # same as + einx.add("a b, b -> a b", x, y) + +Generally, the operation string of :func:`einx.elementwise` represents all input and output expressions explicitly: + +>>> x = np.ones((16, 8)) +>>> y = np.ones((16,)) +>>> einx.add("a b, a -> a b", x, y).shape +(16, 8) + +The following shorthand notation is supported: + +* The output is determined implicitly if one of the input expressions contains the named axes of all other inputs and if this choice is unique: + + .. code:: + + einx.add("a b, a", x, y) # Expands to: "a b, a -> a b" + + einx.where("b a, b, a", x, y, z) # Expands to "b a, b, a -> b a" + + einx.subtract("a b, b a", x, y) # Raises an exception + + einx.add("a b, a b", x, y) # Expands to: "a b, a b -> a b" + +* Bracket notation can be used to indicate that the second input is a subexpression of the first: + + .. code:: + + einx.add("a [b]", x, y) # Expands to: "a b, b" + + .. note:: + + Conceptually, a different elementary operation is used in this case which is applied to tensors of equal shape rather than just scalars. + This variant might be removed in future versions. + +:func:`einx.elementwise` fully supports expression rearranging: + +>>> x = np.ones((16, 16, 32)) +>>> bias = np.ones((4,)) +>>> einx.add("b... (g [c])", x, bias).shape +(16, 16, 32) + +Indexing +-------- + +einx provides a family of elementary operations that perform multi-dimensional indexing and update/retrieve values from tensors at specific coordinates: + +.. code:: + + image = np.ones((256, 256, 3)) + coordinates = np.ones((100, 2), dtype=np.int32) + updates = np.ones((100, 3)) + + # Retrieve values at specific locations in an image + y = einx.get_at("[h w] c, i [2] -> i c", image, coordinates) + # same as + y = image[coordinates[:, 0], coordinates[:, 1]] + + # Update values at specific locations in an image + y = einx.set_at("[h w] c, i [2], i c -> [h w] c", image, coordinates, updates) + # same as + image[coordinates[:, 0], coordinates[:, 1]] = updates + y = image + +Brackets in the first input indicate axes that are indexed, and a single bracket in the second input indicates the coordinate axis. The length of the coordinate axis should equal +the number of indexed axes in the first input. Coordinates can also be passed in separate tensors: + +.. code:: + + coordinates_x = np.ones((100,), dtype=np.int32) + coordinates_y = np.ones((100,), dtype=np.int32) + + y = einx.get_at("[h w] c, i, i -> i c", image, coordinates_x, coordinates_y) + +Indexing functions are specializations of :func:`einx.index` and fully support expression rearranging: + +.. code:: + + einx.add_at("b ([h w]) c, ([2] b) i, c i -> c [h w] b", image, coordinates, updates) + +Dot-product +----------- + +The function :func:`einx.dot` computes a dot-product along the marked axes: + +>>> # Matrix multiplication between x and y +>>> x = np.ones((4, 16)) +>>> y = np.ones((16, 8)) +>>> einx.dot("a [b], [b] c -> a c", x, y).shape +(4, 8) + +While operations such as matrix multiplications are represented conceptually as a vectorized dot-products in einx, they are still implemented using +efficient matmul calls in the respective backend rather than a vectorized evaluation of the dot-product. + +The interface of :func:`einx.dot` closely resembles the existing `np.einsum `_ +which also uses Einstein-inspired notation to express matrix multiplications. In fact, :func:`einx.dot` internally forwards computation +to the ``einsum`` implementation of the respective backend, but additionally supports rearranging of expressions: + +>>> # Simple grouped linear layer +>>> x = np.ones((20, 16)) +>>> w = np.ones((8, 4)) +>>> print(einx.dot("b (g c1), c1 c2 -> b (g c2)", x, w, g=2, graph=True)) +import numpy as np +def op0(i0, i1): + x0 = np.reshape(i0, (20, 2, 8)) + x1 = np.einsum("abc,cd->abd", x0, i1) + x2 = np.reshape(x1, (20, 8)) + return x2 + +The following shorthand notation is supported: + +* When no brackets are found, brackets are placed implicitly around all axes that do not appear in the output: + + .. code:: + + einx.dot("a b, b c -> a c", x, y) # Expands to: "a [b], [b] c -> a c" + + This allows using einsum-like notation with :func:`einx.dot`. + +* When given two input tensors, the expression of the second input is determined implicitly by marking + its components in the input and output expression: + + .. code:: + + einx.dot("a [b] -> a [c]", x, y) # Expands to: "a b, b c -> a c" + + .. note:: + + Conceptually, the elementary operation in this case is not a simple dot-product, but rather a linear map from + ``b`` to ``c`` channels, which motivates the usage of bracket notation in this manner. + + Axes marked multiple times appear only once in the implicit second input expression: + + .. code:: + + einx.dot("[a b] -> [a c]", x, y) # Expands to: "a b, a b c -> a c" + +Other operations: ``vmap`` +-------------------------- + +If an operation is not provided as a separate einx API, it can still be applied in einx using :func:`einx.vmap` or :func:`einx.vmap_with_axis`. +Both functions apply the same vectorization rules as other einx functions, but accept an ``op`` argument that specifies the elementary operation to apply. + +In :func:`einx.vmap`, the input and output tensors of ``op`` match the marked axes in the input and output expressions: + +.. code:: + + # A custom operation: + def op(x): + # Input: x has shape "b c" + x = np.sum(x, axis=1) + x = np.flip(x, axis=0) + # Output: x has shape "b" + return x + + einx.vmap("a [b c] -> a [b]", x, op=op) + +:func:`einx.vmap` is implemented using efficient automatic vectorization in the respective backend (e.g. +`jax.vmap `_, `torch.vmap `_). +einx also implements a simple ``vmap`` function for the Numpy backend for testing/ debugging purposes using a Python loop. + +In :func:`einx.vmap_with_axis`, ``op`` is instead given an ``axis`` argument and must follow +`Numpy broadcasting rules `_: + +.. code:: + + # A custom operation: + def op(x, axis): + # Input: x has shape "a b c", axis is (1, 2) + x = np.sum(x, axis=axis[1]) + x = np.flip(x, axis=axis[0]) + # Output: x has shape "b" + return x + + einx.vmap_with_axis("(a [b c]) -> (a [b])", x, op=op, a=2, b=3, c=4) + +Both :func:`einx.reduce` and :func:`einx.elementwise` are adaptations of :func:`einx.vmap_with_axis`. + +Since most backend operations that accept an ``axis`` argument operate on the entire input tensor when ``axis`` is not given, :func:`einx.vmap_with_axis` can often +analogously be expressed using :func:`einx.vmap`: + +>>> x = np.ones((4, 16)) +>>> einx.vmap_with_axis("a [b] -> a", x, op=np.sum).shape +(4,) +>>> einx.vmap ("a [b] -> a", x, op=np.sum).shape +(4,) + +>>> x = np.ones((4, 16)) +>>> y = np.ones((4,)) +>>> einx.vmap_with_axis("a b, a -> a b", x, y, op=np.add).shape +(4, 16) +>>> einx.vmap ("a b, a -> a b", x, y, op=np.add).shape +(4, 16) + +:func:`einx.vmap` provides more general vectorization capabilities than :func:`einx.vmap_with_axis`, but might in some cases be slower if the latter relies on a +specialized implementation. + +.. _lazytensorconstruction: + +Misc: Tensor factories +---------------------------- + +All einx operations also accept tensor factories instead of tensors as arguments. A tensor factory is a function that accepts a ``shape`` +argument and returns a tensor with that shape. This allows deferring the construction of a tensor to the point inside +an einx operation where its shape has been resolved, and avoids having to manually determine the shape in advance: + +.. code:: + + einx.dot("b... c1, c1 c2 -> b... c2", x, lambda shape: np.random.uniform(shape), c2=32) + +In this example, the shape of ``x`` is used by the expression solver to determine the values of ``b...`` and ``c1``. Since the tensor factory provides no shape +constraints to the solver, the remaining axis values have to be specified explicitly, i.e. ``c2=32``. + +Tensor factories are particularly useful in the context of deep learning modules: The shapes of a layer's weights are typically chosen to align with the shapes +of the layer input and outputs (e.g. the number of input channels in a linear layer must match the corresponding axis in the layer's weight matrix). +This can be achieved implicitly by constructing layer weights using tensor factories. + +The following tutorial describes in more detail how this is used in einx to implement deep learning models. \ No newline at end of file diff --git a/docs/source/gettingstarted/tutorial_overview.rst b/docs/source/gettingstarted/tutorial_overview.rst new file mode 100644 index 0000000..b0197ae --- /dev/null +++ b/docs/source/gettingstarted/tutorial_overview.rst @@ -0,0 +1,35 @@ +Tutorial: Overview +################## + +einx provides a universal interface to formulate tensor operations as concise expressions in frameworks such as +Numpy, PyTorch, Tensorflow and Jax. This tutorial will introduce the main concepts of Einstein-inspired notation +(or *einx notation*) and how it is used as a universal language for expressing tensor operations. + +An einx expression is a string that represents the axis names of a tensor. For example, given the tensor + +>>> import numpy as np +>>> x = np.ones((2, 3, 4)) + +we can name its dimensions ``a``, ``b`` and ``c``: + +>>> import einx +>>> einx.matches("a b c", x) # Check whether expression matches the tensor's shape +True +>>> einx.matches("a b", x) +False + +The purpose of einx expressions is to specify how tensor operations will be applied to the input tensors: + +>>> np.sum(x, axis=1) +>>> # same as +>>> einx.sum("a [b] c", x) + +Here, ``einx.sum`` represents the elementary *sum-reduction* operation that is computed. The expression ``a [b] c`` specifies +that it is applied to sub-tensors +spanning the ``b`` axis, and vectorized over axes ``a`` and ``c``. This is an example of the general paradigm +for formulating complex tensor operations with einx: + +1. Provide a set of elementary tensor operations such as ``einx.{sum|max|where|add|dot|flip|get_at|...}``. +2. Use einx notation as a universal language to express vectorization of the elementary ops. + +The following tutorials will give a deeper dive into einx expressions and how they are used to express a large variety of tensor operations. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index e4b3002..ad648ed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,23 +1,26 @@ Welcome to einx's documentation! ================================ -einx is a Python library that allows formulating many tensor operations as concise expressions using Einstein notation. -It is inspired by `einops `_. - .. toctree:: :caption: Getting Started :maxdepth: 3 gettingstarted/introduction.rst gettingstarted/installation.rst - gettingstarted/einsteinnotation.rst - gettingstarted/tensormanipulation.rst - gettingstarted/neuralnetworks.rst + gettingstarted/tutorial_overview.rst + gettingstarted/tutorial_notation.rst + gettingstarted/tutorial_ops.rst + gettingstarted/tutorial_neuralnetworks.rst gettingstarted/commonnnops.rst gettingstarted/gpt2.rst - gettingstarted/jit.rst - gettingstarted/gotchas.rst - gettingstarted/cheatsheet.rst + +.. toctree:: + :caption: Further Resources + :maxdepth: 1 + + more/jit.rst + more/gotchas.rst + more/related.rst .. toctree:: :caption: Frequently Asked Questions @@ -27,6 +30,7 @@ It is inspired by `einops `_. faq/solver.rst faq/backend.rst faq/einops.rst + faq/universal.rst .. toctree:: :caption: einx API diff --git a/docs/source/gettingstarted/gotchas.rst b/docs/source/more/gotchas.rst similarity index 85% rename from docs/source/gettingstarted/gotchas.rst rename to docs/source/more/gotchas.rst index e2ee789..9eee19f 100644 --- a/docs/source/gettingstarted/gotchas.rst +++ b/docs/source/more/gotchas.rst @@ -3,18 +3,13 @@ Gotchas 1. **Unnamed axes are always unique** and cannot refer to the same axis in different expressions. E.g. ``3 -> 3`` refers to two different axes, both with length 3. This can lead to unexpected behavior in some cases: ``einx.sum("3 -> 3", x)`` will reduce the first ``3`` axis and insert -a new axis broadcasted to length 3. +a new axis broadcasted with length 3. 2. **Spaces in expressions are important.** E.g. in ``(a b)...`` the ellipsis repeats ``(a b)``, while in ``(a b) ...`` the ellipsis repeats a new axis that is inserted in front of it. -3. **einx.dot is not called einx.einsum** despite providing einsum-like functionality to avoid confusion with ``einx.sum``. The name is -motivated by the fact that the function computes a generalized dot-product, and is in line with expressing the same operation using :func:`einx.vmap`: - -.. code:: - - einx.dot("a b, b c -> a c", x, y) - einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) +3. **einx.dot is not called einx.einsum** despite providing einsum-like functionality. This follows the general paradigm of naming functions after +the elementary operation that is computed, and avoids confusion with ``einx.sum``. 4. **einx does not support dynamic shapes** that can occur for example when tracing some types of functions (e.g. `tf.unique `_) in Tensorflow using ``tf.function``. As a workaround, the shape can be specified statically, @@ -29,7 +24,7 @@ The problem typically does not arise in frameworks like Jax 6. **einx implements a custom vmap for Numpy using Python loops**. This is slower than ``vmap`` in other backends, but is included for debugging and testing purposes. -7. **In einx.nn layers, weights are created on the first forward pass** (see :doc:`Tutorial: Neural networks `). +7. **In einx.nn layers, weights are created on the first forward pass** (see :doc:`Tutorial: Neural networks `). This is common practice in jax-based frameworks like Flax and Haiku where the model is initialized using a forward pass on a dummy batch. In other frameworks, an initial forward pass should be added before using the model. (In some circumstances the first actual training batch might be sufficient, but it is safer to always include the initial forward pass.) In PyTorch, diff --git a/docs/source/gettingstarted/jit.rst b/docs/source/more/jit.rst similarity index 96% rename from docs/source/gettingstarted/jit.rst rename to docs/source/more/jit.rst index 5e556f7..f3fdc93 100644 --- a/docs/source/gettingstarted/jit.rst +++ b/docs/source/more/jit.rst @@ -1,7 +1,7 @@ Just-in-time compilation ######################## -When an einx function is invoked, the required backend operations are determined from the given Einstein expressions and traced into graph representation. The graph is +When an einx function is invoked, the required backend operations are determined from the given einx expressions and traced into graph representation. The graph is then just-in-time compiled into a regular Python function using Python's `exec() `_. As a simple example, consider the following einx call: diff --git a/docs/source/more/related.rst b/docs/source/more/related.rst new file mode 100644 index 0000000..313cc7b --- /dev/null +++ b/docs/source/more/related.rst @@ -0,0 +1,20 @@ +Related projects +################ + +* `einops `_ +* `einsum `_ +* `eindex `_ +* `torchdim `_ +* `einindex `_ +* `einshape `_ +* `einop `_ +* `eingather `_ +* `einshard `_ +* `shardops `_ +* `eins `_ +* `Named axes in PyTorch `_ +* `Named axes in Jax `_ +* `Named axes in Penzai `_ +* `Dex `_ +* `Named Tensor Notation `_ +* `Tensor Considered Harmful `_ \ No newline at end of file diff --git a/einx/experimental/op/shard.py b/einx/experimental/op/shard.py index 6354806..2453374 100644 --- a/einx/experimental/op/shard.py +++ b/einx/experimental/op/shard.py @@ -144,6 +144,8 @@ def shard( ) -> einx.Tensor: """Shards a tensor over a mesh of devices. + *This function is currently experimental and will likely change in future versions.* + *This function is currently only supported for Jax: A sharding is created based on the given expression, and applied to the tensor using* ``jax.device_put``. @@ -184,8 +186,6 @@ def shard( >>> x.sharding NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd1',)) - **This function is currently experimental and will likely change in future versions.** - Args: description: Description string in Einstein notation (see above). tensor: Input tensor or tensor factory matching the description string. diff --git a/einx/op/arange.py b/einx/op/arange.py index 7db91e1..e71b042 100644 --- a/einx/op/arange.py +++ b/einx/op/arange.py @@ -128,7 +128,9 @@ def arange( ) -> einx.Tensor: """n-dimensional ``arange`` operation. - Runs ``backend.arange`` for every axis in ``input``, and stacks the results along the single + *This function might be removed in a future version.* + + Runs ``arange`` for every axis in ``input``, and stacks the results along the single marked axis in ``output``. Always uses ``start=0`` and ``step=1``. The `description` argument must meet one of the following formats: diff --git a/einx/op/dot.py b/einx/op/dot.py index 26a1794..e8cf8cb 100644 --- a/einx/op/dot.py +++ b/einx/op/dot.py @@ -203,24 +203,21 @@ def dot( ) -> einx.Tensor: """Computes a general dot-product of the input tensors. - The function flattens all input tensors, applies the general dot-product yielding a single - output tensor, and rearranges the result to match the output expression (see - :doc:`How does einx handle input and output tensors? `). + The following shorthand notation is supported: - The `description` argument specifies the input and output expressions. It must meet - one of the following formats: + * When no brackets are found, brackets are placed implicitly around all axes that do not + appear in the output. - 1. ``input1, input2, ... -> output`` - All input and output expressions are specified explicitly. Similar to - `np.einsum `_ - notation. + Example: ``a b, b c -> a c`` expands to ``a [b], [b] c -> a c`` - 2. ``input1 -> output`` - The function accepts two input tensors. ``[]``-brackets mark all axes in ``input1`` - and ``output`` that should also appear in the second input. The second input is then - determined as an ordered list of all marked axes (without duplicates). + * When given two input tensors, the expression of the second input is determined implicitly + from the marked axes in the input and output expression. - Example: ``[b c1] -> [b c2]`` resolves to ``b c1, b c1 c2 -> b c2`` + Example: ``a [b] -> a [c]`` expands to ``a b, b c -> a c`` + + Axes marked multiple times appear only once in the implicit second input expression. + + Example: ``[a b] -> [a c]`` expands to ``a b, a b c -> a c`` The function additionally passes the ``in_axes``, ``out_axes`` and ``batch_axes`` arguments to tensor factories that can be used to determine the fan-in and fan-out of a neural network @@ -228,7 +225,7 @@ def dot( `_) Args: - description: Description string in Einstein notation (see above). + description: Description string for the operation in einx notation. tensors: Input tensors or tensor factories matching the description string. backend: Backend to use for all operations. If None, determines the backend from the input tensors. Defaults to None. @@ -239,7 +236,7 @@ def dot( **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``. Returns: - The result of the dot-product operation if `graph=False`, otherwise the graph + The result of the dot-product operation if ``graph=False``, otherwise the graph representation of the operation. Examples: diff --git a/einx/op/elementwise.py b/einx/op/elementwise.py index fcaf4d7..5716c4a 100644 --- a/einx/op/elementwise.py +++ b/einx/op/elementwise.py @@ -133,38 +133,29 @@ def elementwise( cse: bool = True, **parameters: npt.ArrayLike, ) -> einx.Tensor: - """Applies an element-by-element operation over the given tensors. Specializes - :func:`einx.vmap_with_axis`. + """Applies an element-by-element operation over the given tensors. - The function flattens all input tensors, applies the given element-by-element operation - yielding a single output tensor, and rearranges the result to match the output expression - (see :doc:`How does einx handle input and output tensors? `). + It supports the following shorthand notation: - The `description` argument specifies the input and output expressions. It must meet one of - the following formats: + * The output is determined implicitly if one of the input expressions contains the named axes + of all other inputs and if this choice is unique. - 1. ``input1, input2, ... -> output`` - All input and output expressions are specified explicitly. + | Example: ``a b, a`` expands to ``a b, a -> a b``. + | Example: ``b a, b, a`` expands to ``b a, b, a -> b a``. + | Example: ``a b, b a`` raises an exception. + | Example: ``a b, a b`` expands to ``a b, a b -> a b``. - 2. ``input1, input2, ...`` - All input expressions are specified explicitly. If one of the input expressions is a - parent of or equal to all other input expressions, it is used as the output expression. - Otherwise, an exception is raised. + * Bracket notation can be used when passing two input tensors to indicate that the second + input is a subexpression of the first. - Example: ``a b, a`` resolves to ``a b, a -> a b``. - - 3. ``input1`` with ``[]``-brackets - The function accepts two input tensors. `[]`-brackets mark all subexpressions in the - first input that should also appear in the second input. - - Example: ``a [b]`` resolves to ``a b, b`` + Example: ``a [b]`` expands to ``a b, b``. Args: - description: Description string in Einstein notation (see above). + description: Description string for the operation in einx notation. tensors: Input tensors or tensor factories matching the description string. op: Backend elemebt-by-element operation. Must accept the same number of tensors as specified in the description string and comply with numpy broadcasting rules. - If `op` is a string, retrieves the attribute of `backend` with the same name. + If ``op`` is a string, retrieves the attribute of ``backend`` with the same name. backend: Backend to use for all operations. If None, determines the backend from the input tensors. Defaults to None. cse: Whether to apply common subexpression elimination to the expressions. Defaults @@ -174,7 +165,7 @@ def elementwise( **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``. Returns: - The result of the elementwise operation if `graph=False`, otherwise the graph + The result of the elementwise operation if ``graph=False``, otherwise the graph representation of the operation. Examples: diff --git a/einx/op/index.py b/einx/op/index.py index 139262c..d2466a9 100644 --- a/einx/op/index.py +++ b/einx/op/index.py @@ -387,28 +387,23 @@ def index( cse: bool = True, **parameters: npt.ArrayLike, ) -> einx.Tensor: - """Updates and/ or returns values from an array at the given coordinates. + """Updates and/ or returns values from a tensor at the given coordinates. - The `description` argument specifies the input and output expressions and must meet one of - the following formats: + * If ``update`` is True: The first tensor receives updates, the last tensor contains the + updates, and all other tensors represent the coordinates. If the output expression is + not given, it is assumed to be equal to the first input expression. - 1. ``tensor, coordinates1, coordinates2, ..., update -> output`` - when modifying values in the tensor. - 2. ``tensor, coordinates1, coordinates2, ... -> output`` - when only returning values from the tensor. + * If ``update`` is False, values are retrieved from the first tensor and the remaining tensors + contain the coordinates. - Brackets in the ``tensor`` expression mark the axes that will be indexed. Brackets in the - ``coordinates`` expression mark the single coordinate axis. All other axes are considered - batch axes. Using multiple coordinate expressions will yield the same output as concatenating + Using multiple coordinate expressions will yield the same output as concatenating the coordinate expressions along the coordinate axis first. Args: - description: Description string in Einstein notation (see above). - *tensors: Tensors that the operation will be applied to. The first tensor will receive - updates, the last tensor contains the updates, and all other tensors represent - the coordinates. - op: The update/gather function. If `op` is a string, retrieves the attribute of `backend` - with the same name. + description: Description string for the operation in einx notation. + *tensors: Tensors that the operation will be applied to. + op: The update/gather function. If ``op`` is a string, retrieves the attribute of + ``backend`` with the same name. update: Whether to update the tensor or return values from the tensor. backend: Backend to use for all operations. If None, determines the backend from the input tensors. Defaults to None. @@ -419,7 +414,7 @@ def index( **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``. Returns: - The result of the update/ gather operation if `graph=False`, otherwise the graph + The result of the update/ gather operation if ``graph=False``, otherwise the graph representation of the operation. Examples: diff --git a/einx/op/rearrange.py b/einx/op/rearrange.py index 0b76300..3a0b77d 100644 --- a/einx/op/rearrange.py +++ b/einx/op/rearrange.py @@ -98,14 +98,9 @@ def rearrange( ) -> Union[einx.Tensor, Tuple[einx.Tensor, ...]]: """Rearranges the input tensors to match the output expressions. - See :doc:`How does einx handle input and output tensors? `. - - The `description` argument specifies the input and output expressions. It must - meet the following format: - ``input1, input2, ... -> output1, output2, ...`` - Args: - description: Description string in Einstein notation (see above). + description: Description string for the operation in einx notation. Must not contain + brackets. tensors: Input tensors or tensor factories matching the description string. backend: Backend to use for all operations. If None, determines the backend from the input tensors. Defaults to None. @@ -116,7 +111,7 @@ def rearrange( **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``. Returns: - The result of the elementwise operation if `graph=False`, otherwise the graph + The result of the rearrange operation if ``graph=False``, otherwise the graph representation of the operation. Examples: diff --git a/einx/op/reduce.py b/einx/op/reduce.py index a50cc86..c1a8d1e 100644 --- a/einx/op/reduce.py +++ b/einx/op/reduce.py @@ -108,32 +108,26 @@ def reduce( ) -> einx.Tensor: """Applies a reduction operation on the given tensors. - The function flattens all input tensors, applies the given reduction operation and rearranges - the result to match the output expression (see :doc:`How does einx handle input and - output tensors? `). + The operation reduces all marked axes in the input to a single scalar. It supports + the following shorthand notation: - The `description` argument specifies the input and output expressions, as well as - reduced axes. It must meet one of the following formats: + * When no brackets are found, brackets are placed implicitly around all axes that do not + appear in the output. + + Example: ``a b c -> a c`` resolves to ``a [b] c -> a c``. - 1. ``input -> output`` - Input and output expressions are specified explicitly. Reduced axes are marked - with ``[]``-brackets in the input expression. If no axes are - marked, reduces all axes that do not appear in the output expression. + * When no output is given, it is determined implicitly by removing marked subexpressions + from the input. - 2. ``input`` - A single input expression is specified. Reduced axes are marked with ``[]``-brackets. - The output expression is determined by removing all marked expressions - from the input expression. - - Example: ``a [b]`` resolves to ``a b -> a``. + Example: ``a [b] c`` resolves to ``a [b] c -> a c``. Args: - description: Description string in Einstein notation (see above). + description: Description string for the operation in einx notation. tensor: Input tensor or tensor factory matching the description string. - op: Backend reduction operation. Is called with ``op(tensor, axis=...)``. If `op` is - a string, retrieves the attribute of `backend` with the same name. + op: Backend reduction operation. Is called with ``op(tensor, axis=...)``. If ``op`` is + a string, retrieves the attribute of ``backend`` with the same name. keepdims: Whether to replace marked expressions with 1s instead of dropping them. Must - be None when `description` already contains an output expression. Defaults to None. + be None when ``description`` already contains an output expression. Defaults to None. backend: Backend to use for all operations. If None, determines the backend from the input tensors. Defaults to None. cse: Whether to apply common subexpression elimination to the expressions. Defaults diff --git a/einx/op/solve.py b/einx/op/solve.py index 2b5ea08..5c6cb26 100644 --- a/einx/op/solve.py +++ b/einx/op/solve.py @@ -59,11 +59,8 @@ def solve( ) -> Optional[Mapping[str, npt.ArrayLike]]: """Solve for the axis values of the given expressions and tensors. - The `description` argument must meet the following format: - ``input1, input2, ...`` - Args: - description: Description string in Einstein notation. + description: Description string for the tensors in einx notation. tensors: Input tensors or tensor factories matching the description string. cse: Whether to apply common subexpression elimination to the expressions. Defaults to False. @@ -87,11 +84,8 @@ def matches( ) -> bool: """Check whether the given expressions and tensors match. - The `description` argument must meet the following format: - ``input1, input2, ...`` - Args: - description: Description string in Einstein notation. + description: Description string for the tensors in einx notation. tensors: Input tensors or tensor factories matching the description string. cse: Whether to apply common subexpression elimination to the expressions. Defaults to False. @@ -99,6 +93,13 @@ def matches( Returns: True if the expressions and tensors match, False otherwise. + + Examples: + >>> x = np.zeros((10, 5)) + >>> einx.matches("a b", x) + True + >>> einx.matches("a b c", x) + False """ return solve(description, *tensors, cse=cse, **parameters) is not None @@ -109,11 +110,8 @@ def check( ) -> None: """Check whether the given expressions and tensors match and raise an exception if they don't. - The `description` argument must meet the following format: - ``input1, input2, ...`` - Args: - description: Description string in Einstein notation. + description: Description string for the tensors in einx notation. tensors: Input tensors or tensor factories matching the description string. cse: Whether to apply common subexpression elimination to the expressions. Defaults to False. diff --git a/einx/op/util.py b/einx/op/util.py index 046a8bb..2845430 100644 --- a/einx/op/util.py +++ b/einx/op/util.py @@ -4,21 +4,6 @@ def flatten(exprs, tensors=None, backend=None): - """Flatten the given expressions and optionally the corresponding tensors. - - Flattening removes all compositions and concatenations and returns a list of new expressions - (and optinally a list of flattened tensors). - - Parameters: - exprs: Expressions to flatten. - tensors: Tensors corresponding to ``exprs``. If None, flattens and returns only - ``exprs``. Defaults to None. - backend: Backend to use for tensor operations. - - Returns: - exprs: The flattened expressions. - tensors, optional: The flattened tensors. Only returned if ``tensors`` is not None. - """ if tensors is None: exprs_out = [] for expr in exprs: @@ -90,18 +75,6 @@ def flatten(exprs, tensors=None, backend=None): def assignment(exprs_in, exprs_out): - """Solve the assignment problem between input and output expressions. - - If multiple solutions exist: For each output expression in order, - choose the first input expression that matches. - - Args: - exprs_in: Input expressions. - exprs_out: Output expressions. - - Returns: - indices: Indices into ``exprs_in`` with the same ordering as ``exprs_out``. - """ if len(exprs_in) != len(exprs_out): raise ValueError("Got different number of input and output expressions") axes_in = [{a.name for a in einx.expr.stage3.get_named_axes(expr_in)} for expr_in in exprs_in] diff --git a/einx/op/vmap.py b/einx/op/vmap.py index 324bd87..be31334 100644 --- a/einx/op/vmap.py +++ b/einx/op/vmap.py @@ -338,24 +338,16 @@ def vmap( kwargs: Mapping = {}, **parameters: npt.ArrayLike, ): - """Applies a function to the marked axes of the input tensors using vectorization. + """Vectorizes and applies a function to the input tensors using automatic vectorization. - The function flattens all input tensors, applies the vectorized operation on the - tensors and rearranges the result to match the output expressions (see :doc:`How does - einx handle input and output tensors? `). - - The `description` argument specifies the input and output expressions. The operation is - applied over all axes marked with ``[]``-brackets. All other axes are considered batch - axes and vectorized over. - - The function ``op`` should accept input tensors and yield output tensors as specified in + The function ``op`` must accept input tensors and yield output tensors as specified in ``description`` with shapes matching the subexpressions that are marked with ``[]``-brackets. Args: - description: Description string in Einstein notation (see above). + description: Description string for the operation in einx notation. tensors: Input tensors or tensor factories matching the description string. op: Function that will be vectorized. If ``op`` is a string, retrieves the attribute - of `backend` with the same name. + of ``backend`` with the same name. flat: Whether to pass the tensors to ``op`` in flattened form or matching the nested layout in the input expressions. Defaults to False. kwargs: Additional keyword arguments that are passed to ``op``. Defaults to ``{}``. diff --git a/einx/op/vmap_with_axis.py b/einx/op/vmap_with_axis.py index 3f0f82a..2e58475 100644 --- a/einx/op/vmap_with_axis.py +++ b/einx/op/vmap_with_axis.py @@ -236,24 +236,19 @@ def vmap_with_axis( **parameters: npt.ArrayLike, ): """Applies a function to the marked axes of the input tensors by passing the ``axis`` - argument. + argument and relying on implicit broadcasting rules. - The function flattens all input tensors, applies the given operation and rearranges - the result to match the output expressions (see :doc:`How does einx handle input and output - tensors? `). - - The `description` argument specifies the input and output expressions. The operation is - applied over all axes marked with ``[]``-brackets. All other axes are considered batch axes. - - When the function is applied on scalars, the ``axis`` argument is not passed. For multiple - input tensors, the function must follow + The function ``op`` must accept input tensors and an ``axis`` argument specifying the + indices of the axes along which the operation is applied. When the function is applied on + scalars, the ``axis`` argument is not passed. For multiple input tensors, the function + must follow `Numpy broadcasting rules `_. Args: - description: Description string in Einstein notation (see above). + description: Description string for the operation in einx notation. tensors: Input tensors or tensor factories matching the description string. - op: Backend operation. Is called with ``op(tensor, axis=...)``. If `op` is a string, - retrieves the attribute of `backend` with the same name. + op: Backend operation. Is called with ``op(tensor, axis=...)``. If ``op`` is a string, + retrieves the attribute of ``backend`` with the same name. kwargs: Additional keyword arguments that are passed to ``op``. backend: Backend to use for all operations. If None, determines the backend from the input tensors. Defaults to None.