Skip to content

Commit

Permalink
Merge pull request Pyomo#3196 from jsiirola/simplify-linear-expression
Browse files Browse the repository at this point in the history
Simplify expressions generated by `TemplateSumExpression`
  • Loading branch information
blnicho authored Apr 2, 2024
2 parents 2a16b19 + ba105ea commit cc33f35
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 64 deletions.
2 changes: 1 addition & 1 deletion pyomo/core/expr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def size(self):
"""
return visitor.sizeof_expression(self)

def _apply_operation(self, result): # pragma: no cover
def _apply_operation(self, result):
"""
Compute the values of this node given the values of its children.
Expand Down
33 changes: 18 additions & 15 deletions pyomo/core/expr/numeric_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,8 +2288,11 @@ def _iadd_mutablenpvsum_mutable(a, b):
def _iadd_mutablenpvsum_native(a, b):
if not b:
return a
a._args_.append(b)
a._nargs += 1
if a._args_ and a._args_[-1].__class__ in native_numeric_types:
a._args_[-1] += b
else:
a._args_.append(b)
a._nargs += 1
return a


Expand All @@ -2301,9 +2304,7 @@ def _iadd_mutablenpvsum_npv(a, b):

def _iadd_mutablenpvsum_param(a, b):
if b.is_constant():
b = b.value
if not b:
return a
return _iadd_mutablesum_native(a, b.value)
a._args_.append(b)
a._nargs += 1
return a
Expand Down Expand Up @@ -2384,8 +2385,11 @@ def _iadd_mutablelinear_mutable(a, b):
def _iadd_mutablelinear_native(a, b):
if not b:
return a
a._args_.append(b)
a._nargs += 1
if a._args_ and a._args_[-1].__class__ in native_numeric_types:
a._args_[-1] += b
else:
a._args_.append(b)
a._nargs += 1
return a


Expand All @@ -2397,9 +2401,7 @@ def _iadd_mutablelinear_npv(a, b):

def _iadd_mutablelinear_param(a, b):
if b.is_constant():
b = b.value
if not b:
return a
return _iadd_mutablesum_native(a, b.value)
a._args_.append(b)
a._nargs += 1
return a
Expand Down Expand Up @@ -2483,8 +2485,11 @@ def _iadd_mutablesum_mutable(a, b):
def _iadd_mutablesum_native(a, b):
if not b:
return a
a._args_.append(b)
a._nargs += 1
if a._args_ and a._args_[-1].__class__ in native_numeric_types:
a._args_[-1] += b
else:
a._args_.append(b)
a._nargs += 1
return a


Expand All @@ -2496,9 +2501,7 @@ def _iadd_mutablesum_npv(a, b):

def _iadd_mutablesum_param(a, b):
if b.is_constant():
b = b.value
if not b:
return a
return _iadd_mutablesum_native(a, b.value)
a._args_.append(b)
a._nargs += 1
return a
Expand Down
33 changes: 17 additions & 16 deletions pyomo/core/expr/template_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from pyomo.core.expr.base import ExpressionBase, ExpressionArgs_Mixin, NPV_Mixin
from pyomo.core.expr.logical_expr import BooleanExpression
from pyomo.core.expr.numeric_expr import (
ARG_TYPE,
NumericExpression,
SumExpression,
Numeric_NPV_Mixin,
SumExpression,
mutable_expression,
register_arg_type,
ARG_TYPE,
_balanced_parens,
)
from pyomo.core.expr.numvalue import (
Expand Down Expand Up @@ -116,18 +117,10 @@ def _to_string(self, values, verbose, smap):
return "%s[%s]" % (values[0], ','.join(values[1:]))

def _resolve_template(self, args):
return args[0].__getitem__(tuple(args[1:]))
return args[0].__getitem__(args[1:])

def _apply_operation(self, result):
args = tuple(
(
arg
if arg.__class__ in native_types or not arg.is_numeric_type()
else value(arg)
)
for arg in result[1:]
)
return result[0].__getitem__(tuple(result[1:]))
return result[0].__getitem__(result[1:])


class Numeric_GetItemExpression(GetItemExpression, NumericExpression):
Expand Down Expand Up @@ -258,8 +251,8 @@ def nargs(self):
return 2

def _apply_operation(self, result):
assert len(result) == 2
return getattr(result[0], result[1])
obj, attr = result
return getattr(obj, attr)

def _to_string(self, values, verbose, smap):
assert len(values) == 2
Expand All @@ -273,7 +266,7 @@ def _to_string(self, values, verbose, smap):
return "%s.%s" % (values[0], attr)

def _resolve_template(self, args):
return getattr(*tuple(args))
return getattr(*args)


class Numeric_GetAttrExpression(GetAttrExpression, NumericExpression):
Expand Down Expand Up @@ -521,7 +514,15 @@ def _to_string(self, values, verbose, smap):
return 'SUM(%s %s)' % (val, iterStr)

def _resolve_template(self, args):
return SumExpression(args)
with mutable_expression() as e:
for arg in args:
e += arg
if e.nargs() > 1:
return e
elif not e.nargs():
return 0
else:
return e.arg(0)


class IndexTemplate(NumericValue):
Expand Down
8 changes: 4 additions & 4 deletions pyomo/core/tests/unit/test_numeric_expr_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6490,11 +6490,11 @@ def test_mutable_nvp_iadd(self):
(mutable_npv, self.invalid, NotImplemented),
(mutable_npv, self.asbinary, _MutableLinearExpression([10, self.bin])),
(mutable_npv, self.zero, _MutableNPVSumExpression([10])),
(mutable_npv, self.one, _MutableNPVSumExpression([10, 1])),
(mutable_npv, self.one, _MutableNPVSumExpression([11])),
# 4:
(mutable_npv, self.native, _MutableNPVSumExpression([10, 5])),
(mutable_npv, self.native, _MutableNPVSumExpression([15])),
(mutable_npv, self.npv, _MutableNPVSumExpression([10, self.npv])),
(mutable_npv, self.param, _MutableNPVSumExpression([10, 6])),
(mutable_npv, self.param, _MutableNPVSumExpression([16])),
(
mutable_npv,
self.param_mut,
Expand Down Expand Up @@ -6534,7 +6534,7 @@ def test_mutable_nvp_iadd(self):
_MutableSumExpression([10] + self.mutable_l2.args),
),
(mutable_npv, self.param0, _MutableNPVSumExpression([10])),
(mutable_npv, self.param1, _MutableNPVSumExpression([10, 1])),
(mutable_npv, self.param1, _MutableNPVSumExpression([11])),
# 20:
(mutable_npv, self.mutable_l3, _MutableNPVSumExpression([10, self.npv])),
]
Expand Down
8 changes: 4 additions & 4 deletions pyomo/core/tests/unit/test_numeric_expr_zerofilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6020,11 +6020,11 @@ def test_mutable_nvp_iadd(self):
(mutable_npv, self.invalid, NotImplemented),
(mutable_npv, self.asbinary, _MutableLinearExpression([10, self.bin])),
(mutable_npv, self.zero, _MutableNPVSumExpression([10])),
(mutable_npv, self.one, _MutableNPVSumExpression([10, 1])),
(mutable_npv, self.one, _MutableNPVSumExpression([11])),
# 4:
(mutable_npv, self.native, _MutableNPVSumExpression([10, 5])),
(mutable_npv, self.native, _MutableNPVSumExpression([15])),
(mutable_npv, self.npv, _MutableNPVSumExpression([10, self.npv])),
(mutable_npv, self.param, _MutableNPVSumExpression([10, 6])),
(mutable_npv, self.param, _MutableNPVSumExpression([16])),
(
mutable_npv,
self.param_mut,
Expand Down Expand Up @@ -6064,7 +6064,7 @@ def test_mutable_nvp_iadd(self):
_MutableSumExpression([10] + self.mutable_l2.args),
),
(mutable_npv, self.param0, _MutableNPVSumExpression([10])),
(mutable_npv, self.param1, _MutableNPVSumExpression([10, 1])),
(mutable_npv, self.param1, _MutableNPVSumExpression([11])),
# 20:
(mutable_npv, self.mutable_l3, _MutableNPVSumExpression([10, self.npv])),
]
Expand Down
48 changes: 24 additions & 24 deletions pyomo/core/tests/unit/test_template_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,14 @@ def c(m):
self.assertEqual(
str(resolve_template(template)),
'x[1,1,10] + '
'(x[2,1,10] + x[2,1,20]) + '
'(x[3,1,10] + x[3,1,20] + x[3,1,30]) + '
'(x[1,2,10]) + '
'(x[2,2,10] + x[2,2,20]) + '
'(x[3,2,10] + x[3,2,20] + x[3,2,30]) + '
'(x[1,3,10]) + '
'(x[2,3,10] + x[2,3,20]) + '
'(x[3,3,10] + x[3,3,20] + x[3,3,30]) <= 0',
'x[2,1,10] + x[2,1,20] + '
'x[3,1,10] + x[3,1,20] + x[3,1,30] + '
'x[1,2,10] + '
'x[2,2,10] + x[2,2,20] + '
'x[3,2,10] + x[3,2,20] + x[3,2,30] + '
'x[1,3,10] + '
'x[2,3,10] + x[2,3,20] + '
'x[3,3,10] + x[3,3,20] + x[3,3,30] <= 0',
)

def test_multidim_nested_sum_rule(self):
Expand Down Expand Up @@ -566,14 +566,14 @@ def c(m):
self.assertEqual(
str(resolve_template(template)),
'x[1,1,10] + '
'(x[2,1,10] + x[2,1,20]) + '
'(x[3,1,10] + x[3,1,20] + x[3,1,30]) + '
'(x[1,2,10]) + '
'(x[2,2,10] + x[2,2,20]) + '
'(x[3,2,10] + x[3,2,20] + x[3,2,30]) + '
'(x[1,3,10]) + '
'(x[2,3,10] + x[2,3,20]) + '
'(x[3,3,10] + x[3,3,20] + x[3,3,30]) <= 0',
'x[2,1,10] + x[2,1,20] + '
'x[3,1,10] + x[3,1,20] + x[3,1,30] + '
'x[1,2,10] + '
'x[2,2,10] + x[2,2,20] + '
'x[3,2,10] + x[3,2,20] + x[3,2,30] + '
'x[1,3,10] + '
'x[2,3,10] + x[2,3,20] + '
'x[3,3,10] + x[3,3,20] + x[3,3,30] <= 0',
)

def test_multidim_nested_getattr_sum_rule(self):
Expand Down Expand Up @@ -609,14 +609,14 @@ def c(m):
self.assertEqual(
str(resolve_template(template)),
'x[1,1,10] + '
'(x[2,1,10] + x[2,1,20]) + '
'(x[3,1,10] + x[3,1,20] + x[3,1,30]) + '
'(x[1,2,10]) + '
'(x[2,2,10] + x[2,2,20]) + '
'(x[3,2,10] + x[3,2,20] + x[3,2,30]) + '
'(x[1,3,10]) + '
'(x[2,3,10] + x[2,3,20]) + '
'(x[3,3,10] + x[3,3,20] + x[3,3,30]) <= 0',
'x[2,1,10] + x[2,1,20] + '
'x[3,1,10] + x[3,1,20] + x[3,1,30] + '
'x[1,2,10] + '
'x[2,2,10] + x[2,2,20] + '
'x[3,2,10] + x[3,2,20] + x[3,2,30] + '
'x[1,3,10] + '
'x[2,3,10] + x[2,3,20] + '
'x[3,3,10] + x[3,3,20] + x[3,3,30] <= 0',
)

def test_eval_getattr(self):
Expand Down

0 comments on commit cc33f35

Please sign in to comment.