From 14065186bb8e3611fb7e6829ce4b109a92cd1e63 Mon Sep 17 00:00:00 2001 From: Corentin Le Molgat Date: Wed, 18 Sep 2024 11:55:03 +0200 Subject: [PATCH] sat: python remove some type assert to improve model building performance --- ortools/sat/python/cp_model.py | 77 ++++++++-------------- ortools/sat/python/cp_model_helper.py | 22 ------- ortools/sat/python/cp_model_helper_test.py | 10 --- 3 files changed, 29 insertions(+), 80 deletions(-) diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 78fd14971fa..cea258dca56 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -465,7 +465,6 @@ def __sub__(self, arg): if cmh.is_zero(arg): return self if isinstance(arg, NumberTypes): - arg = cmh.assert_is_a_number(arg) return _Sum(self, -arg) else: return _Sum(self, -arg) @@ -566,7 +565,6 @@ def __eq__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[overri if arg is None: return False if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) return BoundedLinearExpression(self, [arg, arg]) elif isinstance(arg, LinearExpr): return BoundedLinearExpression(self - arg, [0, 0]) @@ -575,22 +573,23 @@ def __eq__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[overri def __ge__(self, arg: LinearExprT) -> "BoundedLinearExpression": if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) + if arg >= INT_MAX: + raise ArithmeticError(">= INT_MAX is not supported") return BoundedLinearExpression(self, [arg, INT_MAX]) else: return BoundedLinearExpression(self - arg, [0, INT_MAX]) def __le__(self, arg: LinearExprT) -> "BoundedLinearExpression": if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) + if arg <= INT_MIN: + raise ArithmeticError("<= INT_MIN is not supported") return BoundedLinearExpression(self, [INT_MIN, arg]) else: return BoundedLinearExpression(self - arg, [INT_MIN, 0]) def __lt__(self, arg: LinearExprT) -> "BoundedLinearExpression": if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) - if arg == INT_MIN: + if arg <= INT_MIN: raise ArithmeticError("< INT_MIN is not supported") return BoundedLinearExpression(self, [INT_MIN, arg - 1]) else: @@ -598,8 +597,7 @@ def __lt__(self, arg: LinearExprT) -> "BoundedLinearExpression": def __gt__(self, arg: LinearExprT) -> "BoundedLinearExpression": if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) - if arg == INT_MAX: + if arg >= INT_MAX: raise ArithmeticError("> INT_MAX is not supported") return BoundedLinearExpression(self, [arg + 1, INT_MAX]) else: @@ -609,10 +607,9 @@ def __ne__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[overri if arg is None: return True if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) - if arg == INT_MAX: + if arg >= INT_MAX: return BoundedLinearExpression(self, [INT_MIN, INT_MAX - 1]) - elif arg == INT_MIN: + elif arg <= INT_MIN: return BoundedLinearExpression(self, [INT_MIN + 1, INT_MAX]) else: return BoundedLinearExpression( @@ -702,7 +699,6 @@ class _ProductCst(LinearExpr): """Represents the product of a LinearExpr by a constant.""" def __init__(self, expr, coeff) -> None: - coeff = cmh.assert_is_a_number(coeff) if isinstance(expr, _ProductCst): self.__expr = expr.expression() self.__coef = expr.coefficient() * coeff @@ -736,7 +732,6 @@ def __init__(self, expressions, constant=0) -> None: if isinstance(x, NumberTypes): if cmh.is_zero(x): continue - x = cmh.assert_is_a_number(x) self.__constant += x elif isinstance(x, LinearExpr): self.__expressions.append(x) @@ -776,11 +771,9 @@ def __init__(self, expressions, coefficients, constant=0) -> None: " coefficient array must have the same length." ) for e, c in zip(expressions, coefficients): - c = cmh.assert_is_a_number(c) if cmh.is_zero(c): continue if isinstance(e, NumberTypes): - e = cmh.assert_is_a_number(e) self.__constant += e * c elif isinstance(e, LinearExpr): self.__expressions.append(e) @@ -1509,9 +1502,8 @@ def add_linear_expression_in_domain( for t in coeffs_map.items(): if not isinstance(t[0], IntVar): raise TypeError("Wrong argument" + str(t)) - c = cmh.assert_is_int64(t[1]) model_ct.linear.vars.append(t[0].index) - model_ct.linear.coeffs.append(c) + model_ct.linear.coeffs.append(t[1]) model_ct.linear.domain.extend( [ cmh.capped_subtraction(x, constant) @@ -1640,12 +1632,9 @@ def add_circuit(self, arcs: Sequence[ArcT]) -> Constraint: ct = Constraint(self) model_ct = self.__model.constraints[ct.index] for arc in arcs: - tail = cmh.assert_is_int32(arc[0]) - head = cmh.assert_is_int32(arc[1]) - lit = self.get_or_make_boolean_index(arc[2]) - model_ct.circuit.tails.append(tail) - model_ct.circuit.heads.append(head) - model_ct.circuit.literals.append(lit) + model_ct.circuit.tails.append(arc[0]) + model_ct.circuit.heads.append(arc[1]) + model_ct.circuit.literals.append(self.get_or_make_boolean_index(arc[2])) return ct def add_multiple_circuit(self, arcs: Sequence[ArcT]) -> Constraint: @@ -1677,12 +1666,9 @@ def add_multiple_circuit(self, arcs: Sequence[ArcT]) -> Constraint: ct = Constraint(self) model_ct = self.__model.constraints[ct.index] for arc in arcs: - tail = cmh.assert_is_int32(arc[0]) - head = cmh.assert_is_int32(arc[1]) - lit = self.get_or_make_boolean_index(arc[2]) - model_ct.routes.tails.append(tail) - model_ct.routes.heads.append(head) - model_ct.routes.literals.append(lit) + model_ct.routes.tails.append(arc[0]) + model_ct.routes.heads.append(arc[1]) + model_ct.routes.literals.append(self.get_or_make_boolean_index(arc[2])) return ct def add_allowed_assignments( @@ -1720,15 +1706,19 @@ def add_allowed_assignments( model_ct = self.__model.constraints[ct.index] model_ct.table.vars.extend([self.get_or_make_index(x) for x in variables]) arity: int = len(variables) - for t in tuples_list: - if len(t) != arity: - raise TypeError("Tuple " + str(t) + " has the wrong arity") + for one_tuple in tuples_list: + if len(one_tuple) != arity: + raise TypeError("Tuple " + str(one_tuple) + " has the wrong arity") # duck-typing (no explicit type checks here) try: - model_ct.table.values.extend(a for b in tuples_list for a in b) + for one_tuple in tuples_list: + model_ct.table.values.extend(one_tuple) except ValueError as ex: - raise TypeError(f"add_xxx_assignment: Not an integer or does not fit in an int64_t: {ex.args}") from ex + raise TypeError( + "add_xxx_assignment: Not an integer or does not fit in an int64_t:" + f" {ex.args}" + ) from ex return ct @@ -1762,7 +1752,7 @@ def add_forbidden_assignments( "add_forbidden_assignments expects a non-empty variables array" ) - index = len(self.__model.constraints) + index: int = len(self.__model.constraints) ct: Constraint = self.add_allowed_assignments(variables, tuples_list) self.__model.constraints[index].table.negated = True return ct @@ -1829,20 +1819,15 @@ def add_automaton( model_ct.automaton.vars.extend( [self.get_or_make_index(x) for x in transition_variables] ) - starting_state = cmh.assert_is_int64(starting_state) model_ct.automaton.starting_state = starting_state for v in final_states: - v = cmh.assert_is_int64(v) model_ct.automaton.final_states.append(v) for t in transition_triples: if len(t) != 3: raise TypeError("Tuple " + str(t) + " has the wrong arity (!= 3)") - tail = cmh.assert_is_int64(t[0]) - label = cmh.assert_is_int64(t[1]) - head = cmh.assert_is_int64(t[2]) - model_ct.automaton.transition_tail.append(tail) - model_ct.automaton.transition_label.append(label) - model_ct.automaton.transition_head.append(head) + model_ct.automaton.transition_tail.append(t[0]) + model_ct.automaton.transition_label.append(t[1]) + model_ct.automaton.transition_head.append(t[2]) return ct def add_inverse( @@ -2358,7 +2343,6 @@ def new_fixed_size_interval_var( Returns: An `IntervalVar` object. """ - size = cmh.assert_is_int64(size) start_expr = self.parse_linear_expression(start) size_expr = self.parse_linear_expression(size) end_expr = self.parse_linear_expression(start + size) @@ -2545,7 +2529,6 @@ def new_optional_fixed_size_interval_var( Returns: An `IntervalVar` object. """ - size = cmh.assert_is_int64(size) start_expr = self.parse_linear_expression(start) size_expr = self.parse_linear_expression(size) end_expr = self.parse_linear_expression(start + size) @@ -2776,7 +2759,6 @@ def get_or_make_index(self, arg: VariableT) -> int: ): return -arg.expression().index - 1 if isinstance(arg, IntegralTypes): - arg = cmh.assert_is_int64(arg) return self.get_or_make_index_from_constant(arg) raise TypeError("NotSupported: model.get_or_make_index(" + str(arg) + ")") @@ -2842,9 +2824,8 @@ def parse_linear_expression( for t in coeffs_map.items(): if not isinstance(t[0], IntVar): raise TypeError("Wrong argument" + str(t)) - c = cmh.assert_is_int64(t[1]) result.vars.append(t[0].index) - result.coeffs.append(c * mult) + result.coeffs.append(t[1] * mult) return result def _set_objective(self, obj: ObjLinearExprT, minimize: bool): diff --git a/ortools/sat/python/cp_model_helper.py b/ortools/sat/python/cp_model_helper.py index 364aea8485c..8abae4cf1a0 100644 --- a/ortools/sat/python/cp_model_helper.py +++ b/ortools/sat/python/cp_model_helper.py @@ -60,26 +60,6 @@ def is_minus_one(x: Any) -> bool: return False -def assert_is_int64(x: Any) -> int: - """Asserts that x is integer and x is in [min_int_64, max_int_64] and returns it casted to an int.""" - if not isinstance(x, numbers.Integral): - raise TypeError(f"Not an integer: {x} of type {type(x)}") - x_as_int = int(x) - if x_as_int < INT_MIN or x_as_int > INT_MAX: - raise OverflowError(f"Does not fit in an int64_t: {x}") - return x_as_int - - -def assert_is_int32(x: Any) -> int: - """Asserts that x is integer and x is in [min_int_32, max_int_32] and returns it casted to an int.""" - if not isinstance(x, numbers.Integral): - raise TypeError(f"Not an integer: {x} of type {type(x)}") - x_as_int = int(x) - if x_as_int < INT32_MIN or x_as_int > INT32_MAX: - raise OverflowError(f"Does not fit in an int32_t: {x}") - return x_as_int - - def assert_is_zero_or_one(x: Any) -> int: """Asserts that x is 0 or 1 and returns it as an int.""" if not isinstance(x, numbers.Integral): @@ -110,8 +90,6 @@ def to_capped_int64(v: int) -> int: def capped_subtraction(x: int, y: int) -> int: """Saturated arithmetics. Returns x - y truncated to the int64_t range.""" - assert_is_int64(x) - assert_is_int64(y) if y == 0: return x if x == y: diff --git a/ortools/sat/python/cp_model_helper_test.py b/ortools/sat/python/cp_model_helper_test.py index 40fe7b2c25c..62c4b5c8d5a 100644 --- a/ortools/sat/python/cp_model_helper_test.py +++ b/ortools/sat/python/cp_model_helper_test.py @@ -30,16 +30,6 @@ def test_is_boolean(self): self.assertTrue(cp_model_helper.is_boolean(np.bool_(1))) self.assertTrue(cp_model_helper.is_boolean(np.bool_(0))) - def testassert_is_int64(self): - print("testassert_is_int64") - self.assertRaises(TypeError, cp_model_helper.assert_is_int64, "Hello") - self.assertRaises(TypeError, cp_model_helper.assert_is_int64, 1.2) - self.assertRaises(OverflowError, cp_model_helper.assert_is_int64, 2**63) - self.assertRaises(OverflowError, cp_model_helper.assert_is_int64, -(2**63) - 1) - cp_model_helper.assert_is_int64(123) - cp_model_helper.assert_is_int64(2**63 - 1) - cp_model_helper.assert_is_int64(-(2**63)) - def testto_capped_int64(self): print("testto_capped_int64") self.assertEqual(