Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Apr 23, 2024
1 parent 6258d71 commit 028b01f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 76 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ einx traces the required backend operations for a given call into graph represen
>>> x = np.zeros((3, 10, 10))
>>> graph = einx.sum("... (g [c])", x, g=2, graph=True)
>>> print(graph)
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x1 = backend.reshape(i0, (3, 10, 2, 5))
x0 = backend.sum(x1, axis=3)
return x0
x0 = np.reshape(i0, (3, 10, 2, 5))
x1 = np.sum(x0, axis=3)
return x1
```

See [Just-in-time compilation](https://einx.readthedocs.io/en/latest/gettingstarted/jit.html) for more details.
38 changes: 19 additions & 19 deletions docs/source/gettingstarted/einsteinnotation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ To verify that the correct backend calls are made, the just-in-time compiled fun

>>> graph = einx.rearrange("a b c -> a c b", x, graph=True)
>>> print(graph)
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x0 = backend.transpose(i0, (0, 2, 1))
x0 = np.transpose(i0, (0, 2, 1))
return x0

The function shows that einx performs the expected call to ``np.transpose``.
Expand Down Expand Up @@ -86,15 +86,15 @@ that it uses a `np.reshape <https://numpy.org/doc/stable/reference/generated/num
operation with the requested shape:

>>> print(einx.rearrange("(a b) c -> a b c", x, a=2, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x0 = backend.reshape(i0, (2, 3, 4))
x0 = np.reshape(i0, (2, 3, 4))
return x0

>>> print(einx.rearrange("a b c -> (a b) c", x, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x0 = backend.reshape(i0, (6, 4))
x0 = np.reshape(i0, (6, 4))
return x0

.. note::
Expand Down Expand Up @@ -136,12 +136,12 @@ This operation requires multiple backend calls in index-based notation that migh
the intent of the operation and requires less code:

>>> print(einx.rearrange("(s p)... c -> (s...) p... c", x, p=8, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x2 = backend.reshape(i0, (32, 8, 32, 8, 3))
x1 = backend.transpose(x2, (0, 2, 1, 3, 4))
x0 = backend.reshape(x1, (1024, 8, 8, 3))
return x0
x0 = np.reshape(i0, (32, 8, 32, 8, 3))
x1 = np.transpose(x0, (0, 2, 1, 3, 4))
x2 = np.reshape(x1, (1024, 8, 8, 3))
return x2

In einops-style notation, an ellipsis can only appear once at root level without a preceding expression. To be fully compatible with einops notation, einx implicitly
converts anonymous ellipses by adding an axis in front:
Expand Down Expand Up @@ -201,11 +201,11 @@ This can be used for example to concatenate tensors that do not have compatible
The graph shows that einx first reshapes ``y`` by adding a channel dimension, and then concatenates the tensors along that axis:

>>> print(einx.rearrange("h w c, h w -> h w (c + 1)", x, y, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0, i1):
x1 = backend.reshape(i1, (256, 256, 1))
x0 = backend.concatenate([i0, x1], 2)
return x0
x0 = np.reshape(i1, (256, 256, 1))
x1 = np.concatenate([i0, x0], axis=2)
return x1

Splitting is supported analogously:

Expand Down Expand Up @@ -281,11 +281,11 @@ Bracket notation is fully compatible with expression rearranging and can therefo
(4, 64, 64, 3)

>>> print(einx.mean("b (s [ds])... c", x, ds=4, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x1 = backend.reshape(i0, (4, 64, 4, 64, 4, 3))
x0 = backend.mean(x1, axis=(2, 4))
return x0
x0 = np.reshape(i0, (4, 64, 4, 64, 4, 3))
x1 = np.mean(x0, axis=(2, 4))
return x1

.. note::

Expand Down
10 changes: 5 additions & 5 deletions docs/source/gettingstarted/gpt2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ We can verify the correctness of these operations by inspecting the jit-compiled

>>> graph = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=self.heads, graph=True)
>>> print(graph)
# backend: einx.backend.jax
import jax.numpy as jnp
def op0(i0, i1):
x1 = backend.reshape(i0, (1, 1024, 25, 64))
x2 = backend.reshape(i1, (1, 1024, 25, 64))
x0 = backend.einsum("abcd,aecd->abec", x1, x2)
return x0
x0 = jnp.reshape(i0, (1, 1024, 25, 64))
x1 = jnp.reshape(i1, (1, 1024, 25, 64))
x2 = jnp.einsum("abcd,aecd->abec", x0, x1)
return x2

The final GPT-2 model first embeds the input tokens and adds positional embeddings. It then applies a number of main blocks and maps the output onto next token
logits using a linear layer:
Expand Down
62 changes: 30 additions & 32 deletions docs/source/gettingstarted/jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ We can inspect the compiled function by passing ``graph=True``:

>>> graph = einx.sum("a [b]", x, graph=True)
>>> print(graph)
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x0 = backend.sum(i0, axis=1)
x0 = np.sum(i0, axis=1)
return x0

einx passes this string and variables such as ``backend`` to `exec() <https://docs.python.org/3/library/functions.html#exec>`_ to just-in-time compile the function.
einx passes this string to `exec() <https://docs.python.org/3/library/functions.html#exec>`_ to just-in-time compile the function.
It then invokes the function using the required arguments. The traced function is cached, such that subsequent calls with the same signature of inputs can
reuse it and incur no overhead other than for cache lookup.

Expand All @@ -39,52 +39,50 @@ A sum-reduction that requires a reshape operation:

>>> x = np.zeros((10, 10))
>>> print(einx.sum("b... (g [c])", x, g=2, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x1 = backend.reshape(i0, (10, 2, 5))
x0 = backend.sum(x1, axis=2)
return x0
x0 = np.reshape(i0, (10, 2, 5))
x1 = np.sum(x0, axis=2)
return x1

A call to ``einx.dot`` that forwards computation to ``backend.einsum``:
A call to ``einx.dot`` that forwards computation to ``np.einsum``:

>>> x = np.zeros((10, 10))
>>> print(einx.dot("b... (g [c1->c2])", x, np.ones, g=2, c2=8, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0, i1):
x2 = backend.reshape(i0, (10, 2, 5))
x3 = einx.param.instantiate(i1, shape=(5, 8), in_axis=(0,), out_axis=(1,), batch_axis=(), name="weight", init="dot", backend=backend)
assert x3.shape == (5, 8)
x1 = backend.einsum("abc,cd->abd", x2, x3)
x0 = backend.reshape(x1, (10, 16))
return x0
x0 = np.reshape(i0, (10, 2, 5))
x1 = np.einsum("abc,cd->abd", x0, i1((5, 8)))
x2 = np.reshape(x1, (10, 16))
return x2

A call to ``einx.get_at`` that applies ``backend.vmap`` to handle batch axes:
A call to ``einx.get_at`` that applies ``jax.vmap`` to handle batch axes:

>>> x = np.zeros((4, 128, 128, 3))
>>> y = np.zeros((4, 1024, 2), "int32")
>>> x = jnp.zeros((4, 128, 128, 3))
>>> y = jnp.zeros((4, 1024, 2), "int32")
>>> print(einx.get_at("b [h w] c, b p [2] -> b p c", x, y, graph=True))
# backend: einx.backend.numpy
import jax
def op1(i0, i1):
x1 = i1[:, 0]
x2 = i1[:, 1]
x0 = backend.get_at(i0, (x1, x2))
return (x0,)
def op0(i0, i1, op1=op1):
op2 = backend.vmap(op1, in_axes=(0, 0), out_axes=(0,))
op3 = backend.vmap(op2, in_axes=(3, None), out_axes=(2,))
x0 = op3(i0, i1)
return x0[0]
x0 = i1[:, 0]
x1 = i1[:, 1]
x2 = i0[x0, x1]
return (x2,)
x3 = jax.vmap(op1, in_axes=(0, 0), out_axes=(0,))
x4 = jax.vmap(x3, in_axes=(3, None), out_axes=(2,))
def op0(i0, i1):
x0, = x4(i0, i1)
return x0

An operation that requires concatenation of tensors:

>>> x = np.zeros((10, 10, 3))
>>> y = np.ones((10, 10))
>>> print(einx.rearrange("h w c, h w -> h w (c + 1)", x, y, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0, i1):
x1 = backend.reshape(i1, (10, 10, 1))
x0 = backend.concatenate([i0, x1], 2)
return x0
x0 = np.reshape(i1, (10, 10, 1))
x1 = np.concatenate([i0, x0], axis=2)
return x1

The just-in-time compiled function can also be called directly with the correct arguments to avoid a cache lookup:

Expand Down
30 changes: 14 additions & 16 deletions docs/source/gettingstarted/tensormanipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ Using :func:`einx.rearrange` often produces more readable and concise code than
inspected using the just-in-time compiled function that einx creates for this expression (see :doc:`Just-in-time compilation </gettingstarted/jit>`):

>>> print(einx.rearrange("b (s p) (c + 1) -> (b s) p c, (b p) s 1", x, p=8, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0):
x1 = backend.reshape(i0, (4, 32, 8, 17))
x2 = x1[:, :, :, 0:16]
x0 = backend.reshape(x2, (128, 8, 16))
x6 = x1[:, :, :, 16:17]
x5 = backend.reshape(x6, (4, 32, 8))
x4 = backend.transpose(x5, (0, 2, 1))
x3 = backend.reshape(x4, (32, 32, 1))
return [x0, x3]
x0 = np.reshape(i0, (4, 32, 8, 17))
x1 = np.reshape(x0[:, :, :, 0:16], (128, 8, 16))
x2 = np.reshape(x0[:, :, :, 16:17], (4, 32, 8))
x3 = np.transpose(x2, (0, 2, 1))
x4 = np.reshape(x3, (32, 32, 1))
return [x1, x4]

Reduction ops
-------------
Expand Down Expand Up @@ -350,12 +348,12 @@ not the output are reduced via a dot-product:
The graph representation shows that the inputs and output are rearranged as required and the dot-product is forwarded to the ``einsum`` function of the backend:

>>> print(einx.dot("b (g c1), c1 c2 -> b (g c2)", x, w, g=2, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0, i1):
x2 = backend.reshape(i0, (20, 2, 8))
x1 = backend.einsum("abc,cd->abd", x2, i1)
x0 = backend.reshape(x1, (20, 8))
return x0
x0 = np.reshape(i0, (20, 2, 8))
x1 = np.einsum("abc,cd->abd", x0, i1)
x2 = np.reshape(x1, (20, 8))
return x2

Shorthand notation in :func:`einx.dot` is supported as follows. When given two input tensors, the expression of the second input is determined implicitly by marking
its components in the input and output expression:
Expand All @@ -378,9 +376,9 @@ The graph representation shows that the expression forwarded to the ``einsum`` c
>>> x = np.ones((4, 8))
>>> y = np.ones((8, 5))
>>> print(einx.dot("a [b->c]", x, y, graph=True))
# backend: einx.backend.numpy
import numpy as np
def op0(i0, i1):
x0 = backend.einsum("ab,bc->ac", i0, i1)
x0 = np.einsum("ab,bc->ac", i0, i1)
return x0

.. _lazytensorconstruction:
Expand Down

0 comments on commit 028b01f

Please sign in to comment.