Skip to content

Commit

Permalink
Fail during graph construction when expressions cannot be solved
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed May 1, 2024
1 parent 6a5f9f4 commit ee99cd2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 3 additions & 2 deletions einx/expr/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def __rmul__(self, other):


class Variable(Expression):
def __init__(self, id, name):
def __init__(self, id, name, integer=True):
Expression.__init__(self)
self.id = id
self.name = name
self.integer = integer

def __iter__(self):
yield self
Expand All @@ -38,7 +39,7 @@ def __str__(self):
return f"{self.name}"

def sympy(self):
return sympy.Symbol(self.id)
return sympy.Symbol(self.id, integer=self.integer)


class Constant(Expression):
Expand Down
5 changes: 5 additions & 0 deletions test/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def w(shape):
assert einx.dot("... b, ... -> b", x, y).shape == (10,)
assert einx.dot("[...] b -> b", x, y).shape == (10,)

k = setup.full((2, 4, 100))
v = setup.full((2, 4, 100))
with pytest.raises(Exception):
einx.dot("b t (h ck), b t (h cv) -> b h ck cv", k, v, h=32, graph=True)


@pytest.mark.parametrize("test", conftest.tests)
def test_shape_reduce(test):
Expand Down

0 comments on commit ee99cd2

Please sign in to comment.