Skip to content

Commit

Permalink
Merge pull request Pyomo#3194 from jsiirola/dispatcher-simplification
Browse files Browse the repository at this point in the history
Support "default" dispatchers in `ExitNodeDispatcher`
  • Loading branch information
blnicho authored Apr 2, 2024
2 parents cc33f35 + 440bf86 commit 69082ac
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 153 deletions.
78 changes: 23 additions & 55 deletions pyomo/repn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ def _handle_negation_ANY(visitor, node, arg):


_exit_node_handlers[NegationExpression] = {
None: _handle_negation_ANY,
(_CONSTANT,): _handle_negation_constant,
(_LINEAR,): _handle_negation_ANY,
(_GENERAL,): _handle_negation_ANY,
}

#
Expand All @@ -211,20 +210,18 @@ def _handle_negation_ANY(visitor, node, arg):


def _handle_product_constant_constant(visitor, node, arg1, arg2):
_, arg1 = arg1
_, arg2 = arg2
ans = arg1 * arg2
ans = arg1[1] * arg2[1]
if ans != ans:
if not arg1 or not arg2:
if not arg1[1] or not arg2[1]:
deprecation_warning(
f"Encountered {str(arg1)}*{str(arg2)} in expression tree. "
f"Encountered {str(arg1[1])}*{str(arg2[1])} in expression tree. "
"Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
return _, 0
return _, arg1 * arg2
return _CONSTANT, 0
return _CONSTANT, ans


def _handle_product_constant_ANY(visitor, node, arg1, arg2):
Expand Down Expand Up @@ -276,15 +273,12 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):


_exit_node_handlers[ProductExpression] = {
None: _handle_product_nonlinear,
(_CONSTANT, _CONSTANT): _handle_product_constant_constant,
(_CONSTANT, _LINEAR): _handle_product_constant_ANY,
(_CONSTANT, _GENERAL): _handle_product_constant_ANY,
(_LINEAR, _CONSTANT): _handle_product_ANY_constant,
(_LINEAR, _LINEAR): _handle_product_nonlinear,
(_LINEAR, _GENERAL): _handle_product_nonlinear,
(_GENERAL, _CONSTANT): _handle_product_ANY_constant,
(_GENERAL, _LINEAR): _handle_product_nonlinear,
(_GENERAL, _GENERAL): _handle_product_nonlinear,
}
_exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression]

Expand All @@ -309,24 +303,18 @@ def _handle_division_nonlinear(visitor, node, arg1, arg2):


_exit_node_handlers[DivisionExpression] = {
None: _handle_division_nonlinear,
(_CONSTANT, _CONSTANT): _handle_division_constant_constant,
(_CONSTANT, _LINEAR): _handle_division_nonlinear,
(_CONSTANT, _GENERAL): _handle_division_nonlinear,
(_LINEAR, _CONSTANT): _handle_division_ANY_constant,
(_LINEAR, _LINEAR): _handle_division_nonlinear,
(_LINEAR, _GENERAL): _handle_division_nonlinear,
(_GENERAL, _CONSTANT): _handle_division_ANY_constant,
(_GENERAL, _LINEAR): _handle_division_nonlinear,
(_GENERAL, _GENERAL): _handle_division_nonlinear,
}

#
# EXPONENTIATION handlers
#


def _handle_pow_constant_constant(visitor, node, *args):
arg1, arg2 = args
def _handle_pow_constant_constant(visitor, node, arg1, arg2):
ans = apply_node_operation(node, (arg1[1], arg2[1]))
if ans.__class__ in native_complex_types:
ans = complex_number_error(ans, visitor, node)
Expand Down Expand Up @@ -358,15 +346,10 @@ def _handle_pow_nonlinear(visitor, node, arg1, arg2):


_exit_node_handlers[PowExpression] = {
None: _handle_pow_nonlinear,
(_CONSTANT, _CONSTANT): _handle_pow_constant_constant,
(_CONSTANT, _LINEAR): _handle_pow_nonlinear,
(_CONSTANT, _GENERAL): _handle_pow_nonlinear,
(_LINEAR, _CONSTANT): _handle_pow_ANY_constant,
(_LINEAR, _LINEAR): _handle_pow_nonlinear,
(_LINEAR, _GENERAL): _handle_pow_nonlinear,
(_GENERAL, _CONSTANT): _handle_pow_ANY_constant,
(_GENERAL, _LINEAR): _handle_pow_nonlinear,
(_GENERAL, _GENERAL): _handle_pow_nonlinear,
}

#
Expand All @@ -389,9 +372,8 @@ def _handle_unary_nonlinear(visitor, node, arg):


_exit_node_handlers[UnaryFunctionExpression] = {
None: _handle_unary_nonlinear,
(_CONSTANT,): _handle_unary_constant,
(_LINEAR,): _handle_unary_nonlinear,
(_GENERAL,): _handle_unary_nonlinear,
}
_exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression]

Expand All @@ -414,9 +396,8 @@ def _handle_named_ANY(visitor, node, arg1):


_exit_node_handlers[Expression] = {
None: _handle_named_ANY,
(_CONSTANT,): _handle_named_constant,
(_LINEAR,): _handle_named_ANY,
(_GENERAL,): _handle_named_ANY,
}

#
Expand Down Expand Up @@ -449,12 +430,7 @@ def _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3):
return _GENERAL, ans


_exit_node_handlers[Expr_ifExpression] = {
(i, j, k): _handle_expr_if_nonlinear
for i in (_LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
for k in (_CONSTANT, _LINEAR, _GENERAL)
}
_exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear}
for j in (_CONSTANT, _LINEAR, _GENERAL):
for k in (_CONSTANT, _LINEAR, _GENERAL):
_exit_node_handlers[Expr_ifExpression][_CONSTANT, j, k] = _handle_expr_if_const
Expand Down Expand Up @@ -487,11 +463,9 @@ def _handle_equality_general(visitor, node, arg1, arg2):


_exit_node_handlers[EqualityExpression] = {
(i, j): _handle_equality_general
for i in (_CONSTANT, _LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
None: _handle_equality_general,
(_CONSTANT, _CONSTANT): _handle_equality_const,
}
_exit_node_handlers[EqualityExpression][_CONSTANT, _CONSTANT] = _handle_equality_const


def _handle_inequality_const(visitor, node, arg1, arg2):
Expand All @@ -517,13 +491,9 @@ def _handle_inequality_general(visitor, node, arg1, arg2):


_exit_node_handlers[InequalityExpression] = {
(i, j): _handle_inequality_general
for i in (_CONSTANT, _LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
None: _handle_inequality_general,
(_CONSTANT, _CONSTANT): _handle_inequality_const,
}
_exit_node_handlers[InequalityExpression][
_CONSTANT, _CONSTANT
] = _handle_inequality_const


def _handle_ranged_const(visitor, node, arg1, arg2, arg3):
Expand Down Expand Up @@ -554,14 +524,9 @@ def _handle_ranged_general(visitor, node, arg1, arg2, arg3):


_exit_node_handlers[RangedExpression] = {
(i, j, k): _handle_ranged_general
for i in (_CONSTANT, _LINEAR, _GENERAL)
for j in (_CONSTANT, _LINEAR, _GENERAL)
for k in (_CONSTANT, _LINEAR, _GENERAL)
None: _handle_ranged_general,
(_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const,
}
_exit_node_handlers[RangedExpression][
_CONSTANT, _CONSTANT, _CONSTANT
] = _handle_ranged_const


class LinearBeforeChildDispatcher(BeforeChildDispatcher):
Expand Down Expand Up @@ -754,7 +719,10 @@ def _initialize_exit_node_dispatcher(exit_handlers):
exit_dispatcher = {}
for cls, handlers in exit_handlers.items():
for args, fcn in handlers.items():
exit_dispatcher[(cls, *args)] = fcn
if args is None:
exit_dispatcher[cls] = fcn
else:
exit_dispatcher[(cls, *args)] = fcn
return exit_dispatcher


Expand Down
89 changes: 16 additions & 73 deletions pyomo/repn/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,119 +277,62 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):

_exit_node_handlers[ProductExpression].update(
{
None: _handle_product_nonlinear,
(_CONSTANT, _QUADRATIC): linear._handle_product_constant_ANY,
(_LINEAR, _QUADRATIC): _handle_product_nonlinear,
(_QUADRATIC, _QUADRATIC): _handle_product_nonlinear,
(_GENERAL, _QUADRATIC): _handle_product_nonlinear,
(_QUADRATIC, _CONSTANT): linear._handle_product_ANY_constant,
(_QUADRATIC, _LINEAR): _handle_product_nonlinear,
(_QUADRATIC, _GENERAL): _handle_product_nonlinear,
# Replace handler from the linear walker
(_LINEAR, _LINEAR): _handle_product_linear_linear,
(_GENERAL, _GENERAL): _handle_product_nonlinear,
(_GENERAL, _LINEAR): _handle_product_nonlinear,
(_LINEAR, _GENERAL): _handle_product_nonlinear,
}
)

#
# DIVISION
#
_exit_node_handlers[DivisionExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_division_nonlinear,
(_LINEAR, _QUADRATIC): linear._handle_division_nonlinear,
(_QUADRATIC, _QUADRATIC): linear._handle_division_nonlinear,
(_GENERAL, _QUADRATIC): linear._handle_division_nonlinear,
(_QUADRATIC, _CONSTANT): linear._handle_division_ANY_constant,
(_QUADRATIC, _LINEAR): linear._handle_division_nonlinear,
(_QUADRATIC, _GENERAL): linear._handle_division_nonlinear,
}
{(_QUADRATIC, _CONSTANT): linear._handle_division_ANY_constant}
)


#
# EXPONENTIATION
#
_exit_node_handlers[PowExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_pow_nonlinear,
(_LINEAR, _QUADRATIC): linear._handle_pow_nonlinear,
(_QUADRATIC, _QUADRATIC): linear._handle_pow_nonlinear,
(_GENERAL, _QUADRATIC): linear._handle_pow_nonlinear,
(_QUADRATIC, _CONSTANT): linear._handle_pow_ANY_constant,
(_QUADRATIC, _LINEAR): linear._handle_pow_nonlinear,
(_QUADRATIC, _GENERAL): linear._handle_pow_nonlinear,
}
{(_QUADRATIC, _CONSTANT): linear._handle_pow_ANY_constant}
)

#
# ABS and UNARY handlers
#
_exit_node_handlers[AbsExpression][(_QUADRATIC,)] = linear._handle_unary_nonlinear
_exit_node_handlers[UnaryFunctionExpression][
(_QUADRATIC,)
] = linear._handle_unary_nonlinear
# (no changes needed)

#
# NAMED EXPRESSION handlers
#
_exit_node_handlers[Expression][(_QUADRATIC,)] = linear._handle_named_ANY
# (no changes needed)

#
# EXPR_IF handlers
#
# Note: it is easier to just recreate the entire data structure, rather
# than update it
_exit_node_handlers[Expr_ifExpression] = {
(i, j, k): linear._handle_expr_if_nonlinear
for i in (_LINEAR, _QUADRATIC, _GENERAL)
for j in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
for k in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
}
for j in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL):
for k in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL):
_exit_node_handlers[Expr_ifExpression][
_CONSTANT, j, k
] = linear._handle_expr_if_const

#
# RELATIONAL handlers
#
_exit_node_handlers[EqualityExpression].update(
_exit_node_handlers[Expr_ifExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_equality_general,
(_LINEAR, _QUADRATIC): linear._handle_equality_general,
(_QUADRATIC, _QUADRATIC): linear._handle_equality_general,
(_GENERAL, _QUADRATIC): linear._handle_equality_general,
(_QUADRATIC, _CONSTANT): linear._handle_equality_general,
(_QUADRATIC, _LINEAR): linear._handle_equality_general,
(_QUADRATIC, _GENERAL): linear._handle_equality_general,
(_CONSTANT, i, _QUADRATIC): linear._handle_expr_if_const
for i in (_CONSTANT, _LINEAR, _QUADRATIC, _GENERAL)
}
)
_exit_node_handlers[InequalityExpression].update(
_exit_node_handlers[Expr_ifExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_inequality_general,
(_LINEAR, _QUADRATIC): linear._handle_inequality_general,
(_QUADRATIC, _QUADRATIC): linear._handle_inequality_general,
(_GENERAL, _QUADRATIC): linear._handle_inequality_general,
(_QUADRATIC, _CONSTANT): linear._handle_inequality_general,
(_QUADRATIC, _LINEAR): linear._handle_inequality_general,
(_QUADRATIC, _GENERAL): linear._handle_inequality_general,
}
)
_exit_node_handlers[RangedExpression].update(
{
(_CONSTANT, _QUADRATIC): linear._handle_ranged_general,
(_LINEAR, _QUADRATIC): linear._handle_ranged_general,
(_QUADRATIC, _QUADRATIC): linear._handle_ranged_general,
(_GENERAL, _QUADRATIC): linear._handle_ranged_general,
(_QUADRATIC, _CONSTANT): linear._handle_ranged_general,
(_QUADRATIC, _LINEAR): linear._handle_ranged_general,
(_QUADRATIC, _GENERAL): linear._handle_ranged_general,
(_CONSTANT, _QUADRATIC, i): linear._handle_expr_if_const
for i in (_CONSTANT, _LINEAR, _GENERAL)
}
)

#
# RELATIONAL handlers
#
# (no changes needed)


class QuadraticRepnVisitor(linear.LinearRepnVisitor):
Result = QuadraticRepn
Expand Down
6 changes: 2 additions & 4 deletions pyomo/repn/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,16 +718,14 @@ class UnknownExpression(NumericExpression):
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__](None, node, *node.args)
self.assertEqual(len(end), 9)
self.assertIn(UnknownExpression, end)
self.assertEqual(len(end), 8)

node = UnknownExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__, 6, 7](None, node, *node.args)
self.assertEqual(len(end), 10)
self.assertIn((UnknownExpression, 6, 7), end)
self.assertEqual(len(end), 8)

def test_BeforeChildDispatcher_registration(self):
class BeforeChildDispatcherTester(BeforeChildDispatcher):
Expand Down
Loading

0 comments on commit 69082ac

Please sign in to comment.