From 4af566ad06b2a344f460edf6fa320d7241c726ee Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Tue, 23 Apr 2024 22:25:53 +0200 Subject: [PATCH] Add missing numpy test --- test/conftest.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 63d8ce8..ecf94e2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -53,6 +53,22 @@ def op3(*args, **kwargs): return op3 + +# numpy is always available +import numpy as np + +backend = einx.backend.numpy.create() + +test = types.SimpleNamespace( + full=lambda shape, value=0.0, dtype="float32": np.full(shape, value, dtype=dtype), + to_tensor=np.asarray, + to_numpy=np.asarray, +) + +tests.append((einx, backend, test)) + + + if importlib.util.find_spec("jax"): import jax import jax.numpy as jnp