diff --git a/test/test_shapes.py b/test/test_shapes.py index 10e72f0..c6324d0 100644 --- a/test/test_shapes.py +++ b/test/test_shapes.py @@ -241,8 +241,21 @@ def w(shape): v = setup.full((2, 4, 100)) with pytest.raises(Exception): einx.dot("b t (h ck), b t (h cv) -> b h ck cv", k, v, h=32, graph=True) + + x = setup.full((10, 20)) + y = setup.full((10, 24)) + z = setup.full((3, 24)) + assert einx.dot("[a] b, [a c], d [c] -> b d", x, y, z).shape == (20, 3) + assert einx.dot("a b, a c, d c -> b d", x, y, z).shape == (20, 3) - + x = setup.full((10, 20, 24)) + y = setup.full((10, 24)) + z = setup.full((3, 24)) + assert einx.dot("[a] b [c], [a c], d [c] -> b d", x, y, z).shape == (20, 3) + assert einx.dot("a b c, a c, d c -> b d", x, y, z).shape == (20, 3) + with pytest.raises(Exception): + einx.dot("[a] b [c], a c, d [c] -> b d", x, y, z) + @pytest.mark.parametrize("test", conftest.tests) def test_shape_reduce(test): einx, backend, setup = test