diff --git a/einx/backend/__init__.py b/einx/backend/__init__.py index a0f0230..a10e05b 100644 --- a/einx/backend/__init__.py +++ b/einx/backend/__init__.py @@ -7,3 +7,4 @@ from . import _jax as jax from . import _dask as dask from . import _mlx as mlx +from . import _tinygrad as tinygrad diff --git a/einx/backend/_mlx.py b/einx/backend/_mlx.py index 893bbf5..f8ae725 100644 --- a/einx/backend/_mlx.py +++ b/einx/backend/_mlx.py @@ -82,6 +82,7 @@ def arange(start, stop=None, step=None, dtype="int32"): true_divide = op.elementwise(tmx.divide) floor_divide = op.elementwise(tmx.floor_divide) divide = op.elementwise(tmx.divide) + mod = op.elementwise(tmx.remainder) logical_and = associative_binary_to_nary(op.elementwise(tmx.logical_and)) logical_or = associative_binary_to_nary(op.elementwise(tmx.logical_or)) where = op.elementwise(tmx.where) diff --git a/einx/backend/_tinygrad.py b/einx/backend/_tinygrad.py new file mode 100644 index 0000000..f191698 --- /dev/null +++ b/einx/backend/_tinygrad.py @@ -0,0 +1,220 @@ +from .base import * +import einx.tracer as tracer +from einx.tracer.tensor import op +import einx, types +from functools import partial +import functools + + +def create(): + tTensor = tracer.import_("Tensor", from_="tinygrad") + tdtypes = tracer.import_("dtypes", from_="tinygrad") + from tinygrad import Tensor, dtypes + + def scalar_to_tensor(x): + if isinstance(x, (einx.tracer.Scalar, float, int)): + return einx.tracer.apply( + tTensor, + args=[x], + output=einx.tracer.Tensor([]), + ) + else: + return x + + def elementwise(func, convert_all_to_tensor=False): + @einx.trace + @functools.wraps(func) + def outer(*args): + if convert_all_to_tensor: + args = [scalar_to_tensor(a) for a in args] + else: + args = [a for a in args] + args[0] = scalar_to_tensor(args[0]) + return op.elementwise(func)(*args) + return outer + + def reduce(func): + @einx.trace + @functools.wraps(func) + def reduce(tensor, axis=None, **kwargs): + keepdims = kwargs.get("keepdims", False) + if axis is None: + shape = () + else: + axes = [axis] if isinstance(axis, int) else axis + shape = list(tensor.shape) + if keepdims: + for a in axes: + shape[a] = 1 + else: + for a in sorted(axes, reverse=True): + del shape[a] + kwargs = {**kwargs, **{"axis": axis}} + if "keepdims" in kwargs: + kwargs["keepdim"] = kwargs.pop("keepdims") + return tracer.apply(func, args=[tensor], kwargs=kwargs, output=tracer.Tensor(shape)) + return reduce + + def to_dtype(x): + if isinstance(x, str): + return getattr(dtypes, x) + else: + return x + + to_dtype2 = to_dtype + + class tinygrad(Backend): + name = "tinygrad" + tensor_types = [Tensor] + + to_dtype = staticmethod(to_dtype2) + + @staticmethod + @einx.trace + def to_tensor(tensor, shape): + return einx.tracer.apply( + tTensor, + args=[tensor], + output=einx.tracer.Tensor(shape), + ) + + reshape = op.reshape(tTensor.reshape) + transpose = op.transpose(tTensor.permute) + broadcast_to = op.broadcast_to(tTensor.expand) + + @classmethod + @einx.trace + def einsum(backend, equation, *tensors): + x = equation.split("->") + if len(x) != 2: + raise ValueError("Invalid equation") + inputs, output = x + inputs = inputs.split(",") + if len(inputs) != len(tensors): + raise ValueError("Invalid equation") + inputs = [x.strip().replace(" ", "") for x in inputs] + tensors = [t for t in tensors] + + scalars = [] + for i in list(range(len(inputs)))[::-1]: + if (len(inputs[i]) > 0) != (len(tensors[i].shape) > 0): + raise ValueError("Invalid equation") + if len(inputs[i]) == 0: + scalars.append(tensors[i]) + inputs.pop(i) + tensors.pop(i) + + if len(tensors) > 1: + equation = ",".join(inputs) + "->" + output + x = op.einsum(tTensor.einsum)(equation, *tensors) + elif len(tensors) == 1: + x = tensors[0] + else: + x = scalars[0] + scalars = scalars[1:] + for scalar in scalars: + x = backend.multiply(x, scalar) + + return x + + @staticmethod + @einx.trace + def arange(n, dtype="int32"): + if isinstance(dtype, str): + dtype = getattr(tdtypes, dtype) + return op.arange(tTensor.arange)(n, dtype=dtype) + + @staticmethod + @einx.trace + def concatenate(tensors, axis=0): + shape = list(tensors[0].shape) + shape[axis] = sum(tensor.shape[axis] for tensor in tensors) + return tracer.apply( + tTensor.cat, args=[*tensors], kwargs={"dim": axis}, output=tracer.Tensor(shape) + ) + + add = associative_binary_to_nary(elementwise(tTensor.add)) + subtract = elementwise(tTensor.sub) + multiply = associative_binary_to_nary(elementwise(tTensor.mul)) + true_divide = elementwise(tTensor.div) + floor_divide = elementwise(partial(tTensor.div, upcast=False)) + divide = elementwise(tTensor.div) + logical_and = associative_binary_to_nary(elementwise(tTensor.mul)) + logical_or = associative_binary_to_nary(elementwise(tTensor.add)) + where = elementwise(tTensor.where) + less = elementwise(tracer.Operator("<")) + less_equal = elementwise(tracer.Operator("<=")) + greater = elementwise(tracer.Operator(">")) + greater_equal = elementwise(tracer.Operator(">=")) + equal = elementwise(tracer.Operator("==")) + not_equal = elementwise(tracer.Operator("!=")) + maximum = associative_binary_to_nary(elementwise(tTensor.maximum)) + minimum = associative_binary_to_nary(elementwise(tTensor.minimum)) + + sum = reduce(tTensor.sum) + mean = reduce(tTensor.mean) + var = reduce(tTensor.var) + std = reduce(tTensor.std) + + count_nonzero = reduce(tTensor.sum) + min = reduce(tTensor.min) + max = reduce(tTensor.max) + # tinygrad's logsumexp currently does not support multiple axes, so + # we use our custom implementation instead: + # logsumexp = reduce(tTensor.logsumexp) + + log = op.elementwise(tTensor.log) + exp = op.elementwise(tTensor.exp) + sqrt = op.elementwise(tTensor.sqrt) + rsqrt = op.elementwise(tTensor.rsqrt) + square = op.elementwise(tTensor.square) + + @staticmethod + @einx.trace + def get_at(tensor, coordinates): + raise NotImplementedError() + + @staticmethod + @einx.trace + def set_at(tensor, coordinates, updates): + raise NotImplementedError() + + @staticmethod + @einx.trace + def add_at(tensor, coordinates, updates): + raise NotImplementedError() + + @staticmethod + @einx.trace + def subtract_at(tensor, coordinates, updates): + raise NotImplementedError() + + flip = op.keep_shape(tTensor.flip) + softmax = op.keep_shape(tTensor.softmax) + log_softmax = op.keep_shape(tTensor.log_softmax) + + @staticmethod + @einx.trace + def stop_gradient(tensor): + return tensor # TODO: set requires_grad to False? + + @staticmethod + @einx.trace + def vmap(op, in_axes, out_axes, input_shapes, output_shapes): + raise NotImplementedError( + "Functions relying on vmap are not supported for the tinygrad backend" + ) + + class random: + @einx.trace + def bernoulli(rng, p, shape): + return ( + einx.tracer.apply( + tTensor.rand, + args=[*shape], + output=einx.tracer.Tensor(shape), + ) + <= p + ) + + return tinygrad() diff --git a/einx/backend/base.py b/einx/backend/base.py index 3b5b1fe..2948b43 100644 --- a/einx/backend/base.py +++ b/einx/backend/base.py @@ -99,6 +99,17 @@ def to_tensor(tensor): return [to_tensor(tensor) for tensor in tensors] + @classmethod + @einx.trace + def stack(backend, tensors, axis=0): + s = (slice(None),) * axis + (None,) + return backend.concatenate([tensor[s] for tensor in tensors], axis=axis) + + @classmethod + @einx.trace + def mod(backend, x, y): + return backend.subtract(x, backend.multiply(backend.floor_divide(x, y), y)) + @classmethod @einx.trace def logsumexp(backend, x, axis=None): @@ -120,6 +131,30 @@ def logsumexp(backend, x, axis=None): def std(backend, x, axis=None, keepdims=False): return backend.sqrt(backend.var(x, axis=axis, keepdims=keepdims)) + @classmethod + @einx.trace + def prod(backend, tensor, axis=None): + tensor = backend.log(tensor) + tensor = backend.sum(tensor, axis=axis) + tensor = backend.exp(tensor) + return tensor + + @classmethod + @einx.trace + def any(backend, tensor, axis=None): + return backend.count_nonzero(tensor, axis=axis) > 0 + + @classmethod + @einx.trace + def all(backend, tensor, axis=None): + if axis is None: + total_num = np.prod(tensor.shape) + elif isinstance(axis, int): + total_num = tensor.shape[axis] + else: + total_num = np.prod([tensor.shape[i] for i in axis]) + return backend.count_nonzero(tensor, axis=axis) == total_num + @classmethod @einx.trace def softmax(backend, x, axis=None): @@ -157,7 +192,7 @@ def roll(backend, tensor, shift, axis): raise ValueError(f"Got {len(shift)} shifts, expected {len(axis)}") for shift, axis in zip(shift, axis): indices = backend.arange(tensor.shape[axis]) - indices = (indices - shift) % tensor.shape[axis] + indices = backend.mod(indices - shift, tensor.shape[axis]) c = (slice(None),) * axis + (indices,) tensor = tensor[c] return tensor diff --git a/einx/backend/register.py b/einx/backend/register.py index 610faad..770ecd4 100644 --- a/einx/backend/register.py +++ b/einx/backend/register.py @@ -42,6 +42,7 @@ def register(backend): from . import _jax from . import _dask from . import _mlx +from . import _tinygrad # Create numpy backend now numpy = register(_numpy.create()) @@ -52,6 +53,7 @@ def register(backend): register_for_module("jax", _jax.create) register_for_module("dask.array", _dask.create) register_for_module("mlx", _mlx.create) +register_for_module("tinygrad", _tinygrad.create) # Check if any new modules have been imported and construct backends that have been diff --git a/test/conftest.py b/test/conftest.py index c418557..4a59653 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -252,3 +252,16 @@ def wrap(op): ) tests.append((einx, backend, test)) + +if importlib.util.find_spec("tinygrad"): + from tinygrad import Tensor + + backend = einx.backend.tinygrad.create() + + test = types.SimpleNamespace( + full=lambda shape, value=0.0, dtype="float32": Tensor.full(shape, value, dtype=backend.to_dtype(dtype)), + to_tensor=Tensor, + to_numpy=lambda x: x.numpy(), + ) + + tests.append((einx, backend, test)) diff --git a/test/test_shapes.py b/test/test_shapes.py index bcecc4a..68de3e7 100644 --- a/test/test_shapes.py +++ b/test/test_shapes.py @@ -398,7 +398,7 @@ def create_scalar(dtype): def test_shape_vmap(test): einx, backend, setup = test - if backend.name in {"mlx", "dask"}: + if backend.name in {"mlx", "dask", "tinygrad"}: pytest.xfail(reason="Backend does not fully support vmap") x = setup.full((13,)) @@ -529,7 +529,7 @@ def func(x): # c d -> 2 def test_shape_index(test): einx, backend, setup = test - if backend.name in {"mlx", "dask"}: + if backend.name in {"mlx", "dask", "tinygrad"}: pytest.xfail(reason="Backend does not fully support vmap") coord_dtype = "int32" if backend.name != "torch" else "long"