Skip to content

Commit

Permalink
[CP-SAT] cleanup python layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Jan 19, 2025
1 parent 8e9930b commit 67efef0
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 118 deletions.
3 changes: 0 additions & 3 deletions ortools/sat/python/cp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,6 @@ def query(self, index: int) -> IntVar:
raise ValueError("Index out of bounds.")
return self.__var_list[index]

def num_variables(self) -> int:
return len(self.__var_list)

def rebuild_expr(
self,
proto: cp_model_pb2.LinearExpressionProto,
Expand Down
142 changes: 42 additions & 100 deletions ortools/sat/python/cp_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,95 +568,36 @@ PYBIND11_MODULE(cp_model_helper, m) {
},
DOC(operations_research, sat, python, LinearExpr, IsInteger))
// Operators.
// Note that we keep the 3 APIS (expr, int, double) instead of using
// an ExprOrValue argument as this is more efficient.
.def(
"__add__",
[](std::shared_ptr<LinearExpr> expr,
std::shared_ptr<LinearExpr> other) -> std::shared_ptr<LinearExpr> {
return expr->Add(other);
},
arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, Add))
.def(
"__add__",
[](std::shared_ptr<LinearExpr> expr, int64_t cst)
-> std::shared_ptr<LinearExpr> { return expr->AddInt(cst); },
arg("cst"), DOC(operations_research, sat, python, LinearExpr, AddInt))
.def(
"__add__",
[](std::shared_ptr<LinearExpr> expr, double cst)
-> std::shared_ptr<LinearExpr> { return expr->AddFloat(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, AddFloat))
.def(
"__radd__",
[](std::shared_ptr<LinearExpr> expr, int64_t cst)
-> std::shared_ptr<LinearExpr> { return expr->AddInt(cst); },
arg("cst"), DOC(operations_research, sat, python, LinearExpr, AddInt))
.def(
"__radd__",
[](std::shared_ptr<LinearExpr> expr, double cst)
-> std::shared_ptr<LinearExpr> { return expr->AddFloat(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, AddFloat))
.def(
"__sub__",
[](std::shared_ptr<LinearExpr> expr,
std::shared_ptr<LinearExpr> other) -> std::shared_ptr<LinearExpr> {
return expr->Sub(other);
},
arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, Sub))
.def(
"__sub__",
[](std::shared_ptr<LinearExpr> expr, int64_t cst)
-> std::shared_ptr<LinearExpr> { return expr->SubInt(cst); },
arg("cst"), DOC(operations_research, sat, python, LinearExpr, SubInt))
.def(
"__sub__",
[](std::shared_ptr<LinearExpr> expr, double cst)
-> std::shared_ptr<LinearExpr> { return expr->SubFloat(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, SubFloat))
.def(
"__rsub__",
[](std::shared_ptr<LinearExpr> expr, int64_t cst)
-> std::shared_ptr<LinearExpr> { return expr->RSubInt(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, RSubInt))
.def(
"__rsub__",
[](std::shared_ptr<LinearExpr> expr, double cst)
-> std::shared_ptr<LinearExpr> { return expr->RSubFloat(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, RSubFloat))
.def(
"__mul__",
[](std::shared_ptr<LinearExpr> expr, int64_t cst)
-> std::shared_ptr<LinearExpr> { return expr->MulInt(cst); },
arg("cst"), DOC(operations_research, sat, python, LinearExpr, MulInt))
.def(
"__mul__",
[](std::shared_ptr<LinearExpr> expr, double cst)
-> std::shared_ptr<LinearExpr> { return expr->MulFloat(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, MulFloat))
.def(
"__rmul__",
[](std::shared_ptr<LinearExpr> expr, int64_t cst)
-> std::shared_ptr<LinearExpr> { return expr->MulInt(cst); },
arg("cst"), DOC(operations_research, sat, python, LinearExpr, MulInt))
.def(
"__rmul__",
[](std::shared_ptr<LinearExpr> expr, double cst)
-> std::shared_ptr<LinearExpr> { return expr->MulFloat(cst); },
arg("cst"),
DOC(operations_research, sat, python, LinearExpr, MulFloat))
.def(
"__neg__",
[](std::shared_ptr<LinearExpr> expr) { return expr->Neg(); },
DOC(operations_research, sat, python, LinearExpr, Neg))
.def("__add__", &LinearExpr::Add, arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, Add))
.def("__add__", &LinearExpr::AddInt, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, AddInt))
.def("__add__", &LinearExpr::AddFloat, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, AddFloat))
.def("__radd__", &LinearExpr::AddInt, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, AddInt))
.def("__radd__", &LinearExpr::AddFloat, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, AddFloat))
.def("__sub__", &LinearExpr::Sub, arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, Sub))
.def("__sub__", &LinearExpr::SubInt, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, SubInt))
.def("__sub__", &LinearExpr::SubFloat, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, SubFloat))
.def("__rsub__", &LinearExpr::RSubInt, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, RSubInt))
.def("__rsub__", &LinearExpr::RSubFloat, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, RSubFloat))
.def("__mul__", &LinearExpr::MulInt, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, MulInt))
.def("__mul__", &LinearExpr::MulFloat, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, MulFloat))
.def("__rmul__", &LinearExpr::MulInt, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, MulInt))
.def("__rmul__", &LinearExpr::MulFloat, arg("cst"),
DOC(operations_research, sat, python, LinearExpr, MulFloat))
.def("__neg__", &LinearExpr::Neg,
DOC(operations_research, sat, python, LinearExpr, Neg))
.def(
"__eq__",
[](std::shared_ptr<LinearExpr> lhs, std::shared_ptr<LinearExpr> rhs) {
Expand Down Expand Up @@ -874,7 +815,6 @@ PYBIND11_MODULE(cp_model_helper, m) {
}
return expr->AddInt(cst);
},
arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, AddInt))
.def(
"__add__",
Expand Down Expand Up @@ -916,7 +856,6 @@ PYBIND11_MODULE(cp_model_helper, m) {
}
return expr->AddFloat(cst);
},
arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, AddFloat))
.def(
"__sub__",
Expand Down Expand Up @@ -945,7 +884,6 @@ PYBIND11_MODULE(cp_model_helper, m) {
}
return expr->SubInt(cst);
},
arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, SubInt))
.def(
"__sub__",
Expand All @@ -959,7 +897,6 @@ PYBIND11_MODULE(cp_model_helper, m) {
}
return expr->SubFloat(cst);
},
arg("other").none(false),
DOC(operations_research, sat, python, LinearExpr, SubFloat))
.def_property_readonly("num_exprs", &SumArray::num_exprs)
.def_property_readonly("int_offset", &SumArray::int_offset)
Expand All @@ -972,6 +909,8 @@ PYBIND11_MODULE(cp_model_helper, m) {
.def_property_readonly("coefficient", &FloatAffine::coefficient)
.def_property_readonly("offset", &FloatAffine::offset);

// We adding an operator like __add__(int), we need to add all overloads,
// otherwise they are not found.
py::class_<IntAffine, std::shared_ptr<IntAffine>, LinearExpr>(
m, "IntAffine", DOC(operations_research, sat, python, IntAffine))
.def(py::init<std::shared_ptr<LinearExpr>, int64_t, int64_t>())
Expand Down Expand Up @@ -1070,13 +1009,15 @@ PYBIND11_MODULE(cp_model_helper, m) {
},
DOC(operations_research, sat, python, BaseIntVar, negated))
// PEP8 Compatibility.
.def("Not", [](std::shared_ptr<BaseIntVar> self) {
if (!self->is_boolean()) {
ThrowError(PyExc_TypeError,
"negated() is only supported for Boolean variables.");
}
return self->negated();
});
.def("Not",
[](std::shared_ptr<BaseIntVar> self) {
if (!self->is_boolean()) {
ThrowError(PyExc_TypeError,
"negated() is only supported for Boolean variables.");
}
return self->negated();
})
.def("Index", &BaseIntVar::index);

py::class_<NotBooleanVariable, std::shared_ptr<NotBooleanVariable>, Literal>(
m, "NotBooleanVariable",
Expand Down Expand Up @@ -1140,6 +1081,7 @@ PYBIND11_MODULE(cp_model_helper, m) {
return not_var->negated();
},
DOC(operations_research, sat, python, NotBooleanVariable, negated));

py::class_<BoundedLinearExpression, std::shared_ptr<BoundedLinearExpression>>(
m, "BoundedLinearExpression",
DOC(operations_research, sat, python, BoundedLinearExpression))
Expand Down
59 changes: 50 additions & 9 deletions ortools/sat/python/cp_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,14 @@ def testSumParsing(self) -> None:
"FlatFloatExpr([x0(0..2), x2(0..2)], [1, -1], -3)",
)

s13 = cp_model.LinearExpr.sum(2)
self.assertEqual(str(s13), "2")
self.assertEqual(repr(s13), "IntConstant(2)")

s14 = cp_model.LinearExpr.sum(2.5)
self.assertEqual(str(s14), "2.5")
self.assertEqual(repr(s14), "FloatConstant(2.5)")

class FakeNpDTypeA:

def __init__(self):
Expand Down Expand Up @@ -677,6 +685,12 @@ def testWeightedSumParsing(self) -> None:
self.assertLen(flat_s5.vars, 4)
self.assertEqual(-2, flat_s5.offset)

s6 = cp_model.LinearExpr.weighted_sum([2], [1])
self.assertEqual(repr(s6), "IntConstant(2)")

s7 = cp_model.LinearExpr.weighted_sum([2], [1.25])
self.assertEqual(repr(s7), "FloatConstant(2.5)")

def testSumWithApi(self) -> None:
model = cp_model.CpModel()
x = [model.new_int_var(0, 2, f"x{i}") for i in range(100)]
Expand Down Expand Up @@ -1203,16 +1217,24 @@ def testRebuildFromLinearExpressionProto(self) -> None:
proto.coeffs.append(1)
proto.vars.append(y.index)
proto.coeffs.append(2)
expr1 = model.rebuild_from_linear_expression_proto(proto)
canonical_expr1 = cmh.FlatIntExpr(expr1)
self.assertEqual(canonical_expr1.vars[0], x)
self.assertEqual(canonical_expr1.vars[1], y)
self.assertEqual(canonical_expr1.coeffs[0], 1)
self.assertEqual(canonical_expr1.coeffs[1], 2)
self.assertEqual(canonical_expr1.offset, 0)
self.assertEqual(~canonical_expr1.vars[1], ~y)
self.assertRaises(TypeError, canonical_expr1.vars[0].negated)

proto.offset = 2
expr = model.rebuild_from_linear_expression_proto(proto)
canonical_expr = cmh.FlatIntExpr(expr)
self.assertEqual(canonical_expr.vars[0], x)
self.assertEqual(canonical_expr.vars[1], y)
self.assertEqual(canonical_expr.coeffs[0], 1)
self.assertEqual(canonical_expr.coeffs[1], 2)
self.assertEqual(canonical_expr.offset, 2)
self.assertEqual(~canonical_expr.vars[1], ~y)
self.assertRaises(TypeError, canonical_expr.vars[0].negated)
expr2 = model.rebuild_from_linear_expression_proto(proto)
canonical_expr2 = cmh.FlatIntExpr(expr2)
self.assertEqual(canonical_expr2.vars[0], x)
self.assertEqual(canonical_expr2.vars[1], y)
self.assertEqual(canonical_expr2.coeffs[0], 1)
self.assertEqual(canonical_expr2.coeffs[1], 2)
self.assertEqual(canonical_expr2.offset, 2)

def testAbsentInterval(self) -> None:
model = cp_model.CpModel()
Expand Down Expand Up @@ -1354,9 +1376,18 @@ def testRepr(self) -> None:
y = model.new_int_var(0, 3, "y")
z = model.new_int_var(0, 3, "z")
self.assertEqual(repr(x), "x(0..4)")
self.assertEqual(repr(x + 0), "x(0..4)")
self.assertEqual(repr(x + 0.0), "x(0..4)")
self.assertEqual(repr(x - 0), "x(0..4)")
self.assertEqual(repr(x - 0.0), "x(0..4)")
self.assertEqual(repr(x * 1), "x(0..4)")
self.assertEqual(repr(x * 1.0), "x(0..4)")
self.assertEqual(repr(x * 0), "IntConstant(0)")
self.assertEqual(repr(x * 0.0), "IntConstant(0)")
self.assertEqual(repr(x * 2), "IntAffine(expr=x(0..4), coeff=2, offset=0)")
self.assertEqual(
repr(x + 1.5), "FloatAffine(expr=x(0..4), coeff=1, offset=1.5)"
)
self.assertEqual(repr(x + y), "SumArray(x(0..4), y(0..3))")
self.assertEqual(
repr(cp_model.LinearExpr.sum([x, y, z])),
Expand All @@ -1369,6 +1400,8 @@ def testRepr(self) -> None:
i = model.new_interval_var(x, 2, y, "i")
self.assertEqual(repr(i), "i(start = x, size = 2, end = y)")
b = model.new_bool_var("b")
self.assertEqual(repr(b), "b(0..1)")
self.assertEqual(repr(~b), "NotBooleanVariable(index=3)")
x1 = model.new_int_var(0, 4, "x1")
y1 = model.new_int_var(0, 3, "y1")
j = model.new_optional_interval_var(x1, 2, y1, b, "j")
Expand Down Expand Up @@ -1761,6 +1794,14 @@ def testIntVarSeries(self) -> None:
solution = solver.values(x)
self.assertTrue((solution.values == [0, 5, 0]).all())
self.assertRaises(TypeError, x.apply, lambda x: ~x)
y = model.new_int_var_series(
name="y", index=df.index, lower_bounds=-1, upper_bounds=1
)
self.assertRaises(TypeError, y.apply, lambda x: ~x)
z = model.new_int_var_series(
name="y", index=df.index, lower_bounds=0, upper_bounds=1
)
_ = z.apply(lambda x: ~x)

def testBoolVarSeries(self) -> None:
df = pd.DataFrame([1, -1, 1], columns=["coeffs"])
Expand Down
12 changes: 6 additions & 6 deletions ortools/sat/python/linear_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ std::shared_ptr<Literal> BaseIntVar::negated() {

int NotBooleanVariable::index() const {
std::shared_ptr<BaseIntVar> var = var_.lock();
CHECK(var != nullptr);
CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code.
return -var->index() - 1;
}

Expand All @@ -774,34 +774,34 @@ int NotBooleanVariable::index() const {
*/
std::shared_ptr<Literal> NotBooleanVariable::negated() {
std::shared_ptr<BaseIntVar> var = var_.lock();
CHECK(var != nullptr);
CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code.
return var;
}

bool NotBooleanVariable::VisitAsInt(IntExprVisitor& lin, int64_t c) {
std::shared_ptr<BaseIntVar> var = var_.lock();
CHECK(var != nullptr);
CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code.
lin.AddVarCoeff(var, -c);
lin.AddConstant(c);
return true;
}

void NotBooleanVariable::VisitAsFloat(FloatExprVisitor& lin, double c) {
std::shared_ptr<BaseIntVar> var = var_.lock();
CHECK(var != nullptr);
CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code.
lin.AddVarCoeff(var, -c);
lin.AddConstant(c);
}

std::string NotBooleanVariable::ToString() const {
std::shared_ptr<BaseIntVar> var = var_.lock();
CHECK(var != nullptr);
CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code.
return absl::StrCat("not(", var->ToString(), ")");
}

std::string NotBooleanVariable::DebugString() const {
std::shared_ptr<BaseIntVar> var = var_.lock();
CHECK(var != nullptr);
CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code.
return absl::StrCat("NotBooleanVariable(index=", var->index(), ")");
}

Expand Down

0 comments on commit 67efef0

Please sign in to comment.