diff --git a/einx/expr/solver.py b/einx/expr/solver.py index 3811415..fd95715 100644 --- a/einx/expr/solver.py +++ b/einx/expr/solver.py @@ -20,10 +20,11 @@ def __rmul__(self, other): class Variable(Expression): - def __init__(self, id, name): + def __init__(self, id, name, integer=True): Expression.__init__(self) self.id = id self.name = name + self.integer = integer def __iter__(self): yield self @@ -38,7 +39,7 @@ def __str__(self): return f"{self.name}" def sympy(self): - return sympy.Symbol(self.id) + return sympy.Symbol(self.id, integer=self.integer) class Constant(Expression): diff --git a/test/test_shapes.py b/test/test_shapes.py index 4cb9052..eea9142 100644 --- a/test/test_shapes.py +++ b/test/test_shapes.py @@ -233,6 +233,11 @@ def w(shape): assert einx.dot("... b, ... -> b", x, y).shape == (10,) assert einx.dot("[...] b -> b", x, y).shape == (10,) + k = setup.full((2, 4, 100)) + v = setup.full((2, 4, 100)) + with pytest.raises(Exception): + einx.dot("b t (h ck), b t (h cv) -> b h ck cv", k, v, h=32, graph=True) + @pytest.mark.parametrize("test", conftest.tests) def test_shape_reduce(test):