Skip to content

Commit

Permalink
Fix tinygrad tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jun 9, 2024
1 parent ec14667 commit 2e15548
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install optional dependencies
run: |
python -m pip install --upgrade pip
pip install pytest "jax[cpu]" flax torch tensorflow einops mlx dask
pip install pytest "jax[cpu]" flax torch tensorflow einops mlx dask tinygrad
- uses: actions/checkout@v3
- name: Test with pytest
run: |
Expand All @@ -39,7 +39,7 @@ jobs:
- name: Install optional dependencies
run: |
python -m pip install --upgrade pip
pip install pytest "jax[cpu]" flax dm-haiku torch tensorflow einops equinox mlx dask
pip install pytest "jax[cpu]" flax dm-haiku torch tensorflow einops equinox mlx dask tinygrad
pip install --upgrade keras
- uses: actions/checkout@v3
- name: Test with pytest
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def wrap(op):
tests.append((einx, backend, test))

if importlib.util.find_spec("tinygrad"):
import os
os.environ["PYTHON"] = "1"
from tinygrad import Tensor

backend = einx.backend.tinygrad.create()
Expand Down
20 changes: 10 additions & 10 deletions test/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_values(test):

rng = np.random.default_rng(42)

if backend.name != "mlx":
if backend.name not in {"mlx", "dask", "tinygrad"}:
x = setup.to_tensor(rng.uniform(size=(13,)).astype("float32"))
assert allclose(
einx.vmap("b -> b [3]", x, op=lambda x: x + setup.full((3,), value=1)),
Expand All @@ -32,7 +32,7 @@ def test_values(test):

x = setup.to_tensor(rng.uniform(size=(10, 20, 3)).astype("float32"))
y = setup.to_tensor(rng.uniform(size=(10, 24)).astype("float32"))
if backend.name != "mlx":
if backend.name not in {"mlx", "dask", "tinygrad"}:
assert allclose(
einx.dot("a b c, a d -> a b c d", x, y),
einx.vmap(
Expand All @@ -50,7 +50,7 @@ def test_values(test):
setup=setup,
)

if backend.name != "mlx":
if backend.name not in {"mlx", "dask", "tinygrad"}:
assert allclose(
einx.mean("a b [c]", x),
einx.vmap("a b [c] -> a b", x, op=backend.mean),
Expand Down Expand Up @@ -118,12 +118,12 @@ def test_values(test):
x = setup.to_tensor(np.arange(10))
y = setup.to_tensor(np.arange(10)[::-1].copy())
z = setup.to_tensor(np.arange(10))
assert allclose(
einx.get_at("[h], h2 -> h2", x, y),
y,
setup=setup,
)
if backend.name != "mlx":
if backend.name not in {"mlx", "dask", "tinygrad"}:
assert allclose(
einx.get_at("[h], h2 -> h2", x, y),
y,
setup=setup,
)
assert allclose(
einx.set_at("[h], h2, h2 -> [h]", x, y, z),
y,
Expand All @@ -141,7 +141,7 @@ def test_values(test):
setup=setup,
)

if not backend.name in {"mlx", "dask"}:
if backend.name not in {"mlx", "dask", "tinygrad"}:
coord_dtype = "int32" if backend.name != "torch" else "long"
x = setup.to_tensor(rng.uniform(size=(4, 5, 6)).astype("float32"))
y = setup.full((4, 5), value=3, dtype=coord_dtype)
Expand Down

0 comments on commit 2e15548

Please sign in to comment.