From a624bb6529891da5e428e9bc3af3aef86e29164a Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Tue, 11 Jun 2024 12:56:52 +0200 Subject: [PATCH] ruff format --- einx/backend/_tinygrad.py | 8 +++++--- test/conftest.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/einx/backend/_tinygrad.py b/einx/backend/_tinygrad.py index f191698..f3d61ee 100644 --- a/einx/backend/_tinygrad.py +++ b/einx/backend/_tinygrad.py @@ -20,7 +20,7 @@ def scalar_to_tensor(x): ) else: return x - + def elementwise(func, convert_all_to_tensor=False): @einx.trace @functools.wraps(func) @@ -31,6 +31,7 @@ def outer(*args): args = [a for a in args] args[0] = scalar_to_tensor(args[0]) return op.elementwise(func)(*args) + return outer def reduce(func): @@ -53,6 +54,7 @@ def reduce(tensor, axis=None, **kwargs): if "keepdims" in kwargs: kwargs["keepdim"] = kwargs.pop("keepdims") return tracer.apply(func, args=[tensor], kwargs=kwargs, output=tracer.Tensor(shape)) + return reduce def to_dtype(x): @@ -114,7 +116,7 @@ def einsum(backend, equation, *tensors): scalars = scalars[1:] for scalar in scalars: x = backend.multiply(x, scalar) - + return x @staticmethod @@ -196,7 +198,7 @@ def subtract_at(tensor, coordinates, updates): @staticmethod @einx.trace def stop_gradient(tensor): - return tensor # TODO: set requires_grad to False? + return tensor # TODO: set requires_grad to False? @staticmethod @einx.trace diff --git a/test/conftest.py b/test/conftest.py index 1408a47..9a6336b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -255,13 +255,16 @@ def wrap(op): if importlib.util.find_spec("tinygrad"): import os + os.environ["PYTHON"] = "1" from tinygrad import Tensor backend = einx.backend.tinygrad.create() test = types.SimpleNamespace( - full=lambda shape, value=0.0, dtype="float32": Tensor.full(shape, value, dtype=backend.to_dtype(dtype)), + full=lambda shape, value=0.0, dtype="float32": Tensor.full( + shape, value, dtype=backend.to_dtype(dtype) + ), to_tensor=Tensor, to_numpy=lambda x: x.numpy(), )