Skip to content

Commit

Permalink
Add partial support for tinygrad
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jun 8, 2024
1 parent 23f5e04 commit 19e875f
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 3 deletions.
1 change: 1 addition & 0 deletions einx/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from . import _jax as jax
from . import _dask as dask
from . import _mlx as mlx
from . import _tinygrad as tinygrad
1 change: 1 addition & 0 deletions einx/backend/_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def arange(start, stop=None, step=None, dtype="int32"):
true_divide = op.elementwise(tmx.divide)
floor_divide = op.elementwise(tmx.floor_divide)
divide = op.elementwise(tmx.divide)
mod = op.elementwise(tmx.remainder)
logical_and = associative_binary_to_nary(op.elementwise(tmx.logical_and))
logical_or = associative_binary_to_nary(op.elementwise(tmx.logical_or))
where = op.elementwise(tmx.where)
Expand Down
220 changes: 220 additions & 0 deletions einx/backend/_tinygrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from .base import *
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
from functools import partial
import functools


def create():
tTensor = tracer.import_("Tensor", from_="tinygrad")
tdtypes = tracer.import_("dtypes", from_="tinygrad")
from tinygrad import Tensor, dtypes

def scalar_to_tensor(x):
if isinstance(x, (einx.tracer.Scalar, float, int)):
return einx.tracer.apply(
tTensor,
args=[x],
output=einx.tracer.Tensor([]),
)
else:
return x

def elementwise(func, convert_all_to_tensor=False):
@einx.trace
@functools.wraps(func)
def outer(*args):
if convert_all_to_tensor:
args = [scalar_to_tensor(a) for a in args]
else:
args = [a for a in args]
args[0] = scalar_to_tensor(args[0])
return op.elementwise(func)(*args)
return outer

def reduce(func):
@einx.trace
@functools.wraps(func)
def reduce(tensor, axis=None, **kwargs):
keepdims = kwargs.get("keepdims", False)
if axis is None:
shape = ()
else:
axes = [axis] if isinstance(axis, int) else axis
shape = list(tensor.shape)
if keepdims:
for a in axes:
shape[a] = 1
else:
for a in sorted(axes, reverse=True):
del shape[a]
kwargs = {**kwargs, **{"axis": axis}}
if "keepdims" in kwargs:
kwargs["keepdim"] = kwargs.pop("keepdims")
return tracer.apply(func, args=[tensor], kwargs=kwargs, output=tracer.Tensor(shape))
return reduce

def to_dtype(x):
if isinstance(x, str):
return getattr(dtypes, x)
else:
return x

to_dtype2 = to_dtype

class tinygrad(Backend):
name = "tinygrad"
tensor_types = [Tensor]

to_dtype = staticmethod(to_dtype2)

@staticmethod
@einx.trace
def to_tensor(tensor, shape):
return einx.tracer.apply(
tTensor,
args=[tensor],
output=einx.tracer.Tensor(shape),
)

reshape = op.reshape(tTensor.reshape)
transpose = op.transpose(tTensor.permute)
broadcast_to = op.broadcast_to(tTensor.expand)

@classmethod
@einx.trace
def einsum(backend, equation, *tensors):
x = equation.split("->")
if len(x) != 2:
raise ValueError("Invalid equation")
inputs, output = x
inputs = inputs.split(",")
if len(inputs) != len(tensors):
raise ValueError("Invalid equation")
inputs = [x.strip().replace(" ", "") for x in inputs]
tensors = [t for t in tensors]

scalars = []
for i in list(range(len(inputs)))[::-1]:
if (len(inputs[i]) > 0) != (len(tensors[i].shape) > 0):
raise ValueError("Invalid equation")
if len(inputs[i]) == 0:
scalars.append(tensors[i])
inputs.pop(i)
tensors.pop(i)

if len(tensors) > 1:
equation = ",".join(inputs) + "->" + output
x = op.einsum(tTensor.einsum)(equation, *tensors)
elif len(tensors) == 1:
x = tensors[0]
else:
x = scalars[0]
scalars = scalars[1:]
for scalar in scalars:
x = backend.multiply(x, scalar)

return x

@staticmethod
@einx.trace
def arange(n, dtype="int32"):
if isinstance(dtype, str):
dtype = getattr(tdtypes, dtype)
return op.arange(tTensor.arange)(n, dtype=dtype)

@staticmethod
@einx.trace
def concatenate(tensors, axis=0):
shape = list(tensors[0].shape)
shape[axis] = sum(tensor.shape[axis] for tensor in tensors)
return tracer.apply(
tTensor.cat, args=[*tensors], kwargs={"dim": axis}, output=tracer.Tensor(shape)
)

add = associative_binary_to_nary(elementwise(tTensor.add))
subtract = elementwise(tTensor.sub)
multiply = associative_binary_to_nary(elementwise(tTensor.mul))
true_divide = elementwise(tTensor.div)
floor_divide = elementwise(partial(tTensor.div, upcast=False))
divide = elementwise(tTensor.div)
logical_and = associative_binary_to_nary(elementwise(tTensor.mul))
logical_or = associative_binary_to_nary(elementwise(tTensor.add))
where = elementwise(tTensor.where)
less = elementwise(tracer.Operator("<"))
less_equal = elementwise(tracer.Operator("<="))
greater = elementwise(tracer.Operator(">"))
greater_equal = elementwise(tracer.Operator(">="))
equal = elementwise(tracer.Operator("=="))
not_equal = elementwise(tracer.Operator("!="))
maximum = associative_binary_to_nary(elementwise(tTensor.maximum))
minimum = associative_binary_to_nary(elementwise(tTensor.minimum))

sum = reduce(tTensor.sum)
mean = reduce(tTensor.mean)
var = reduce(tTensor.var)
std = reduce(tTensor.std)

count_nonzero = reduce(tTensor.sum)
min = reduce(tTensor.min)
max = reduce(tTensor.max)
# tinygrad's logsumexp currently does not support multiple axes, so
# we use our custom implementation instead:
# logsumexp = reduce(tTensor.logsumexp)

log = op.elementwise(tTensor.log)
exp = op.elementwise(tTensor.exp)
sqrt = op.elementwise(tTensor.sqrt)
rsqrt = op.elementwise(tTensor.rsqrt)
square = op.elementwise(tTensor.square)

@staticmethod
@einx.trace
def get_at(tensor, coordinates):
raise NotImplementedError()

@staticmethod
@einx.trace
def set_at(tensor, coordinates, updates):
raise NotImplementedError()

@staticmethod
@einx.trace
def add_at(tensor, coordinates, updates):
raise NotImplementedError()

@staticmethod
@einx.trace
def subtract_at(tensor, coordinates, updates):
raise NotImplementedError()

flip = op.keep_shape(tTensor.flip)
softmax = op.keep_shape(tTensor.softmax)
log_softmax = op.keep_shape(tTensor.log_softmax)

@staticmethod
@einx.trace
def stop_gradient(tensor):
return tensor # TODO: set requires_grad to False?

@staticmethod
@einx.trace
def vmap(op, in_axes, out_axes, input_shapes, output_shapes):
raise NotImplementedError(
"Functions relying on vmap are not supported for the tinygrad backend"
)

class random:
@einx.trace
def bernoulli(rng, p, shape):
return (
einx.tracer.apply(
tTensor.rand,
args=[*shape],
output=einx.tracer.Tensor(shape),
)
<= p
)

return tinygrad()
37 changes: 36 additions & 1 deletion einx/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ def to_tensor(tensor):

return [to_tensor(tensor) for tensor in tensors]

@classmethod
@einx.trace
def stack(backend, tensors, axis=0):
s = (slice(None),) * axis + (None,)
return backend.concatenate([tensor[s] for tensor in tensors], axis=axis)

@classmethod
@einx.trace
def mod(backend, x, y):
return backend.subtract(x, backend.multiply(backend.floor_divide(x, y), y))

@classmethod
@einx.trace
def logsumexp(backend, x, axis=None):
Expand All @@ -120,6 +131,30 @@ def logsumexp(backend, x, axis=None):
def std(backend, x, axis=None, keepdims=False):
return backend.sqrt(backend.var(x, axis=axis, keepdims=keepdims))

@classmethod
@einx.trace
def prod(backend, tensor, axis=None):
tensor = backend.log(tensor)
tensor = backend.sum(tensor, axis=axis)
tensor = backend.exp(tensor)
return tensor

@classmethod
@einx.trace
def any(backend, tensor, axis=None):
return backend.count_nonzero(tensor, axis=axis) > 0

@classmethod
@einx.trace
def all(backend, tensor, axis=None):
if axis is None:
total_num = np.prod(tensor.shape)
elif isinstance(axis, int):
total_num = tensor.shape[axis]
else:
total_num = np.prod([tensor.shape[i] for i in axis])
return backend.count_nonzero(tensor, axis=axis) == total_num

@classmethod
@einx.trace
def softmax(backend, x, axis=None):
Expand Down Expand Up @@ -157,7 +192,7 @@ def roll(backend, tensor, shift, axis):
raise ValueError(f"Got {len(shift)} shifts, expected {len(axis)}")
for shift, axis in zip(shift, axis):
indices = backend.arange(tensor.shape[axis])
indices = (indices - shift) % tensor.shape[axis]
indices = backend.mod(indices - shift, tensor.shape[axis])
c = (slice(None),) * axis + (indices,)
tensor = tensor[c]
return tensor
Expand Down
2 changes: 2 additions & 0 deletions einx/backend/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def register(backend):
from . import _jax
from . import _dask
from . import _mlx
from . import _tinygrad

# Create numpy backend now
numpy = register(_numpy.create())
Expand All @@ -52,6 +53,7 @@ def register(backend):
register_for_module("jax", _jax.create)
register_for_module("dask.array", _dask.create)
register_for_module("mlx", _mlx.create)
register_for_module("tinygrad", _tinygrad.create)


# Check if any new modules have been imported and construct backends that have been
Expand Down
13 changes: 13 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,16 @@ def wrap(op):
)

tests.append((einx, backend, test))

if importlib.util.find_spec("tinygrad"):
from tinygrad import Tensor

backend = einx.backend.tinygrad.create()

test = types.SimpleNamespace(
full=lambda shape, value=0.0, dtype="float32": Tensor.full(shape, value, dtype=backend.to_dtype(dtype)),
to_tensor=Tensor,
to_numpy=lambda x: x.numpy(),
)

tests.append((einx, backend, test))
4 changes: 2 additions & 2 deletions test/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def create_scalar(dtype):
def test_shape_vmap(test):
einx, backend, setup = test

if backend.name in {"mlx", "dask"}:
if backend.name in {"mlx", "dask", "tinygrad"}:
pytest.xfail(reason="Backend does not fully support vmap")

x = setup.full((13,))
Expand Down Expand Up @@ -529,7 +529,7 @@ def func(x): # c d -> 2
def test_shape_index(test):
einx, backend, setup = test

if backend.name in {"mlx", "dask"}:
if backend.name in {"mlx", "dask", "tinygrad"}:
pytest.xfail(reason="Backend does not fully support vmap")

coord_dtype = "int32" if backend.name != "torch" else "long"
Expand Down

0 comments on commit 19e875f

Please sign in to comment.