diff --git a/README.md b/README.md index 80ce479..ef6d967 100644 --- a/README.md +++ b/README.md @@ -115,11 +115,11 @@ einx traces the required backend operations for a given call into graph represen >>> x = np.zeros((3, 10, 10)) >>> graph = einx.sum("... (g [c])", x, g=2, graph=True) >>> print(graph) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x1 = backend.reshape(i0, (3, 10, 2, 5)) - x0 = backend.sum(x1, axis=3) - return x0 + x0 = np.reshape(i0, (3, 10, 2, 5)) + x1 = np.sum(x0, axis=3) + return x1 ``` See [Just-in-time compilation](https://einx.readthedocs.io/en/latest/gettingstarted/jit.html) for more details. \ No newline at end of file diff --git a/docs/source/gettingstarted/einsteinnotation.rst b/docs/source/gettingstarted/einsteinnotation.rst index da7c73c..cead1aa 100644 --- a/docs/source/gettingstarted/einsteinnotation.rst +++ b/docs/source/gettingstarted/einsteinnotation.rst @@ -38,9 +38,9 @@ To verify that the correct backend calls are made, the just-in-time compiled fun >>> graph = einx.rearrange("a b c -> a c b", x, graph=True) >>> print(graph) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x0 = backend.transpose(i0, (0, 2, 1)) + x0 = np.transpose(i0, (0, 2, 1)) return x0 The function shows that einx performs the expected call to ``np.transpose``. @@ -86,15 +86,15 @@ that it uses a `np.reshape >> print(einx.rearrange("(a b) c -> a b c", x, a=2, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x0 = backend.reshape(i0, (2, 3, 4)) + x0 = np.reshape(i0, (2, 3, 4)) return x0 >>> print(einx.rearrange("a b c -> (a b) c", x, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x0 = backend.reshape(i0, (6, 4)) + x0 = np.reshape(i0, (6, 4)) return x0 .. note:: @@ -136,12 +136,12 @@ This operation requires multiple backend calls in index-based notation that migh the intent of the operation and requires less code: >>> print(einx.rearrange("(s p)... c -> (s...) p... c", x, p=8, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x2 = backend.reshape(i0, (32, 8, 32, 8, 3)) - x1 = backend.transpose(x2, (0, 2, 1, 3, 4)) - x0 = backend.reshape(x1, (1024, 8, 8, 3)) - return x0 + x0 = np.reshape(i0, (32, 8, 32, 8, 3)) + x1 = np.transpose(x0, (0, 2, 1, 3, 4)) + x2 = np.reshape(x1, (1024, 8, 8, 3)) + return x2 In einops-style notation, an ellipsis can only appear once at root level without a preceding expression. To be fully compatible with einops notation, einx implicitly converts anonymous ellipses by adding an axis in front: @@ -201,11 +201,11 @@ This can be used for example to concatenate tensors that do not have compatible The graph shows that einx first reshapes ``y`` by adding a channel dimension, and then concatenates the tensors along that axis: >>> print(einx.rearrange("h w c, h w -> h w (c + 1)", x, y, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0, i1): - x1 = backend.reshape(i1, (256, 256, 1)) - x0 = backend.concatenate([i0, x1], 2) - return x0 + x0 = np.reshape(i1, (256, 256, 1)) + x1 = np.concatenate([i0, x0], axis=2) + return x1 Splitting is supported analogously: @@ -281,11 +281,11 @@ Bracket notation is fully compatible with expression rearranging and can therefo (4, 64, 64, 3) >>> print(einx.mean("b (s [ds])... c", x, ds=4, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x1 = backend.reshape(i0, (4, 64, 4, 64, 4, 3)) - x0 = backend.mean(x1, axis=(2, 4)) - return x0 + x0 = np.reshape(i0, (4, 64, 4, 64, 4, 3)) + x1 = np.mean(x0, axis=(2, 4)) + return x1 .. note:: diff --git a/docs/source/gettingstarted/gpt2.rst b/docs/source/gettingstarted/gpt2.rst index dca192b..1ab78c8 100644 --- a/docs/source/gettingstarted/gpt2.rst +++ b/docs/source/gettingstarted/gpt2.rst @@ -82,12 +82,12 @@ We can verify the correctness of these operations by inspecting the jit-compiled >>> graph = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=self.heads, graph=True) >>> print(graph) -# backend: einx.backend.jax +import jax.numpy as jnp def op0(i0, i1): - x1 = backend.reshape(i0, (1, 1024, 25, 64)) - x2 = backend.reshape(i1, (1, 1024, 25, 64)) - x0 = backend.einsum("abcd,aecd->abec", x1, x2) - return x0 + x0 = jnp.reshape(i0, (1, 1024, 25, 64)) + x1 = jnp.reshape(i1, (1, 1024, 25, 64)) + x2 = jnp.einsum("abcd,aecd->abec", x0, x1) + return x2 The final GPT-2 model first embeds the input tokens and adds positional embeddings. It then applies a number of main blocks and maps the output onto next token logits using a linear layer: diff --git a/docs/source/gettingstarted/jit.rst b/docs/source/gettingstarted/jit.rst index 024a4f4..5e556f7 100644 --- a/docs/source/gettingstarted/jit.rst +++ b/docs/source/gettingstarted/jit.rst @@ -14,12 +14,12 @@ We can inspect the compiled function by passing ``graph=True``: >>> graph = einx.sum("a [b]", x, graph=True) >>> print(graph) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x0 = backend.sum(i0, axis=1) + x0 = np.sum(i0, axis=1) return x0 -einx passes this string and variables such as ``backend`` to `exec() `_ to just-in-time compile the function. +einx passes this string to `exec() `_ to just-in-time compile the function. It then invokes the function using the required arguments. The traced function is cached, such that subsequent calls with the same signature of inputs can reuse it and incur no overhead other than for cache lookup. @@ -39,52 +39,50 @@ A sum-reduction that requires a reshape operation: >>> x = np.zeros((10, 10)) >>> print(einx.sum("b... (g [c])", x, g=2, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x1 = backend.reshape(i0, (10, 2, 5)) - x0 = backend.sum(x1, axis=2) - return x0 + x0 = np.reshape(i0, (10, 2, 5)) + x1 = np.sum(x0, axis=2) + return x1 -A call to ``einx.dot`` that forwards computation to ``backend.einsum``: +A call to ``einx.dot`` that forwards computation to ``np.einsum``: >>> x = np.zeros((10, 10)) >>> print(einx.dot("b... (g [c1->c2])", x, np.ones, g=2, c2=8, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0, i1): - x2 = backend.reshape(i0, (10, 2, 5)) - x3 = einx.param.instantiate(i1, shape=(5, 8), in_axis=(0,), out_axis=(1,), batch_axis=(), name="weight", init="dot", backend=backend) - assert x3.shape == (5, 8) - x1 = backend.einsum("abc,cd->abd", x2, x3) - x0 = backend.reshape(x1, (10, 16)) - return x0 + x0 = np.reshape(i0, (10, 2, 5)) + x1 = np.einsum("abc,cd->abd", x0, i1((5, 8))) + x2 = np.reshape(x1, (10, 16)) + return x2 -A call to ``einx.get_at`` that applies ``backend.vmap`` to handle batch axes: +A call to ``einx.get_at`` that applies ``jax.vmap`` to handle batch axes: ->>> x = np.zeros((4, 128, 128, 3)) ->>> y = np.zeros((4, 1024, 2), "int32") +>>> x = jnp.zeros((4, 128, 128, 3)) +>>> y = jnp.zeros((4, 1024, 2), "int32") >>> print(einx.get_at("b [h w] c, b p [2] -> b p c", x, y, graph=True)) -# backend: einx.backend.numpy +import jax def op1(i0, i1): - x1 = i1[:, 0] - x2 = i1[:, 1] - x0 = backend.get_at(i0, (x1, x2)) - return (x0,) -def op0(i0, i1, op1=op1): - op2 = backend.vmap(op1, in_axes=(0, 0), out_axes=(0,)) - op3 = backend.vmap(op2, in_axes=(3, None), out_axes=(2,)) - x0 = op3(i0, i1) - return x0[0] + x0 = i1[:, 0] + x1 = i1[:, 1] + x2 = i0[x0, x1] + return (x2,) +x3 = jax.vmap(op1, in_axes=(0, 0), out_axes=(0,)) +x4 = jax.vmap(x3, in_axes=(3, None), out_axes=(2,)) +def op0(i0, i1): + x0, = x4(i0, i1) + return x0 An operation that requires concatenation of tensors: >>> x = np.zeros((10, 10, 3)) >>> y = np.ones((10, 10)) >>> print(einx.rearrange("h w c, h w -> h w (c + 1)", x, y, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0, i1): - x1 = backend.reshape(i1, (10, 10, 1)) - x0 = backend.concatenate([i0, x1], 2) - return x0 + x0 = np.reshape(i1, (10, 10, 1)) + x1 = np.concatenate([i0, x0], axis=2) + return x1 The just-in-time compiled function can also be called directly with the correct arguments to avoid a cache lookup: diff --git a/docs/source/gettingstarted/tensormanipulation.rst b/docs/source/gettingstarted/tensormanipulation.rst index 3ae7ecb..492f63b 100644 --- a/docs/source/gettingstarted/tensormanipulation.rst +++ b/docs/source/gettingstarted/tensormanipulation.rst @@ -39,16 +39,14 @@ Using :func:`einx.rearrange` often produces more readable and concise code than inspected using the just-in-time compiled function that einx creates for this expression (see :doc:`Just-in-time compilation `): >>> print(einx.rearrange("b (s p) (c + 1) -> (b s) p c, (b p) s 1", x, p=8, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0): - x1 = backend.reshape(i0, (4, 32, 8, 17)) - x2 = x1[:, :, :, 0:16] - x0 = backend.reshape(x2, (128, 8, 16)) - x6 = x1[:, :, :, 16:17] - x5 = backend.reshape(x6, (4, 32, 8)) - x4 = backend.transpose(x5, (0, 2, 1)) - x3 = backend.reshape(x4, (32, 32, 1)) - return [x0, x3] + x0 = np.reshape(i0, (4, 32, 8, 17)) + x1 = np.reshape(x0[:, :, :, 0:16], (128, 8, 16)) + x2 = np.reshape(x0[:, :, :, 16:17], (4, 32, 8)) + x3 = np.transpose(x2, (0, 2, 1)) + x4 = np.reshape(x3, (32, 32, 1)) + return [x1, x4] Reduction ops ------------- @@ -350,12 +348,12 @@ not the output are reduced via a dot-product: The graph representation shows that the inputs and output are rearranged as required and the dot-product is forwarded to the ``einsum`` function of the backend: >>> print(einx.dot("b (g c1), c1 c2 -> b (g c2)", x, w, g=2, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0, i1): - x2 = backend.reshape(i0, (20, 2, 8)) - x1 = backend.einsum("abc,cd->abd", x2, i1) - x0 = backend.reshape(x1, (20, 8)) - return x0 + x0 = np.reshape(i0, (20, 2, 8)) + x1 = np.einsum("abc,cd->abd", x0, i1) + x2 = np.reshape(x1, (20, 8)) + return x2 Shorthand notation in :func:`einx.dot` is supported as follows. When given two input tensors, the expression of the second input is determined implicitly by marking its components in the input and output expression: @@ -378,9 +376,9 @@ The graph representation shows that the expression forwarded to the ``einsum`` c >>> x = np.ones((4, 8)) >>> y = np.ones((8, 5)) >>> print(einx.dot("a [b->c]", x, y, graph=True)) -# backend: einx.backend.numpy +import numpy as np def op0(i0, i1): - x0 = backend.einsum("ab,bc->ac", i0, i1) + x0 = np.einsum("ab,bc->ac", i0, i1) return x0 .. _lazytensorconstruction: