diff --git a/ortools/sat/2d_mandatory_overlap_propagator.cc b/ortools/sat/2d_mandatory_overlap_propagator.cc new file mode 100644 index 0000000000..c08b5ff5d1 --- /dev/null +++ b/ortools/sat/2d_mandatory_overlap_propagator.cc @@ -0,0 +1,106 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/2d_mandatory_overlap_propagator.h" + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" +#include "ortools/sat/scheduling_helpers.h" + +namespace operations_research { +namespace sat { + +int MandatoryOverlapPropagator::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + helper_.WatchAllBoxes(id); + return id; +} + +MandatoryOverlapPropagator::~MandatoryOverlapPropagator() { + if (!VLOG_IS_ON(1)) return; + std::vector> stats; + stats.push_back({"MandatoryOverlapPropagator/called_with_zero_area", + num_calls_zero_area_}); + stats.push_back({"MandatoryOverlapPropagator/called_without_zero_area", + num_calls_nonzero_area_}); + stats.push_back({"MandatoryOverlapPropagator/conflicts", num_conflicts_}); + + shared_stats_->AddStats(stats); +} + +bool MandatoryOverlapPropagator::Propagate() { + if (!helper_.SynchronizeAndSetDirection(true, true, false)) return false; + + mandatory_regions_.clear(); + mandatory_regions_index_.clear(); + bool has_zero_area_boxes = false; + absl::Span tasks = + helper_.x_helper().TaskByIncreasingNegatedStartMax(); + for (int i = tasks.size() - 1; i >= 0; --i) { + const int b = tasks[i].task_index; + if (!helper_.IsPresent(b)) continue; + const ItemWithVariableSize item = helper_.GetItemWithVariableSize(b); + if (item.x.start_max > item.x.end_min || + item.y.start_max > item.y.end_min) { + continue; + } + mandatory_regions_.push_back({.x_min = item.x.start_max, + .x_max = item.x.end_min, + .y_min = item.y.start_max, + .y_max = item.y.end_min}); + mandatory_regions_index_.push_back(b); + + if (mandatory_regions_.back().SizeX() == 0 || + mandatory_regions_.back().SizeY() == 0) { + has_zero_area_boxes = true; + } + } + std::optional> conflict; + if (has_zero_area_boxes) { + num_calls_zero_area_++; + conflict = FindOneIntersectionIfPresentWithZeroArea(mandatory_regions_); + } else { + num_calls_nonzero_area_++; + conflict = FindOneIntersectionIfPresent(mandatory_regions_); + } + + if (conflict.has_value()) { + num_conflicts_++; + return helper_.ReportConflictFromTwoBoxes( + mandatory_regions_index_[conflict->first], + mandatory_regions_index_[conflict->second]); + } + return true; +} + +void CreateAndRegisterMandatoryOverlapPropagator( + NoOverlap2DConstraintHelper* helper, Model* model, + GenericLiteralWatcher* watcher, int priority) { + MandatoryOverlapPropagator* propagator = + new MandatoryOverlapPropagator(helper, model); + watcher->SetPropagatorPriority(propagator->RegisterWith(watcher), priority); + model->TakeOwnership(propagator); +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/2d_mandatory_overlap_propagator.h b/ortools/sat/2d_mandatory_overlap_propagator.h new file mode 100644 index 0000000000..b56ac0d7f8 --- /dev/null +++ b/ortools/sat/2d_mandatory_overlap_propagator.h @@ -0,0 +1,65 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_2D_MANDATORY_OVERLAP_PROPAGATOR_H_ +#define OR_TOOLS_SAT_2D_MANDATORY_OVERLAP_PROPAGATOR_H_ + +#include +#include + +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" +#include "ortools/sat/synchronization.h" + +namespace operations_research { +namespace sat { + +// Propagator that checks that no mandatory area of two boxes overlap in +// O(N * log N) time. +void CreateAndRegisterMandatoryOverlapPropagator( + NoOverlap2DConstraintHelper* helper, Model* model, + GenericLiteralWatcher* watcher, int priority); + +// Exposed for testing. +class MandatoryOverlapPropagator : public PropagatorInterface { + public: + MandatoryOverlapPropagator(NoOverlap2DConstraintHelper* helper, Model* model) + : helper_(*helper), + shared_stats_(model->GetOrCreate()) {} + + ~MandatoryOverlapPropagator() override; + + bool Propagate() final; + int RegisterWith(GenericLiteralWatcher* watcher); + + private: + NoOverlap2DConstraintHelper& helper_; + SharedStatistics* shared_stats_; + std::vector mandatory_regions_; + std::vector mandatory_regions_index_; + + int64_t num_conflicts_ = 0; + int64_t num_calls_zero_area_ = 0; + int64_t num_calls_nonzero_area_ = 0; + + MandatoryOverlapPropagator(const MandatoryOverlapPropagator&) = delete; + MandatoryOverlapPropagator& operator=(const MandatoryOverlapPropagator&) = + delete; +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_2D_MANDATORY_OVERLAP_PROPAGATOR_H_ diff --git a/ortools/sat/2d_try_edge_propagator.cc b/ortools/sat/2d_try_edge_propagator.cc index 95af23535f..e5a49ba49d 100644 --- a/ortools/sat/2d_try_edge_propagator.cc +++ b/ortools/sat/2d_try_edge_propagator.cc @@ -301,6 +301,9 @@ std::vector TryEdgeRectanglePropagator::GetMinimumProblemWithPropagation( } // Now gather the data per box to make easier to use the set cover solver API. + // TODO(user): skip the boxes that are fixed at level zero. They do not + // contribute to the size of the explanation (so we shouldn't minimize their + // number) and make the SetCover problem harder to solve. std::vector> conflicting_position_per_box( active_box_ranges_.size(), std::vector()); for (int i = 0; i < conflicts_per_x_and_y_.size(); ++i) { @@ -403,28 +406,30 @@ bool TryEdgeRectanglePropagator::ExplainAndPropagate( void CreateAndRegisterTryEdgePropagator(NoOverlap2DConstraintHelper* helper, Model* model, - GenericLiteralWatcher* watcher) { + GenericLiteralWatcher* watcher, + int priority) { TryEdgeRectanglePropagator* try_edge_propagator = new TryEdgeRectanglePropagator(true, true, false, helper, model); - watcher->SetPropagatorPriority(try_edge_propagator->RegisterWith(watcher), 5); + watcher->SetPropagatorPriority(try_edge_propagator->RegisterWith(watcher), + priority); model->TakeOwnership(try_edge_propagator); TryEdgeRectanglePropagator* try_edge_propagator_mirrored = new TryEdgeRectanglePropagator(false, true, false, helper, model); watcher->SetPropagatorPriority( - try_edge_propagator_mirrored->RegisterWith(watcher), 5); + try_edge_propagator_mirrored->RegisterWith(watcher), priority); model->TakeOwnership(try_edge_propagator_mirrored); TryEdgeRectanglePropagator* try_edge_propagator_swap = new TryEdgeRectanglePropagator(true, true, true, helper, model); watcher->SetPropagatorPriority( - try_edge_propagator_swap->RegisterWith(watcher), 5); + try_edge_propagator_swap->RegisterWith(watcher), priority); model->TakeOwnership(try_edge_propagator_swap); TryEdgeRectanglePropagator* try_edge_propagator_swap_mirrored = new TryEdgeRectanglePropagator(false, true, true, helper, model); watcher->SetPropagatorPriority( - try_edge_propagator_swap_mirrored->RegisterWith(watcher), 5); + try_edge_propagator_swap_mirrored->RegisterWith(watcher), priority); model->TakeOwnership(try_edge_propagator_swap_mirrored); } diff --git a/ortools/sat/2d_try_edge_propagator.h b/ortools/sat/2d_try_edge_propagator.h index 3ac69f4d14..dec5e31c3d 100644 --- a/ortools/sat/2d_try_edge_propagator.h +++ b/ortools/sat/2d_try_edge_propagator.h @@ -37,7 +37,8 @@ namespace sat { // it is different from the current x_min, it will propagate the new x_min. void CreateAndRegisterTryEdgePropagator(NoOverlap2DConstraintHelper* helper, Model* model, - GenericLiteralWatcher* watcher); + GenericLiteralWatcher* watcher, + int priority); // Exposed for testing. class TryEdgeRectanglePropagator : public PropagatorInterface { diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index ffc96bd55e..ac1ae8e598 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -94,6 +94,22 @@ proto_library( srcs = ["cp_model.proto"], ) +cc_library( + name = "2d_mandatory_overlap_propagator", + srcs = ["2d_mandatory_overlap_propagator.cc"], + hdrs = ["2d_mandatory_overlap_propagator.h"], + deps = [ + ":diffn_util", + ":integer", + ":model", + ":no_overlap_2d_helper", + ":scheduling_helpers", + ":synchronization", + "@com_google_absl//absl/log", + "@com_google_absl//absl/types:span", + ], +) + cc_proto_library( name = "cp_model_cc_proto", deps = [":cp_model_proto"], @@ -765,6 +781,7 @@ cc_library( ":presolve_util", ":sat_parameters_cc_proto", ":sat_solver", + ":solution_crush", ":util", "//ortools/algorithms:sparse_permutation", "//ortools/base", @@ -798,9 +815,9 @@ cc_test( ":cp_model_utils", ":model", ":presolve_context", + ":solution_crush", "//ortools/base:gmock_main", "//ortools/base:parse_test_proto", - "//ortools/base:types", "//ortools/util:affine_relation", "//ortools/util:sorted_interval_list", "@com_google_absl//absl/container:flat_hash_set", @@ -808,6 +825,24 @@ cc_test( ], ) +cc_library( + name = "solution_crush", + srcs = [ + "solution_crush.cc", + ], + hdrs = ["solution_crush.h"], + deps = [ + ":cp_model_cc_proto", + ":cp_model_utils", + ":sat_parameters_cc_proto", + "//ortools/algorithms:sparse_permutation", + "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cp_model_table", srcs = ["cp_model_table.cc"], @@ -873,6 +908,7 @@ cc_library( ":sat_parameters_cc_proto", ":sat_solver", ":simplification", + ":solution_crush", ":util", ":var_domination", "//ortools/base", @@ -975,6 +1011,7 @@ cc_library( ":cp_model_utils", ":presolve_context", ":sat_parameters_cc_proto", + ":solution_crush", ":util", "//ortools/base", "//ortools/base:stl_util", @@ -1436,6 +1473,7 @@ cc_library( ":integer_base", ":presolve_context", ":presolve_util", + ":solution_crush", ":util", "//ortools/algorithms:dynamic_partition", "//ortools/base", @@ -1723,7 +1761,9 @@ cc_library( ":sat_base", ":scheduling_helpers", ":util", + "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", ], ) @@ -2283,6 +2323,7 @@ cc_library( ":intervals", ":linear_constraint", ":model", + ":no_overlap_2d_helper", ":precedences", ":presolve_util", ":routing_cuts", @@ -2290,6 +2331,7 @@ cc_library( ":sat_parameters_cc_proto", ":sat_solver", ":scheduling_cuts", + ":scheduling_helpers", ":util", "//ortools/base", "//ortools/base:mathutil", @@ -2665,6 +2707,7 @@ cc_library( ":linear_constraint", ":linear_constraint_manager", ":model", + ":no_overlap_2d_helper", ":sat_base", ":scheduling_helpers", ":util", @@ -3204,6 +3247,7 @@ cc_library( srcs = ["diffn.cc"], hdrs = ["diffn.h"], deps = [ + ":2d_mandatory_overlap_propagator", ":2d_orthogonal_packing", ":2d_try_edge_propagator", ":cumulative_energy", @@ -3632,6 +3676,7 @@ cc_library( ":sat_base", ":sat_parameters_cc_proto", ":sat_solver", + ":solution_crush", ":symmetry_util", ":util", "//ortools/algorithms:binary_search", diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 9ac9fe0c38..df44d8936b 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -1442,18 +1442,15 @@ class ConstraintChecker { } std::optional> one_intersection; - if (!has_zero_sizes) { - absl::c_stable_sort(enforced_rectangles, - [](const Rectangle& a, const Rectangle& b) { - return a.x_min < b.x_min; - }); - one_intersection = FindOneIntersectionIfPresent(enforced_rectangles); + absl::c_stable_sort(enforced_rectangles, + [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + }); + if (has_zero_sizes) { + one_intersection = + FindOneIntersectionIfPresentWithZeroArea(enforced_rectangles); } else { - const std::vector> intersections = - FindPartialRectangleIntersections(enforced_rectangles); - if (!intersections.empty()) { - one_intersection = intersections[0]; - } + one_intersection = FindOneIntersectionIfPresent(enforced_rectangles); } if (one_intersection != std::nullopt) { diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index 95ad399d67..c706273acb 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -41,6 +41,7 @@ #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/presolve_context.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/solution_crush.h" #include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" @@ -55,6 +56,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, int64_t sum_of_negative_demand, ConstraintProto* reservoir_ct, PresolveContext* context) { + SolutionCrush& crush = context->solution_crush(); const ReservoirConstraintProto& reservoir = reservoir_ct->reservoir(); const int num_events = reservoir.time_exprs_size(); @@ -70,9 +72,9 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, std::vector level_vars(num_events); for (int i = 0; i < num_events; ++i) { level_vars[i] = context->NewIntVar(Domain(var_min, var_max)); - if (context->HintIsLoaded()) { + if (crush.HintIsLoaded()) { // The hint of active events is set later. - context->SetNewVariableHint(level_vars[i], 0); + crush.SetNewVariableHint(level_vars[i], 0); } } @@ -85,16 +87,15 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, }; std::vector active_event_hints; bool has_complete_hint = false; - if (context->HintIsLoaded()) { + if (crush.HintIsLoaded()) { has_complete_hint = true; for (int i = 0; i < num_events && has_complete_hint; ++i) { - if (context->VarHasSolutionHint( - PositiveRef(reservoir.active_literals(i)))) { - if (context->LiteralSolutionHint(reservoir.active_literals(i))) { + if (crush.VarHasSolutionHint(PositiveRef(reservoir.active_literals(i)))) { + if (crush.LiteralSolutionHint(reservoir.active_literals(i))) { const std::optional time_hint = - context->GetExpressionSolutionHint(reservoir.time_exprs(i)); + crush.GetExpressionSolutionHint(reservoir.time_exprs(i)); const std::optional change_hint = - context->GetExpressionSolutionHint(reservoir.level_changes(i)); + crush.GetExpressionSolutionHint(reservoir.level_changes(i)); if (time_hint.has_value() && change_hint.has_value()) { active_event_hints.push_back( {i, time_hint.value(), change_hint.value()}); @@ -131,8 +132,8 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, active_event_hints[j].time == active_event_hints[i].time) { if (i != j) std::swap(active_event_hints[i], active_event_hints[j]); current_level += active_event_hints[i].level_change; - context->UpdateVarSolutionHint(level_vars[active_event_hints[i].index], - current_level); + crush.UpdateVarSolutionHint(level_vars[active_event_hints[i].index], + current_level); } else { has_complete_hint = false; break; @@ -148,7 +149,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, circuit->add_heads(num_events); circuit->add_literals(all_inactive); if (has_complete_hint) { - context->SetNewVariableHint(all_inactive, active_event_hints.empty()); + crush.SetNewVariableHint(all_inactive, active_event_hints.empty()); } } @@ -175,9 +176,9 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, circuit->add_heads(i); circuit->add_literals(start_var); if (has_complete_hint) { - context->SetNewVariableHint(start_var, - !active_event_hints.empty() && - active_event_hints.front().index == i); + crush.SetNewVariableHint(start_var, + !active_event_hints.empty() && + active_event_hints.front().index == i); } // Add enforced linear for demand. @@ -200,9 +201,9 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, circuit->add_heads(num_events); circuit->add_literals(end_var); if (has_complete_hint) { - context->SetNewVariableHint(end_var, - !active_event_hints.empty() && - active_event_hints.back().index == i); + crush.SetNewVariableHint(end_var, + !active_event_hints.empty() && + active_event_hints.back().index == i); } } @@ -226,9 +227,9 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, if (has_complete_hint) { const int hint_i_index = active_event_hint_index[i]; const int hint_j_index = active_event_hint_index[j]; - context->SetNewVariableHint(arc_i_j, - hint_i_index != -1 && hint_j_index != -1 && - hint_j_index == hint_i_index + 1); + crush.SetNewVariableHint(arc_i_j, hint_i_index != -1 && + hint_j_index != -1 && + hint_j_index == hint_i_index + 1); } // Add enforced linear for time. @@ -430,16 +431,17 @@ void ExpandReservoir(ConstraintProto* reservoir_ct, PresolveContext* context) { // not(active) => new_var == 0. context->AddImplyInDomain(NegatedRef(active), new_var, Domain(0)); - if (context->HintIsLoaded() && - context->VarHasSolutionHint(PositiveRef(active))) { - if (context->LiteralSolutionHint(active)) { + SolutionCrush& crush = context->solution_crush(); + if (crush.HintIsLoaded() && + crush.VarHasSolutionHint(PositiveRef(active))) { + if (crush.LiteralSolutionHint(active)) { const std::optional demand_hint = - context->GetExpressionSolutionHint(demand); + crush.GetExpressionSolutionHint(demand); if (demand_hint.has_value()) { - context->SetNewVariableHint(new_var, demand_hint.value()); + crush.SetNewVariableHint(new_var, demand_hint.value()); } } else { - context->SetNewVariableHint(new_var, 0); + crush.SetNewVariableHint(new_var, 0); } } } @@ -542,24 +544,25 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { int64_t expr_hint = 0; int64_t mod_expr_hint = 0; int64_t target_expr_hint = 0; - if (context->HintIsLoaded()) { + SolutionCrush& crush = context->solution_crush(); + if (crush.HintIsLoaded()) { has_complete_hint = true; for (const int lit : ct->enforcement_literal()) { - if (!context->VarHasSolutionHint(PositiveRef(lit))) { + if (!crush.VarHasSolutionHint(PositiveRef(lit))) { has_complete_hint = false; break; } - enforced_hint = enforced_hint && context->LiteralSolutionHint(lit); + enforced_hint = enforced_hint && crush.LiteralSolutionHint(lit); } if (has_complete_hint && enforced_hint) { has_complete_hint = false; - std::optional hint = context->GetExpressionSolutionHint(expr); + std::optional hint = crush.GetExpressionSolutionHint(expr); if (hint.has_value()) { expr_hint = hint.value(); - hint = context->GetExpressionSolutionHint(mod_expr); + hint = crush.GetExpressionSolutionHint(mod_expr); if (hint.has_value()) { mod_expr_hint = hint.value(); - hint = context->GetExpressionSolutionHint(target_expr); + hint = crush.GetExpressionSolutionHint(target_expr); if (hint.has_value()) { target_expr_hint = hint.value(); has_complete_hint = true; @@ -582,9 +585,9 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { context->DomainSuperSetOf(mod_expr))); if (has_complete_hint) { if (enforced_hint) { - context->SetNewVariableHint(div_var, expr_hint / mod_expr_hint); + crush.SetNewVariableHint(div_var, expr_hint / mod_expr_hint); } else { - context->SetNewVariableHint(div_var, context->MinOf(div_var)); + crush.SetNewVariableHint(div_var, context->MinOf(div_var)); } } LinearExpressionProto div_expr; @@ -606,9 +609,9 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { const int prod_var = context->NewIntVar(prod_domain); if (has_complete_hint) { if (enforced_hint) { - context->SetNewVariableHint(prod_var, expr_hint - target_expr_hint); + crush.SetNewVariableHint(prod_var, expr_hint - target_expr_hint); } else { - context->SetNewVariableHint(prod_var, context->MinOf(prod_var)); + crush.SetNewVariableHint(prod_var, context->MinOf(prod_var)); } } LinearExpressionProto prod_expr; @@ -636,6 +639,7 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { void ExpandNonBinaryIntProd(ConstraintProto* ct, PresolveContext* context) { CHECK_GT(ct->int_prod().exprs_size(), 2); + SolutionCrush& crush = context->solution_crush(); std::deque terms( {ct->int_prod().exprs().begin(), ct->int_prod().exprs().end()}); while (terms.size() > 2) { @@ -645,14 +649,14 @@ void ExpandNonBinaryIntProd(ConstraintProto* ct, PresolveContext* context) { context->DomainSuperSetOf(left).ContinuousMultiplicationBy( context->DomainSuperSetOf(right)); const int new_var = context->NewIntVar(new_domain); - if (context->HintIsLoaded()) { + if (crush.HintIsLoaded()) { const std::optional left_hint = - context->GetExpressionSolutionHint(left); + crush.GetExpressionSolutionHint(left); const std::optional right_hint = - context->GetExpressionSolutionHint(right); + crush.GetExpressionSolutionHint(right); if (left_hint.has_value() && right_hint.has_value()) { - context->SetNewVariableHint(new_var, - left_hint.value() * right_hint.value()); + crush.SetNewVariableHint(new_var, + left_hint.value() * right_hint.value()); } } LinearArgumentProto* const int_prod = @@ -857,16 +861,17 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { // Second, for each expr, create a new boolean bi, and add bi => target <= ai // With exactly_one(bi) + SolutionCrush& crush = context->solution_crush(); std::vector enforcement_hints; - if (context->HintIsLoaded()) { + if (crush.HintIsLoaded()) { const std::optional target_hint = - context->GetExpressionSolutionHint(ct->lin_max().target()); + crush.GetExpressionSolutionHint(ct->lin_max().target()); if (target_hint.has_value()) { int enforcement_hint_sum = 0; enforcement_hints.reserve(num_exprs); for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { const std::optional expr_hint = - context->GetExpressionSolutionHint(expr); + crush.GetExpressionSolutionHint(expr); if (!expr_hint.has_value()) { enforcement_hints.clear(); break; @@ -886,7 +891,7 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { if (num_exprs == 2) { const int new_bool = context->NewBoolVar("lin max expansion"); if (!enforcement_hints.empty()) { - context->SetNewVariableHint(new_bool, enforcement_hints[0]); + crush.SetNewVariableHint(new_bool, enforcement_hints[0]); } enforcement_literals.push_back(new_bool); enforcement_literals.push_back(NegatedRef(new_bool)); @@ -895,7 +900,7 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { for (int i = 0; i < num_exprs; ++i) { const int new_bool = context->NewBoolVar("lin max expansion"); if (!enforcement_hints.empty()) { - context->SetNewVariableHint(new_bool, enforcement_hints[i]); + crush.SetNewVariableHint(new_bool, enforcement_hints[i]); } exactly_one->mutable_exactly_one()->add_literals(new_bool); enforcement_literals.push_back(new_bool); @@ -1162,7 +1167,8 @@ void AddImplyInReachableValues(int literal, std::vector GetAutomatonStateHints(ConstraintProto* ct, PresolveContext* context) { - if (!context->HintIsLoaded()) return {}; + SolutionCrush& crush = context->solution_crush(); + if (!crush.HintIsLoaded()) return {}; const AutomatonConstraintProto& proto = ct->automaton(); absl::flat_hash_map, int64_t> transitions; @@ -1176,7 +1182,7 @@ std::vector GetAutomatonStateHints(ConstraintProto* ct, state_hints.push_back(current_state); for (int i = 0; i < proto.exprs_size(); ++i) { const std::optional label_hint = - context->GetExpressionSolutionHint(proto.exprs(i)); + crush.GetExpressionSolutionHint(proto.exprs(i)); if (!label_hint.has_value()) return {}; const auto it = transitions.find({current_state, label_hint.value()}); if (it == transitions.end()) return {}; @@ -1218,6 +1224,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { absl::flat_hash_map in_encoding; absl::flat_hash_map out_encoding; bool removed_values = false; + SolutionCrush& crush = context->solution_crush(); const std::vector state_hints = GetAutomatonStateHints(ct, context); DCHECK(state_hints.empty() || state_hints.size() == proto.exprs_size() + 1); @@ -1320,7 +1327,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { if (states.size() == 2) { const int var = context->NewBoolVar("automaton expansion"); if (!state_hints.empty()) { - context->SetNewVariableHint(var, state_hints[time + 1] == states[0]); + crush.SetNewVariableHint(var, state_hints[time + 1] == states[0]); } out_encoding[states[0]] = var; out_encoding[states[1]] = NegatedRef(var); @@ -1372,8 +1379,8 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { out_encoding[state] = context->NewBoolVar("automaton expansion"); if (!state_hints.empty()) { - context->SetNewVariableHint(out_encoding[state], - state_hints[time + 1] == state); + crush.SetNewVariableHint(out_encoding[state], + state_hints[time + 1] == state); } } } @@ -1436,14 +1443,13 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { // TODO(user): Call and use the same heuristics as the table constraint to // expand this small table with 3 columns (i.e. compress, negate, etc...). const int64_t label_hint = - context->GetExpressionSolutionHint(proto.exprs(time)).value_or(0); + crush.GetExpressionSolutionHint(proto.exprs(time)).value_or(0); std::vector tuple_literals; if (num_tuples == 2) { const int bool_var = context->NewBoolVar("automaton expansion"); if (!state_hints.empty()) { - context->SetNewVariableHint( - bool_var, - state_hints[time] == in_states[0] && label_hint == labels[0]); + crush.SetNewVariableHint(bool_var, state_hints[time] == in_states[0] && + label_hint == labels[0]); } tuple_literals.push_back(bool_var); tuple_literals.push_back(NegatedRef(bool_var)); @@ -1464,7 +1470,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { } else { tuple_literal = context->NewBoolVar("automaton expansion"); if (!state_hints.empty()) { - context->SetNewVariableHint( + crush.SetNewVariableHint( tuple_literal, state_hints[time] == in_states[i] && label_hint == labels[i]); } @@ -2007,6 +2013,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, BoolArgumentProto* exactly_one = context->working_model->add_constraints()->mutable_exactly_one(); int exactly_one_hint_sum = 0; + SolutionCrush& crush = context->solution_crush(); std::optional table_is_active_literal = std::nullopt; // Process enforcement literals. @@ -2028,7 +2035,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, if (table_is_active_literal.has_value()) { const int inactive_lit = NegatedRef(table_is_active_literal.value()); exactly_one->add_literals(inactive_lit); - exactly_one_hint_sum += context->LiteralSolutionHintIs(inactive_lit, true); + exactly_one_hint_sum += crush.LiteralSolutionHintIs(inactive_lit, true); } int num_reused_variables = 0; @@ -2049,7 +2056,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, create_new_var = false; tuple_literals[i] = context->GetOrCreateVarValueEncoding(vars[var_index], v); - exactly_one_hint_sum += context->SolutionHint(vars[var_index]) == v; + exactly_one_hint_sum += crush.SolutionHint(vars[var_index]) == v; break; } if (create_new_var) { @@ -2065,7 +2072,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, // values T[v] (an empty set means "any value"). for (const int i : tuples_with_new_variable) { if (exactly_one_hint_sum >= 1) { - context->SetNewVariableHint(tuple_literals[i], false); + crush.SetNewVariableHint(tuple_literals[i], false); continue; } bool tuple_literal_hint = true; @@ -2073,12 +2080,12 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, const auto& values = compressed_table[i][var_index]; if (!values.empty() && std::find(values.begin(), values.end(), - context->SolutionHint(vars[var_index])) == values.end()) { + crush.SolutionHint(vars[var_index])) == values.end()) { tuple_literal_hint = false; break; } } - context->SetNewVariableHint(tuple_literals[i], tuple_literal_hint); + crush.SetNewVariableHint(tuple_literals[i], tuple_literal_hint); exactly_one_hint_sum += tuple_literal_hint; } if (num_reused_variables > 0) { @@ -2377,14 +2384,15 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, int64_t expr_hint = 0; int hint_bucket = -1; bool has_complete_hint = false; - if (context->HintIsLoaded()) { + SolutionCrush& crush = context->solution_crush(); + if (crush.HintIsLoaded()) { has_complete_hint = true; const int num_terms = ct->linear().vars().size(); - const absl::Span hint = context->SolutionHint(); + const absl::Span hint = crush.SolutionHint(); for (int i = 0; i < num_terms; ++i) { const int var = ct->linear().vars(i); DCHECK_LT(var, hint.size()); - if (!context->VarHasSolutionHint(var)) { + if (!crush.VarHasSolutionHint(var)) { has_complete_hint = false; break; } @@ -2411,7 +2419,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, const Domain rhs = ReadDomainFromProto(ct->linear()); const int slack = context->NewIntVar(rhs); if (has_complete_hint) { - context->SetNewVariableHint(slack, expr_hint); + crush.SetNewVariableHint(slack, expr_hint); } ct->mutable_linear()->add_vars(slack); ct->mutable_linear()->add_coeffs(-1); @@ -2427,7 +2435,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, // a single Boolean. single_bool = context->NewBoolVar("complex linear expansion"); if (has_complete_hint) { - context->SetNewVariableHint(single_bool, hint_bucket == 0); + crush.SetNewVariableHint(single_bool, hint_bucket == 0); } } else { clause = context->working_model->add_constraints()->mutable_bool_or(); @@ -2449,7 +2457,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, if (clause != nullptr) { subdomain_literal = context->NewBoolVar("complex linear expansion"); if (has_complete_hint) { - context->SetNewVariableHint(subdomain_literal, hint_bucket == i); + crush.SetNewVariableHint(subdomain_literal, hint_bucket == i); } clause->add_literals(subdomain_literal); domain_literals.push_back(subdomain_literal); @@ -2482,19 +2490,19 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, maintain_linear_is_enforced->add_literals(NegatedRef(e_lit)); } maintain_linear_is_enforced->add_literals(linear_is_enforced); - if (context->HintIsLoaded()) { + if (crush.HintIsLoaded()) { bool has_complete_enforced_hint = true; bool linear_is_enforced_hint = true; for (const int e_lit : enforcement_literals) { - if (!context->VarHasSolutionHint(PositiveRef(e_lit))) { + if (!crush.VarHasSolutionHint(PositiveRef(e_lit))) { has_complete_enforced_hint = false; break; } - linear_is_enforced_hint &= context->LiteralSolutionHint(e_lit); + linear_is_enforced_hint &= crush.LiteralSolutionHint(e_lit); } if (has_complete_enforced_hint) { - context->SetNewVariableHint(linear_is_enforced, - linear_is_enforced_hint); + crush.SetNewVariableHint(linear_is_enforced, + linear_is_enforced_hint); } } } diff --git a/ortools/sat/cp_model_expand_test.cc b/ortools/sat/cp_model_expand_test.cc index 5ee41c8a59..d3038f154b 100644 --- a/ortools/sat/cp_model_expand_test.cc +++ b/ortools/sat/cp_model_expand_test.cc @@ -2015,7 +2015,8 @@ TEST(FinalExpansionForLinearConstraintTest, ComplexLinearExpansion) { EXPECT_THAT(initial_model, testing::EqualsProto(expected_model)); // We should properly complete the hint and choose the bucket [4, 6]. - EXPECT_THAT(context.SolutionHint(), ::testing::ElementsAre(1, 5, 0, 1, 0)); + EXPECT_THAT(context.solution_crush().SolutionHint(), + ::testing::ElementsAre(1, 5, 0, 1, 0)); EXPECT_TRUE(context.DebugTestHintFeasibility()); } @@ -2064,7 +2065,8 @@ TEST(FinalExpansionForLinearConstraintTest, ComplexLinearExpansionWithInteger) { EXPECT_THAT(initial_model, testing::EqualsProto(expected_model)); // We should properly complete the hint with the new slack variable. - EXPECT_THAT(context.SolutionHint(), ::testing::ElementsAre(1, 5, 6)); + EXPECT_THAT(context.solution_crush().SolutionHint(), + ::testing::ElementsAre(1, 5, 6)); EXPECT_TRUE(context.DebugTestHintFeasibility()); } @@ -2149,7 +2151,7 @@ TEST(FinalExpansionForLinearConstraintTest, // We should properly complete the hint and choose the bucket [4, 6], as well // as set the new linear_is_enforced hint to true. - EXPECT_THAT(context.SolutionHint(), + EXPECT_THAT(context.solution_crush().SolutionHint(), ::testing::ElementsAre(1, 5, 1, 0, 0, 1, 1)); EXPECT_TRUE(context.DebugTestHintFeasibility()); } diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index b9e8d3d2cb..d93d7f22a6 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -81,6 +81,7 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/simplification.h" +#include "ortools/sat/solution_crush.h" #include "ortools/sat/util.h" #include "ortools/sat/var_domination.h" #include "ortools/util/affine_relation.h" @@ -504,8 +505,9 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) { // hint(enforcement) = 0. But in this case the `enforcement` hint can be // increased to 1 to preserve the hint feasibility. const int implied_literal = ct->bool_and().literals(0); - if (context_->LiteralSolutionHintIs(implied_literal, true)) { - context_->UpdateLiteralSolutionHint(enforcement, true); + SolutionCrush& crush = context_->solution_crush(); + if (crush.LiteralSolutionHintIs(implied_literal, true)) { + crush.UpdateLiteralSolutionHint(enforcement, true); } context_->StoreBooleanEqualityRelation(enforcement, implied_literal); } @@ -2330,10 +2332,10 @@ bool CpModelPresolver::RemoveSingletonInLinear(ConstraintProto* ct) { } continue; } - context_->UpdateVarSolutionHint( - ct->linear().vars(i), context_->LiteralSolutionHint(indicator) - ? other_value - : best_value); + SolutionCrush& crush = context_->solution_crush(); + crush.UpdateVarSolutionHint( + ct->linear().vars(i), + crush.LiteralSolutionHint(indicator) ? other_value : best_value); if (RefIsPositive(indicator)) { if (!context_->StoreAffineRelation(ct->linear().vars(i), indicator, other_value - best_value, @@ -2959,17 +2961,17 @@ bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) { namespace { // Set the hint in `context` for the variable in `equality` that has no hint, if // there is exactly one. Otherwise do nothing. -void MaybeComputeMissingHint(PresolveContext* context, +void MaybeComputeMissingHint(SolutionCrush& crush, const LinearConstraintProto& equality) { DCHECK(equality.domain_size() == 2 && equality.domain(0) == equality.domain(1)); - if (!context->HintIsLoaded()) return; + if (!crush.HintIsLoaded()) return; int term_with_missing_hint = -1; int64_t missing_term_value = equality.domain(0); for (int i = 0; i < equality.vars_size(); ++i) { - if (context->VarHasSolutionHint(equality.vars(i))) { + if (crush.VarHasSolutionHint(equality.vars(i))) { missing_term_value -= - context->SolutionHint(equality.vars(i)) * equality.coeffs(i); + crush.SolutionHint(equality.vars(i)) * equality.coeffs(i); } else if (term_with_missing_hint == -1) { term_with_missing_hint = i; } else { @@ -2978,7 +2980,7 @@ void MaybeComputeMissingHint(PresolveContext* context, } } if (term_with_missing_hint == -1) return; - context->SetNewVariableHint( + crush.SetNewVariableHint( equality.vars(term_with_missing_hint), missing_term_value / equality.coeffs(term_with_missing_hint)); } @@ -3112,7 +3114,7 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { const int num_constraints = context_->working_model->constraints_size(); for (int i = 0; i < num_replaced_variables; ++i) { MaybeComputeMissingHint( - context_, + context_->solution_crush(), context_->working_model->constraints(num_constraints - 1 - i).linear()); } @@ -3947,19 +3949,20 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, if (fixed) { context_->UpdateRuleStats("linear: tightened into equality"); // Compute a new `var` hint so that the lhs of `ct` is equal to `rhs`. + SolutionCrush& crush = context_->solution_crush(); int64_t var_hint = rhs.FixedValue(); bool var_hint_is_valid = true; for (int j = 0; j < num_vars; ++j) { if (j == i) continue; const int term_var = ct->linear().vars(j); - if (!context_->VarHasSolutionHint(term_var)) { + if (!crush.VarHasSolutionHint(term_var)) { var_hint_is_valid = false; break; } - var_hint -= context_->SolutionHint(term_var) * ct->linear().coeffs(j); + var_hint -= crush.SolutionHint(term_var) * ct->linear().coeffs(j); } if (var_hint_is_valid) { - context_->UpdateRefSolutionHint(var, var_hint / var_coeff); + crush.UpdateRefSolutionHint(var, var_hint / var_coeff); } FillDomainInProto(rhs, ct->mutable_linear()); negated_rhs = rhs.Negation(); @@ -9696,10 +9699,10 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // increase the objective value thanks to the `skip` test above -- the // objective domain is non-constraining, but this only guarantees that // singleton variables can freely *decrease* the objective). - if (context_->LiteralSolutionHint(a) != - context_->LiteralSolutionHint(b)) { - context_->UpdateLiteralSolutionHint(a, true); - context_->UpdateLiteralSolutionHint(b, true); + SolutionCrush& crush = context_->solution_crush(); + if (crush.LiteralSolutionHint(a) != crush.LiteralSolutionHint(b)) { + crush.UpdateLiteralSolutionHint(a, true); + crush.UpdateLiteralSolutionHint(b, true); } context_->StoreBooleanEqualityRelation(a, b); @@ -11903,13 +11906,14 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { // code below ensures this (`value2` is the 'cheapest' value the implied // domain, and `value1` the cheapest value in the variable's domain). bool enforcing_hint = true; + SolutionCrush& crush = context_->solution_crush(); for (const int enforcement_lit : ct.enforcement_literal()) { - if (context_->LiteralSolutionHintIs(enforcement_lit, false)) { + if (crush.LiteralSolutionHintIs(enforcement_lit, false)) { enforcing_hint = false; break; } } - context_->UpdateVarSolutionHint(var, enforcing_hint ? value2 : value1); + crush.UpdateVarSolutionHint(var, enforcing_hint ? value2 : value1); return (void)context_->IntersectDomainWith( var, Domain::FromValues({value1, value2})); } @@ -12148,7 +12152,7 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { int64_t offset = special_value; for (const int64_t value : encoded_values) { const int literal = context_->GetOrCreateVarValueEncoding(var, value); - const int coeff = (value - special_value); + const int64_t coeff = (value - special_value); if (RefIsPositive(literal)) { mapping_ct->mutable_linear()->add_vars(literal); mapping_ct->mutable_linear()->add_coeffs(-coeff); @@ -13769,12 +13773,13 @@ void UpdateHintInProto(PresolveContext* context) { if (context->ModelIsUnsat()) return; // Extract the new hint information from the context. + SolutionCrush& crush = context->solution_crush(); auto* mutable_hint = proto->mutable_solution_hint(); mutable_hint->clear_vars(); mutable_hint->clear_values(); const int num_vars = context->working_model->variables().size(); for (int hinted_var = 0; hinted_var < num_vars; ++hinted_var) { - if (!context->VarHasSolutionHint(hinted_var)) continue; + if (!crush.VarHasSolutionHint(hinted_var)) continue; // Note the use of ClampedSolutionHint() instead of SolutionHint() below. // This also make sure a hint of INT_MIN or INT_MAX does not overflow. @@ -13789,13 +13794,15 @@ void UpdateHintInProto(PresolveContext* context) { if (relation.representative != hinted_var) { // Lets first fetch the value of the representative. const int rep = relation.representative; - if (!context->VarHasSolutionHint(rep)) continue; - const int64_t rep_value = context->ClampedSolutionHint(rep); + if (!crush.VarHasSolutionHint(rep)) continue; + const int64_t rep_value = + crush.ClampedSolutionHint(rep, context->DomainOf(rep)); // Apply the affine relation. hinted_value = rep_value * relation.coeff + relation.offset; } else { - hinted_value = context->ClampedSolutionHint(hinted_var); + hinted_value = + crush.ClampedSolutionHint(hinted_var, context->DomainOf(hinted_var)); } mutable_hint->add_vars(hinted_var); @@ -13882,7 +13889,7 @@ CpSolverStatus CpModelPresolver::Presolve() { } } - if (!context_->HintIsLoaded()) { + if (!context_->solution_crush().HintIsLoaded()) { context_->LoadSolutionHint(); } ExpandCpModelAndCanonicalizeConstraints(); diff --git a/ortools/sat/cp_model_symmetries.cc b/ortools/sat/cp_model_symmetries.cc index 9ae3f9ed78..d29bec4bce 100644 --- a/ortools/sat/cp_model_symmetries.cc +++ b/ortools/sat/cp_model_symmetries.cc @@ -50,6 +50,7 @@ #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/solution_crush.h" #include "ortools/sat/symmetry_util.h" #include "ortools/sat/util.h" #include "ortools/util/affine_relation.h" @@ -933,10 +934,11 @@ std::vector BuildInequalityCoeffsForOrbitope( void UpdateHintAfterFixingBoolToBreakSymmetry( PresolveContext* context, int var, bool fixed_value, absl::Span> generators) { - if (!context->VarHasSolutionHint(var)) { + SolutionCrush& crush = context->solution_crush(); + if (!crush.VarHasSolutionHint(var)) { return; } - const int64_t hinted_value = context->SolutionHint(var); + const int64_t hinted_value = crush.SolutionHint(var); if (hinted_value == static_cast(fixed_value)) { return; } @@ -948,8 +950,8 @@ void UpdateHintAfterFixingBoolToBreakSymmetry( bool found_target = false; int target_var; for (int v : orbit) { - if (context->VarHasSolutionHint(v) && - context->SolutionHint(v) == static_cast(fixed_value)) { + if (crush.VarHasSolutionHint(v) && + crush.SolutionHint(v) == static_cast(fixed_value)) { found_target = true; target_var = v; break; @@ -964,11 +966,11 @@ void UpdateHintAfterFixingBoolToBreakSymmetry( const std::vector generator_idx = TracePoint(target_var, schrier_vector, generators); for (const int i : generator_idx) { - context->PermuteHintValues(*generators[i]); + crush.PermuteHintValues(*generators[i]); } - DCHECK(context->VarHasSolutionHint(var)); - DCHECK_EQ(context->SolutionHint(var), fixed_value); + DCHECK(crush.VarHasSolutionHint(var)); + DCHECK_EQ(crush.SolutionHint(var), fixed_value); } } // namespace @@ -1250,7 +1252,8 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { const int var = can_be_fixed_to_false[i]; if (orbits[var] == orbit_index) ++num_in_orbit; context->UpdateRuleStats("symmetry: fixed to false in general orbit"); - if (context->VarHasSolutionHint(var) && context->SolutionHint(var) == 1 && + SolutionCrush& crush = context->solution_crush(); + if (crush.VarHasSolutionHint(var) && crush.SolutionHint(var) == 1 && var_can_be_true_per_orbit[orbits[var]] != -1) { // We are breaking the symmetry in a way that makes the hint invalid. // We want `var` to be false, so we would naively pick a symmetry to diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index 4b76caed75..fc05d2ddec 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -30,6 +30,7 @@ #include "absl/numeric/bits.h" #include "absl/types/span.h" #include "ortools/base/logging.h" +#include "ortools/sat/2d_mandatory_overlap_propagator.h" #include "ortools/sat/2d_orthogonal_packing.h" #include "ortools/sat/2d_try_edge_propagator.h" #include "ortools/sat/cumulative_energy.h" @@ -153,6 +154,12 @@ void AddNonOverlappingRectangles(const std::vector& x, NoOverlap2DConstraintHelper* no_overlap_helper = repository->GetOrCreate2DHelper(x, y); + GenericLiteralWatcher* const watcher = + model->GetOrCreate(); + + CreateAndRegisterMandatoryOverlapPropagator(no_overlap_helper, model, watcher, + 3); + NonOverlappingRectanglesDisjunctivePropagator* constraint = new NonOverlappingRectanglesDisjunctivePropagator(no_overlap_helper, model); @@ -161,8 +168,6 @@ void AddNonOverlappingRectangles(const std::vector& x, RectanglePairwisePropagator* pairwise_propagator = new RectanglePairwisePropagator(no_overlap_helper, model); - GenericLiteralWatcher* const watcher = - model->GetOrCreate(); watcher->SetPropagatorPriority(pairwise_propagator->RegisterWith(watcher), 4); model->TakeOwnership(pairwise_propagator); @@ -213,7 +218,7 @@ void AddNonOverlappingRectangles(const std::vector& x, } if (params.use_try_edge_reasoning_in_no_overlap_2d()) { - CreateAndRegisterTryEdgePropagator(no_overlap_helper, model, watcher); + CreateAndRegisterTryEdgePropagator(no_overlap_helper, model, watcher, 5); } } @@ -606,15 +611,22 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: SchedulingConstraintHelper* x = &helper_->x_helper(); SchedulingConstraintHelper* y = &helper_->y_helper(); - const absl::flat_hash_set requested_boxes_set(requested_boxes.begin(), - requested_boxes.end()); + // Optimization: we only initialize the set if we don't have all task here. + absl::flat_hash_set requested_boxes_set; + if (requested_boxes.size() != helper_->NumBoxes()) { + requested_boxes_set = {requested_boxes.begin(), requested_boxes.end()}; + } + // Compute relevant boxes, the one with a mandatory part on y. Because we will // need to sort it this way, we consider them by increasing start max. const auto temp = y->TaskByIncreasingNegatedStartMax(); auto fixed_boxes = already_checked_fixed_boxes_.view(); for (int i = temp.size(); --i >= 0;) { const int box = temp[i].task_index; - if (!requested_boxes_set.contains(box)) continue; + if (requested_boxes.size() != helper_->NumBoxes() && + !requested_boxes_set.contains(box)) { + continue; + } // By definition, fixed boxes are always present. // Doing this check optimize a bit the case where we have many fixed boxes. @@ -874,12 +886,7 @@ bool RectanglePairwisePropagator::Propagate() { fixed_non_zero_area_boxes_.clear(); non_fixed_non_zero_area_boxes_.clear(); fixed_non_zero_area_rectangles_.clear(); - absl::flat_hash_set component_boxes = { - helper_->connected_components()[component_index].begin(), - helper_->connected_components()[component_index].end()}; - for (auto task : helper_->x_helper().TaskByIncreasingStartMin()) { - const int b = task.task_index; - if (!component_boxes.contains(b)) continue; + for (int b : helper_->connected_components()[component_index]) { if (!helper_->IsPresent(b)) continue; const auto [x_size_max, y_size_max] = helper_->GetBoxSizesMax(b); ItemWithVariableSize* box; @@ -903,16 +910,10 @@ bool RectanglePairwisePropagator::Propagate() { *box = helper_->GetItemWithVariableSize(b); } - // The only thing to propagate between two fixed boxes is a conflict, and we - // can detect those in O(N*log(N)) time. - const std::optional> fixed_conflict = - FindOneIntersectionIfPresent(fixed_non_zero_area_rectangles_); + // We ignore pairs of two fixed boxes. The only thing to propagate between + // two fixed boxes is a conflict and it should already have been taken care + // of by the MandatoryOverlapPropagator propagator. - if (fixed_conflict.has_value()) { - return helper_->ReportConflictFromTwoBoxes( - fixed_non_zero_area_boxes_[fixed_conflict->first].index, - fixed_non_zero_area_boxes_[fixed_conflict->second].index); - } RETURN_IF_FALSE(FindRestrictionsAndPropagateConflict( non_fixed_non_zero_area_boxes_, fixed_non_zero_area_boxes_, &restrictions)); diff --git a/ortools/sat/diffn_cuts.cc b/ortools/sat/diffn_cuts.cc index 85ff82f4d1..43d11f26fc 100644 --- a/ortools/sat/diffn_cuts.cc +++ b/ortools/sat/diffn_cuts.cc @@ -37,6 +37,7 @@ #include "ortools/sat/linear_constraint.h" #include "ortools/sat/linear_constraint_manager.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/util.h" @@ -53,7 +54,8 @@ const double kMinCutViolation = 1e-4; } // namespace -DiffnBaseEvent::DiffnBaseEvent(int t, SchedulingConstraintHelper* x_helper) +DiffnBaseEvent::DiffnBaseEvent(int t, + const SchedulingConstraintHelper* x_helper) : x_start_min(x_helper->StartMin(t)), x_start_max(x_helper->StartMax(t)), x_end_min(x_helper->EndMin(t)), @@ -61,7 +63,7 @@ DiffnBaseEvent::DiffnBaseEvent(int t, SchedulingConstraintHelper* x_helper) x_size_min(x_helper->SizeMin(t)) {} struct DiffnEnergyEvent : DiffnBaseEvent { - DiffnEnergyEvent(int t, SchedulingConstraintHelper* x_helper) + DiffnEnergyEvent(int t, const SchedulingConstraintHelper* x_helper) : DiffnBaseEvent(t, x_helper) {} // We need this for linearizing the energy in some cases. @@ -385,7 +387,7 @@ CutGenerator CreateNoOverlap2dEnergyCutGenerator( return result; } -DiffnCtEvent::DiffnCtEvent(int t, SchedulingConstraintHelper* x_helper) +DiffnCtEvent::DiffnCtEvent(int t, const SchedulingConstraintHelper* x_helper) : DiffnBaseEvent(t, x_helper) {} std::string DiffnCtEvent::DebugString() const { @@ -573,28 +575,30 @@ void GenerateNoOvelap2dCompletionTimeCutsWithEnergy( // TODO(user): Use demands_helper and decomposed energy. CutGenerator CreateNoOverlap2dCompletionTimeCutGenerator( - SchedulingConstraintHelper* x_helper, SchedulingConstraintHelper* y_helper, - Model* model) { + NoOverlap2DConstraintHelper* helper, Model* model) { CutGenerator result; result.only_run_at_level_zero = true; - AddIntegerVariableFromIntervals(x_helper, model, &result.vars); - AddIntegerVariableFromIntervals(y_helper, model, &result.vars); + AddIntegerVariableFromIntervals(&helper->x_helper(), model, &result.vars); + AddIntegerVariableFromIntervals(&helper->y_helper(), model, &result.vars); gtl::STLSortAndRemoveDuplicates(&result.vars); auto* product_decomposer = model->GetOrCreate(); - result.generate_cuts = [x_helper, y_helper, product_decomposer, + result.generate_cuts = [helper, product_decomposer, model](LinearConstraintManager* manager) { - if (!x_helper->SynchronizeAndSetTimeDirection(true)) return false; - if (!y_helper->SynchronizeAndSetTimeDirection(true)) return false; + if (!helper->SynchronizeAndSetDirection(true, true, false)) { + return false; + } - const int num_rectangles = x_helper->NumTasks(); + const int num_rectangles = helper->NumBoxes(); std::vector active_rectangles_indexes; active_rectangles_indexes.reserve(num_rectangles); std::vector active_rectangles; active_rectangles.reserve(num_rectangles); std::vector cached_areas(num_rectangles); + const SchedulingConstraintHelper* x_helper = &helper->x_helper(); + const SchedulingConstraintHelper* y_helper = &helper->y_helper(); for (int rect = 0; rect < num_rectangles; ++rect) { - if (!y_helper->IsPresent(rect) || !y_helper->IsPresent(rect)) continue; + if (!helper->IsPresent(rect)) continue; cached_areas[rect] = x_helper->SizeMin(rect) * y_helper->SizeMin(rect); if (cached_areas[rect] == 0) continue; @@ -626,12 +630,12 @@ CutGenerator CreateNoOverlap2dCompletionTimeCutGenerator( rectangles.push_back(active_rectangles_indexes[index]); } - auto generate_cuts = [product_decomposer, manager, model, &rectangles]( - absl::string_view cut_name, - SchedulingConstraintHelper* x_helper, - SchedulingConstraintHelper* y_helper) { + auto generate_cuts = [product_decomposer, manager, model, helper, + &rectangles](absl::string_view cut_name) { std::vector events; + const SchedulingConstraintHelper* x_helper = &helper->x_helper(); + const SchedulingConstraintHelper* y_helper = &helper->y_helper(); const auto& lp_values = manager->LpValues(); for (const int rect : rectangles) { DiffnCtEvent event(rect, x_helper); @@ -654,14 +658,22 @@ CutGenerator CreateNoOverlap2dCompletionTimeCutGenerator( /*skip_low_sizes=*/false, model, manager); }; - if (!x_helper->SynchronizeAndSetTimeDirection(true)) return false; - if (!y_helper->SynchronizeAndSetTimeDirection(true)) return false; - generate_cuts("NoOverlap2dXCompletionTime", x_helper, y_helper); - generate_cuts("NoOverlap2dYCompletionTime", y_helper, x_helper); - if (!x_helper->SynchronizeAndSetTimeDirection(false)) return false; - if (!y_helper->SynchronizeAndSetTimeDirection(false)) return false; - generate_cuts("NoOverlap2dXCompletionTimeMirror", x_helper, y_helper); - generate_cuts("NoOverlap2dYCompletionTimeMirror", y_helper, x_helper); + if (!helper->SynchronizeAndSetDirection(true, true, false)) { + return false; + } + generate_cuts("NoOverlap2dXCompletionTime"); + if (!helper->SynchronizeAndSetDirection(true, true, true)) { + return false; + } + generate_cuts("NoOverlap2dYCompletionTime"); + if (!helper->SynchronizeAndSetDirection(false, false, false)) { + return false; + } + generate_cuts("NoOverlap2dXCompletionTimeMirror"); + if (!helper->SynchronizeAndSetDirection(false, false, true)) { + return false; + } + generate_cuts("NoOverlap2dYCompletionTimeMirror"); } return true; }; diff --git a/ortools/sat/diffn_cuts.h b/ortools/sat/diffn_cuts.h index c49718d467..8532adb929 100644 --- a/ortools/sat/diffn_cuts.h +++ b/ortools/sat/diffn_cuts.h @@ -22,6 +22,7 @@ #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/scheduling_helpers.h" namespace operations_research { @@ -30,8 +31,7 @@ namespace sat { // Completion time cuts for the no_overlap_2d constraint. It actually generates // the completion time cumulative cuts in both axis. CutGenerator CreateNoOverlap2dCompletionTimeCutGenerator( - SchedulingConstraintHelper* x_helper, SchedulingConstraintHelper* y_helper, - Model* model); + NoOverlap2DConstraintHelper* helper, Model* model); // Energetic cuts for the no_overlap_2d constraint. // @@ -57,7 +57,7 @@ CutGenerator CreateNoOverlap2dEnergyCutGenerator( // Base event type for scheduling cuts. struct DiffnBaseEvent { - DiffnBaseEvent(int t, SchedulingConstraintHelper* x_helper); + DiffnBaseEvent(int t, const SchedulingConstraintHelper* x_helper); // Cache of the intervals bound on the x direction. IntegerValue x_start_min; @@ -84,7 +84,7 @@ struct DiffnBaseEvent { // capacity_max. // For a no_overlap_2d constraint, y the other dimension of the rect. struct DiffnCtEvent : DiffnBaseEvent { - DiffnCtEvent(int t, SchedulingConstraintHelper* x_helper); + DiffnCtEvent(int t, const SchedulingConstraintHelper* x_helper); // The lp value of the end of the x interval. AffineExpression x_end; diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index a15c84d3d0..99f71c470e 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -1808,7 +1809,7 @@ struct Rectangle32 { // Requires that rectangles are sorted by x_min and that sizes on both // dimensions are > 0. std::vector> FindPartialRectangleIntersectionsImpl( - absl::Span rectangles, int y_max) { + absl::Span rectangles, int32_t y_max) { // We are going to use a sweep line algorithm to find the intersections. // First, we sort the rectangles by their x coordinates, then consider a sweep // line that goes from the left to the right. See the comment on the @@ -1847,11 +1848,16 @@ std::vector> FindPartialRectangleIntersectionsImpl( return arcs; } -std::vector> FindPartialRectangleIntersections( +namespace { +struct PostProcessedResult { + std::vector rectangles_sorted_by_x_min; + std::pair bounding_box; // Always starting at (0,0). +}; + +PostProcessedResult ConvertToRectangle32WithNonZeroSizes( absl::Span rectangles) { - // This function preprocess the data and calls - // FindPartialRectangleIntersectionsImpl() to actually solve the problem - // using a sweep line algorithm. The preprocessing consists of the following: + // This function is a preprocessing function for algorithms that find overlap + // between rectangles. It does the following: // - It converts the arbitrary int64_t coordinates into a small integer by // sorting the possible values and assigning them consecutive integers. // - It grows zero size intervals to make them size one. This simplifies @@ -1987,6 +1993,7 @@ std::vector> FindPartialRectangleIntersections( prev_event = event; prev_x = x; } + const int max_x_index = cur_index + 1; std::vector sorted_rectangles32; sorted_rectangles32.reserve(rectangles.size()); @@ -1997,23 +2004,22 @@ std::vector> FindPartialRectangleIntersections( } } - gtl::STLClearObject(&x_events); - return FindPartialRectangleIntersectionsImpl( - absl::MakeSpan(sorted_rectangles32), max_y_index); + return {sorted_rectangles32, {max_x_index, max_y_index}}; } - -std::optional> FindOneIntersectionIfPresent( - absl::Span rectangles) { - DCHECK( - absl::c_is_sorted(rectangles, [](const Rectangle& a, const Rectangle& b) { - return a.x_min < b.x_min; - })); +template +std::optional> FindOneIntersectionIfPresentImpl( + absl::Span rectangles) { + using CoordinateType = std::decay_t; + DCHECK(absl::c_is_sorted(rectangles, + [](const RectangleT& a, const RectangleT& b) { + return a.x_min < b.x_min; + })); // Set of box intersection the sweep line. We only store y_min, other // coordinates can be accessed via rectangles[index].coordinate. struct Element { mutable int index; - IntegerValue y_min; + CoordinateType y_min; bool operator<(const Element& other) const { return y_min < other.y_min; } }; @@ -2022,9 +2028,9 @@ std::optional> FindOneIntersectionIfPresent( absl::btree_set interval_set; for (int i = 0; i < rectangles.size(); ++i) { - const IntegerValue x = rectangles[i].x_min; - const IntegerValue y_min = rectangles[i].y_min; - const IntegerValue y_max = rectangles[i].y_max; + const CoordinateType x = rectangles[i].x_min; + const CoordinateType y_min = rectangles[i].y_min; + const CoordinateType y_max = rectangles[i].y_max; // TODO(user): We can handle that, but it require some changes below. DCHECK_LE(y_min, y_max); @@ -2055,7 +2061,8 @@ std::optional> FindOneIntersectionIfPresent( it = interval_set.erase(it_before); } else { DCHECK_LE(it_before->y_min, y_min); - const IntegerValue y_max_before = rectangles[it_before->index].y_max; + const CoordinateType y_max_before = + rectangles[it_before->index].y_max; if (y_max_before > y_min) { // Intersection. return {{it_before->index, i}}; @@ -2086,5 +2093,30 @@ std::optional> FindOneIntersectionIfPresent( return {}; } +} // namespace + +std::vector> FindPartialRectangleIntersections( + absl::Span rectangles) { + auto postprocessed = ConvertToRectangle32WithNonZeroSizes(rectangles); + return FindPartialRectangleIntersectionsImpl( + postprocessed.rectangles_sorted_by_x_min, + postprocessed.bounding_box.second); +} + +std::optional> FindOneIntersectionIfPresent( + absl::Span rectangles) { + return FindOneIntersectionIfPresentImpl(rectangles); +} + +std::optional> FindOneIntersectionIfPresentWithZeroArea( + absl::Span rectangles) { + auto postprocessed = ConvertToRectangle32WithNonZeroSizes(rectangles); + std::optional> result = FindOneIntersectionIfPresentImpl( + absl::MakeConstSpan(postprocessed.rectangles_sorted_by_x_min)); + if (!result.has_value()) return {}; + return {{postprocessed.rectangles_sorted_by_x_min[result->first].index, + postprocessed.rectangles_sorted_by_x_min[result->second].index}}; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index a7a924f0d5..e921e78bef 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -268,6 +268,14 @@ struct ItemWithVariableSize { }; Interval x; Interval y; + + template + friend void AbslStringify(Sink& sink, const ItemWithVariableSize& item) { + absl::Format(&sink, "Item %v: [(%v..%v)-(%v..%v)] x [(%v..%v)-(%v..%v)]", + item.index, item.x.start_min, item.x.start_max, item.x.end_min, + item.x.end_max, item.y.start_min, item.y.start_max, + item.y.end_min, item.y.end_max); + } }; struct PairwiseRestriction { @@ -702,13 +710,19 @@ std::vector> FindPartialRectangleIntersections( // This function is faster that the FindPartialRectangleIntersections() if one // only want to know if there is at least one intersection. It is in O(N log N). // -// IMPORTANT: this assumes rectangles are already sorted by their x_min. +// IMPORTANT: this assumes rectangles are already sorted by their x_min and does +// not support degenerate rectangles with zero area. // // If a pair {i, j} is returned, we will have i < j, and no intersection in // the subset of rectanges in [0, j). std::optional> FindOneIntersectionIfPresent( absl::Span rectangles); +// Same as FindOneIntersectionIfPresent() but supports degenerate rectangles +// with zero area. +std::optional> FindOneIntersectionIfPresentWithZeroArea( + absl::Span rectangles); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index bc8b4862f0..a2e33a9b8a 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -336,11 +336,6 @@ NewOptionalIntervalWithVariableSize(int64_t min_start, int64_t max_end, }; } -// Cuts helpers. -void AddIntegerVariableFromIntervals(SchedulingConstraintHelper* helper, - Model* model, - std::vector* vars); - void AppendVariablesFromCapacityAndDemands( const AffineExpression& capacity, SchedulingDemandHelper* demands_helper, Model* model, std::vector* vars); diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 7c21c7cf71..fd9b36b0db 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -47,6 +47,7 @@ #include "ortools/sat/intervals.h" #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/precedences.h" #include "ortools/sat/presolve_util.h" #include "ortools/sat/routing_cuts.h" @@ -54,6 +55,7 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/scheduling_cuts.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/saturated_arithmetic.h" @@ -1712,9 +1714,11 @@ void AddNoOverlap2dCutGenerator(const ConstraintProto& ct, Model* m, intervals_repository->GetOrCreateHelper(x_intervals); SchedulingConstraintHelper* y_helper = intervals_repository->GetOrCreateHelper(y_intervals); + NoOverlap2DConstraintHelper* no_overlap_helper = + intervals_repository->GetOrCreate2DHelper(x_intervals, y_intervals); relaxation->cut_generators.push_back( - CreateNoOverlap2dCompletionTimeCutGenerator(x_helper, y_helper, m)); + CreateNoOverlap2dCompletionTimeCutGenerator(no_overlap_helper, m)); // Checks if at least one rectangle has a variable dimension or is optional. bool has_variable_part = false; diff --git a/ortools/sat/no_overlap_2d_helper.cc b/ortools/sat/no_overlap_2d_helper.cc index 77d5fd5cd8..47f3e6473a 100644 --- a/ortools/sat/no_overlap_2d_helper.cc +++ b/ortools/sat/no_overlap_2d_helper.cc @@ -18,6 +18,8 @@ #include #include +#include "absl/base/log_severity.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/sat/2d_rectangle_presolve.h" @@ -85,6 +87,15 @@ void ClearAndAddMandatoryOverlapReason(int box1, int box2, bool NoOverlap2DConstraintHelper::ReportConflictFromTwoBoxes(int box1, int box2) { + DCHECK_NE(box1, box2); + if (DEBUG_MODE) { + std::vector restrictions; + AppendPairwiseRestrictions({GetItemWithVariableSize(box1)}, + {GetItemWithVariableSize(box2)}, &restrictions); + DCHECK_EQ(restrictions.size(), 1); + DCHECK(restrictions[0].type == + PairwiseRestriction::PairwiseRestrictionType::CONFLICT); + } ClearAndAddMandatoryOverlapReason(box1, box2, x_helper_.get()); ClearAndAddMandatoryOverlapReason(box1, box2, y_helper_.get()); x_helper_->ImportOtherReasons(*y_helper_); diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 4eaac392ad..8edfa6915b 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -35,7 +35,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/logging.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" @@ -115,19 +114,7 @@ int PresolveContext::NewIntVarWithDefinition( UpdateNewConstraintsVariableUsage(); } - // We only fill the hint of the new variable if all the variable involved - // in its definition have a value. - if (hint_is_loaded_) { - int64_t new_value = 0; - for (const auto [var, coeff] : definition) { - CHECK_GE(var, 0); - CHECK_LE(var, hint_.size()); - if (!hint_has_value_[var]) return new_var; - new_value += coeff * hint_[var]; - } - hint_has_value_[new_var] = true; - hint_[new_var] = new_value; - } + solution_crush_.SetVarToLinearExpression(new_var, definition); return new_var; } @@ -138,52 +125,14 @@ int PresolveContext::NewBoolVar(absl::string_view source) { int PresolveContext::NewBoolVarWithClause(absl::Span clause) { const int new_var = NewBoolVar("with clause"); - if (hint_is_loaded_) { - int new_hint = 0; - bool all_have_hint = true; - for (const int literal : clause) { - const int var = PositiveRef(literal); - if (!hint_has_value_[var]) { - all_have_hint = false; - break; - } - if (hint_[var] == (RefIsPositive(literal) ? 1 : 0)) { - new_hint = 1; - break; - } - } - // Leave the `new_var` hint unassigned if any literal is not hinted. - if (all_have_hint) { - hint_has_value_[new_var] = true; - hint_[new_var] = new_hint; - } - } + solution_crush_.SetVarToClause(new_var, clause); return new_var; } int PresolveContext::NewBoolVarWithConjunction( absl::Span conjunction) { const int new_var = NewBoolVar("with conjunction"); - if (hint_is_loaded_) { - int new_hint = 1; - bool all_have_hint = true; - for (const int literal : conjunction) { - const int var = PositiveRef(literal); - if (!hint_has_value_[var]) { - all_have_hint = false; - break; - } - if (hint_[var] == (RefIsPositive(literal) ? 0 : 1)) { - new_hint = 0; - break; - } - } - // Leave the `new_var` hint unassigned if any literal is not hinted. - if (all_have_hint) { - hint_has_value_[new_var] = true; - hint_[new_var] = new_hint; - } - } + solution_crush_.SetVarToConjunction(new_var, conjunction); return new_var; } @@ -589,19 +538,11 @@ ABSL_MUST_USE_RESULT bool PresolveContext::IntersectDomainWithInternal( domain.ToString())); } - if (update_hint && VarHasSolutionHint(var)) { - UpdateVarSolutionHint(var, domains_[var].ClosestValue(SolutionHint(var))); + // TODO(user): always update the hint, remove the `update_hint` argument. + if (update_hint) { + solution_crush_.UpdateVarToDomain(var, domains_[var]); } -#ifdef CHECK_HINT - if (working_model->has_solution_hint() && HintIsLoaded() && - !domains_[var].Contains(hint_[var])) { - LOG(FATAL) << "Hint with value " << hint_[var] - << " infeasible when changing domain of " << var << " to " - << domains_[var]; - } -#endif - // Propagate the domain of the representative right away. // Note that the recursive call should only by one level deep. const AffineRelation::Relation r = GetAffineRelation(var); @@ -745,8 +686,9 @@ void PresolveContext::AddVariableUsage(int c) { #ifdef CHECK_HINT // Crash if the loaded hint is infeasible for this constraint. // This is helpful to debug a wrong presolve that kill a feasible solution. - if (working_model->has_solution_hint() && HintIsLoaded() && - !ConstraintIsFeasible(*working_model, ct, hint_)) { + if (working_model->has_solution_hint() && solution_crush_.HintIsLoaded() && + !ConstraintIsFeasible(*working_model, ct, + solution_crush_.SolutionHint())) { LOG(FATAL) << "Hint infeasible for constraint #" << c << " : " << ct.ShortDebugString(); } @@ -802,8 +744,9 @@ void PresolveContext::UpdateConstraintVariableUsage(int c) { #ifdef CHECK_HINT // Crash if the loaded hint is infeasible for this constraint. // This is helpful to debug a wrong presolve that kill a feasible solution. - if (working_model->has_solution_hint() && HintIsLoaded() && - !ConstraintIsFeasible(*working_model, ct, hint_)) { + if (working_model->has_solution_hint() && solution_crush_.HintIsLoaded() && + !ConstraintIsFeasible(*working_model, ct, + solution_crush_.SolutionHint())) { LOG(FATAL) << "Hint infeasible for constraint #" << c << " : " << ct.ShortDebugString(); } @@ -1136,12 +1079,6 @@ bool PresolveContext::CanonicalizeAffineVariable(int ref, int64_t coeff, return true; } -void PresolveContext::PermuteHintValues(const SparsePermutation& perm) { - CHECK(hint_is_loaded_); - perm.ApplyToDenseCollection(hint_); - perm.ApplyToDenseCollection(hint_has_value_); -} - bool PresolveContext::StoreAffineRelation(int var_x, int var_y, int64_t coeff, int64_t offset, bool debug_no_recursion) { @@ -1150,29 +1087,12 @@ bool PresolveContext::StoreAffineRelation(int var_x, int var_y, int64_t coeff, DCHECK_NE(coeff, 0); if (is_unsat_) return false; - if (hint_is_loaded_) { - if (!hint_has_value_[var_y] && hint_has_value_[var_x]) { - hint_has_value_[var_y] = true; - hint_[var_y] = (hint_[var_x] - offset) / coeff; - if (hint_[var_y] * coeff + offset != hint_[var_x]) { - // TODO(user): Do we implement a rounding to closest instead of - // routing towards 0. - UpdateRuleStats( - "Warning: hint didn't satisfy affine relation and was corrected"); - } - } + if (!solution_crush_.MaybeSetVarToAffineEquationSolution(var_x, var_y, coeff, + offset)) { + UpdateRuleStats( + "Warning: hint didn't satisfy affine relation and was corrected"); } -#ifdef CHECK_HINT - const int64_t vx = hint_[var_x]; - const int64_t vy = hint_[var_y]; - if (working_model->has_solution_hint() && HintIsLoaded() && - vx != vy * coeff + offset) { - LOG(FATAL) << "Affine relation incompatible with hint: " << vx - << " != " << vy << " * " << coeff << " + " << offset; - } -#endif - // TODO(user): I am not 100% sure why, but sometimes the representative is // fixed but that is not propagated to var_x or var_y and this causes issues. if (!PropagateAffineRelation(var_x)) return false; @@ -1419,7 +1339,7 @@ void PresolveContext::ResetAfterCopy() { var_to_constraints_.clear(); var_to_num_linear1_.clear(); objective_map_.clear(); - hint_.clear(); + DCHECK(!solution_crush_.HintIsLoaded()); } // Create the internal structure for any new variables in working_model. @@ -1447,40 +1367,35 @@ void PresolveContext::InitializeNewDomains() { } // We resize the hint too even if not loaded. - hint_.resize(new_size, 0); - hint_has_value_.resize(new_size, false); + solution_crush_.Resize(new_size); } void PresolveContext::LoadSolutionHint() { - CHECK(!hint_is_loaded_); - hint_is_loaded_ = true; + absl::flat_hash_map hint_values; if (working_model->has_solution_hint()) { const auto hint_proto = working_model->solution_hint(); - const int num_terms = hint_proto.vars().size(); int num_changes = 0; - for (int i = 0; i < num_terms; ++i) { + for (int i = 0; i < hint_proto.vars().size(); ++i) { const int var = hint_proto.vars(i); if (!RefIsPositive(var)) break; // Abort. Shouldn't happen. - if (var < hint_.size()) { - hint_has_value_[var] = true; - const int64_t hint_value = hint_proto.values(i); - hint_[var] = DomainOf(var).ClosestValue(hint_value); - if (hint_[var] != hint_value) { - ++num_changes; - } + const int64_t hint_value = hint_proto.values(i); + const int64_t clamped_hint_value = DomainOf(var).ClosestValue(hint_value); + if (clamped_hint_value != hint_value) { + ++num_changes; } + hint_values[var] = clamped_hint_value; } if (num_changes > 0) { UpdateRuleStats("hint: moved var hint within its domain.", num_changes); } - for (int i = 0; i < hint_.size(); ++i) { - if (hint_has_value_[i]) continue; - if (IsFixed(i)) { - hint_has_value_[i] = true; - hint_[i] = FixedValue(i); + for (int i = 0; i < working_model->variables().size(); ++i) { + if (!hint_values.contains(i) && IsFixed(i)) { + hint_values[i] = FixedValue(i); } } } + solution_crush_.Resize(working_model->variables().size()); + solution_crush_.LoadSolution(hint_values); } void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { @@ -1541,9 +1456,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { } else { UpdateRuleStats("variables with 2 values: create encoding literal"); max_literal = NewBoolVar("var with 2 values"); - if (hint_is_loaded_ && hint_has_value_[var]) { - SetNewVariableHint(max_literal, hint_[var] == var_max ? 1 : 0); - } + solution_crush_.MaybeSetLiteralToValueEncoding(max_literal, var, var_max); min_literal = NegatedRef(max_literal); var_map[var_min] = SavedLiteral(min_literal); var_map[var_max] = SavedLiteral(max_literal); @@ -1742,15 +1655,7 @@ bool PresolveContext::InsertVarValueEncoding(int literal, int var, eq_half_encoding_.insert({{literal, var}, value}); neq_half_encoding_.insert({{NegatedRef(literal), var}, value}); - if (hint_is_loaded_) { - const int bool_var = PositiveRef(literal); - DCHECK(RefIsPositive(var)); - if (!hint_has_value_[bool_var] && hint_has_value_[var]) { - const int64_t bool_value = hint_[var] == value ? 1 : 0; - hint_has_value_[bool_var] = true; - hint_[bool_var] = RefIsPositive(literal) ? bool_value : 1 - bool_value; - } - } + solution_crush_.MaybeSetLiteralToValueEncoding(literal, var, value); return true; } @@ -2443,20 +2348,8 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( const int result = NewBoolVar(""); reified_precedences_cache_[key] = result; - // Take care of hints. - if (hint_is_loaded_) { - std::optional time_i_hint = GetExpressionSolutionHint(time_i); - std::optional time_j_hint = GetExpressionSolutionHint(time_j); - std::optional active_i_hint = GetRefSolutionHint(active_i); - std::optional active_j_hint = GetRefSolutionHint(active_j); - if (time_i_hint.has_value() && time_j_hint.has_value() && - active_i_hint.has_value() && active_j_hint.has_value()) { - const bool reified_hint = (active_i_hint.value() != 0) && - (active_j_hint.value() != 0) && - (time_i_hint.value() <= time_j_hint.value()); - SetNewVariableHint(result, reified_hint); - } - } + solution_crush_.SetVarToReifiedPrecedenceLiteral(result, time_i, time_j, + active_i, active_j); // result => (time_i <= time_j) && active_i && active_j. ConstraintProto* const lesseq = working_model->add_constraints(); @@ -2869,8 +2762,9 @@ void CreateValidModelWithSingleConstraint(const ConstraintProto& ct, bool PresolveContext::DebugTestHintFeasibility() { WriteVariableDomainsToProto(); - if (hint_.size() != working_model->variables().size()) return false; - return SolutionIsFeasible(*working_model, hint_); + const absl::Span hint = solution_crush_.SolutionHint(); + if (hint.size() != working_model->variables().size()) return false; + return SolutionIsFeasible(*working_model, hint); } } // namespace sat diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 0aab53bfa8..44195ac37e 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -29,13 +29,13 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/logging.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/model.h" #include "ortools/sat/presolve_util.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/solution_crush.h" #include "ortools/sat/util.h" #include "ortools/util/affine_relation.h" #include "ortools/util/bitset.h" @@ -634,83 +634,7 @@ class PresolveContext { // possible (e.g. for the model proto's fixed variables). void LoadSolutionHint(); - void PermuteHintValues(const SparsePermutation& perm); - - // Solution hint accessors. - bool VarHasSolutionHint(int var) const { return hint_has_value_[var]; } - int64_t SolutionHint(int var) const { return hint_[var]; } - bool HintIsLoaded() const { return hint_is_loaded_; } - absl::Span SolutionHint() const { return hint_; } - - // Similar to SolutionHint() but make sure the value is within the current - // bounds of the variable. - int64_t ClampedSolutionHint(int var) { - int64_t value = hint_[var]; - if (value > MaxOf(var)) { - value = MaxOf(var); - } else if (value < MinOf(var)) { - value = MinOf(var); - } - return value; - } - - bool LiteralSolutionHint(int lit) const { - const int var = PositiveRef(lit); - return RefIsPositive(lit) ? hint_[var] : !hint_[var]; - } - - bool LiteralSolutionHintIs(int lit, bool value) const { - const int var = PositiveRef(lit); - return hint_is_loaded_ && hint_has_value_[var] && - hint_[var] == (RefIsPositive(lit) ? value : !value); - } - - // If the given literal is already hinted, updates its hint. - // Otherwise do nothing. - void UpdateLiteralSolutionHint(int lit, bool value) { - UpdateVarSolutionHint(PositiveRef(lit), - RefIsPositive(lit) == value ? 1 : 0); - } - - std::optional GetRefSolutionHint(int ref) { - const int var = PositiveRef(ref); - if (!VarHasSolutionHint(var)) return std::nullopt; - const int64_t var_hint = SolutionHint(var); - return RefIsPositive(ref) ? var_hint : -var_hint; - } - - std::optional GetExpressionSolutionHint( - const LinearExpressionProto& expr) { - int64_t result = expr.offset(); - for (int i = 0; i < expr.vars().size(); ++i) { - if (expr.coeffs(i) == 0) continue; - if (!VarHasSolutionHint(expr.vars(i))) return std::nullopt; - result += expr.coeffs(i) * SolutionHint(expr.vars(i)); - } - return result; - } - - void UpdateRefSolutionHint(int ref, int hint) { - UpdateVarSolutionHint(PositiveRef(ref), RefIsPositive(ref) ? hint : -hint); - } - - // If the given variable is already hinted, updates its hint value. - // Otherwise, do nothing. - void UpdateVarSolutionHint(int var, int64_t value) { - DCHECK(RefIsPositive(var)); - if (!hint_is_loaded_) return; - if (!hint_has_value_[var]) return; - hint_[var] = value; - } - - // Allows to set the hint of a newly created variable. - void SetNewVariableHint(int var, int64_t value) { - CHECK(hint_is_loaded_); - CHECK(!hint_has_value_[var]); - hint_has_value_[var] = true; - hint_[var] = value; - } - + SolutionCrush& solution_crush() { return solution_crush_; } // This is slow O(problem_size) but can be used to debug presolve, either by // pinpointing the transition from feasible to infeasible or the other way // around if for some reason the presolve drop constraint that it shouldn't. @@ -805,13 +729,7 @@ class PresolveContext { // The current domain of each variables. std::vector domains_; - // Parallel to domains. - // - // This contains all the hinted value or zero if the hint wasn't specified. - // We try to maintain this as we create new variable. - bool hint_is_loaded_ = false; - std::vector hint_has_value_; - std::vector hint_; + SolutionCrush solution_crush_; // Internal representation of the objective. During presolve, we first load // the objective in this format in order to have more efficient substitution diff --git a/ortools/sat/presolve_context_test.cc b/ortools/sat/presolve_context_test.cc index 7f4f1181ff..1d212f2136 100644 --- a/ortools/sat/presolve_context_test.cc +++ b/ortools/sat/presolve_context_test.cc @@ -23,6 +23,7 @@ #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/model.h" +#include "ortools/sat/solution_crush.h" #include "ortools/util/affine_relation.h" #include "ortools/util/sorted_interval_list.h" @@ -771,7 +772,7 @@ TEST(PresolveContextTest, IntersectDomainAndUpdateHint) { EXPECT_TRUE(context.IntersectDomainWithAndUpdateHint(0, Domain(5, 20))); EXPECT_EQ(context.DomainOf(0), Domain(5, 10)); - EXPECT_EQ(context.SolutionHint(0), 5); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 5); } TEST(PresolveContextTest, DomainSuperSetOf) { @@ -1073,20 +1074,21 @@ TEST(PresolveContextTest, LoadSolutionHint) { context.InitializeNewDomains(); context.LoadSolutionHint(); - EXPECT_TRUE(context.HintIsLoaded()); - EXPECT_TRUE(context.VarHasSolutionHint(0)); - EXPECT_TRUE(context.VarHasSolutionHint(1)); // From the fixed domain. - EXPECT_TRUE(context.VarHasSolutionHint(2)); - EXPECT_EQ(context.SolutionHint(0), 10); // Clamped to the domain. - EXPECT_EQ(context.SolutionHint(1), 5); // From the fixed domain. - EXPECT_EQ(context.SolutionHint(2), 0); - EXPECT_EQ(context.GetRefSolutionHint(0), 10); - EXPECT_EQ(context.GetRefSolutionHint(NegatedRef(0)), -10); - EXPECT_FALSE(context.LiteralSolutionHint(2)); - EXPECT_TRUE(context.LiteralSolutionHint(NegatedRef(2))); - EXPECT_TRUE(context.LiteralSolutionHintIs(2, false)); - EXPECT_TRUE(context.LiteralSolutionHintIs(NegatedRef(2), true)); - EXPECT_THAT(context.SolutionHint(), ::testing::ElementsAre(10, 5, 0)); + SolutionCrush& crush = context.solution_crush(); + EXPECT_TRUE(crush.HintIsLoaded()); + EXPECT_TRUE(crush.VarHasSolutionHint(0)); + EXPECT_TRUE(crush.VarHasSolutionHint(1)); // From the fixed domain. + EXPECT_TRUE(crush.VarHasSolutionHint(2)); + EXPECT_EQ(crush.SolutionHint(0), 10); // Clamped to the domain. + EXPECT_EQ(crush.SolutionHint(1), 5); // From the fixed domain. + EXPECT_EQ(crush.SolutionHint(2), 0); + EXPECT_EQ(crush.GetRefSolutionHint(0), 10); + EXPECT_EQ(crush.GetRefSolutionHint(NegatedRef(0)), -10); + EXPECT_FALSE(crush.LiteralSolutionHint(2)); + EXPECT_TRUE(crush.LiteralSolutionHint(NegatedRef(2))); + EXPECT_TRUE(crush.LiteralSolutionHintIs(2, false)); + EXPECT_TRUE(crush.LiteralSolutionHintIs(NegatedRef(2), true)); + EXPECT_THAT(crush.SolutionHint(), ::testing::ElementsAre(10, 5, 0)); } } // namespace diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index c764caf3cd..4e00d4c918 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -14,8 +14,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -48,21 +50,36 @@ void ThrowError(PyObject* py_exception, const std::string& message) { throw py::error_already_set(); } +// We extend the SolverWrapper class to keep track of the local error already +// set. +class ExtSolveWrapper : public SolveWrapper { + public: + mutable std::optional local_error_already_set_; +}; + // A trampoline class to override the OnSolutionCallback method to acquire the // GIL. class PySolutionCallback : public SolutionCallback { public: using SolutionCallback::SolutionCallback; /* Inherit constructors */ - void OnSolutionCallback() const override { ::py::gil_scoped_acquire acquire; - PYBIND11_OVERRIDE_PURE( - void, /* Return type */ - SolutionCallback, /* Parent class */ - OnSolutionCallback, /* Name of function */ - /* This function has no arguments. The trailing comma - in the previous line is needed for some compilers */ - ); + try { + PYBIND11_OVERRIDE_PURE( + void, /* Return type */ + SolutionCallback, /* Parent class */ + OnSolutionCallback, /* Name of function */ + /* This function has no arguments. The trailing comma + in the previous line is needed for some compilers */ + ); + } catch (py::error_already_set& e) { + // We assume this code is serialized as the gil is held. + ExtSolveWrapper* solve_wrapper = static_cast(wrapper()); + if (!solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_ = e; + } + StopSearch(); + } } }; @@ -266,13 +283,13 @@ std::shared_ptr SumArguments(py::args expressions) { // Normal list or tuple argument. py::sequence elements = expressions[0].cast(); linear_exprs.reserve(elements.size()); - for (const py::handle& arg : elements) { - process_arg(arg); + for (const py::handle& expr : elements) { + process_arg(expr); } } else { // Direct sum(x, y, 3, ..) without []. linear_exprs.reserve(expressions.size()); - for (const py::handle arg : expressions) { - process_arg(arg); + for (const py::handle expr : expressions) { + process_arg(expr); } } @@ -482,28 +499,80 @@ PYBIND11_MODULE(cp_model_helper, m) { .def("value", &ResponseWrapper::FixedValue, py::arg("value")) .def("wall_time", &ResponseWrapper::WallTime); - py::class_(m, "SolveWrapper") + py::class_(m, "SolveWrapper") .def(py::init<>()) - .def("add_log_callback", &SolveWrapper::AddLogCallback, - py::arg("log_callback")) + .def( + "add_log_callback", + [](ExtSolveWrapper* solve_wrapper, + std::function log_callback) { + std::function safe_log_callback = + [solve_wrapper, log_callback](std::string message) -> void { + ::py::gil_scoped_acquire acquire; + try { + log_callback(message); + } catch (py::error_already_set& e) { + // We assume this code is serialized as the gil is held. + if (!solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_ = e; + } + solve_wrapper->StopSearch(); + } + }; + solve_wrapper->AddLogCallback(safe_log_callback); + }, + py::arg("log_callback").none(false)) .def("add_solution_callback", &SolveWrapper::AddSolutionCallback, py::arg("callback")) .def("clear_solution_callback", &SolveWrapper::ClearSolutionCallback) - .def("add_best_bound_callback", &SolveWrapper::AddBestBoundCallback, - py::arg("best_bound_callback")) + .def( + "add_best_bound_callback", + [](ExtSolveWrapper* solve_wrapper, + std::function best_bound_callback) { + std::function safe_best_bound_callback = + [solve_wrapper, best_bound_callback](double bound) -> void { + ::py::gil_scoped_acquire acquire; + try { + best_bound_callback(bound); + } catch (py::error_already_set& e) { + // We assume this code is serialized as the gil is held. + if (!solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_ = e; + } + solve_wrapper->StopSearch(); + } + }; + solve_wrapper->AddBestBoundCallback(safe_best_bound_callback); + }, + py::arg("best_bound_callback").none(false)) .def("set_parameters", &SolveWrapper::SetParameters, py::arg("parameters")) .def("solve", - [](SolveWrapper* solve_wrapper, + [](ExtSolveWrapper* solve_wrapper, const CpModelProto& model_proto) -> CpSolverResponse { - ::pybind11::gil_scoped_release release; - return solve_wrapper->Solve(model_proto); + const auto result = [&]() -> CpSolverResponse { + ::py::gil_scoped_release release; + return solve_wrapper->Solve(model_proto); + }(); + if (solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_->restore(); + solve_wrapper->local_error_already_set_.reset(); + throw py::error_already_set(); + } + return result; }) .def("solve_and_return_response_wrapper", - [](SolveWrapper* solve_wrapper, + [](ExtSolveWrapper* solve_wrapper, const CpModelProto& model_proto) -> ResponseWrapper { - ::py::gil_scoped_release release; - return ResponseWrapper(solve_wrapper->Solve(model_proto)); + const auto result = [&]() -> ResponseWrapper { + ::py::gil_scoped_release release; + return ResponseWrapper(solve_wrapper->Solve(model_proto)); + }(); + if (solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_->restore(); + solve_wrapper->local_error_already_set_.reset(); + throw py::error_already_set(); + } + return result; }) .def("stop_search", &SolveWrapper::StopSearch); diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index 66b8447dac..0daf692e3e 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -114,6 +114,16 @@ def last_time(self) -> float: return self.__last_time +class RaiseException(cp_model.CpSolverSolutionCallback): + + def __init__(self, msg: str) -> None: + super().__init__() + self.__msg = msg + + def on_solution_callback(self) -> None: + raise ValueError(self.__msg) + + class LogToString: """Record log in a string.""" @@ -2036,8 +2046,6 @@ def testIssue4376MinimizeModel(self) -> None: max_width = 10 horizon = sum(t[0] for t in jobs) - num_jobs = len(jobs) - all_jobs = range(num_jobs) intervals = [] intervals0 = [] @@ -2047,16 +2055,16 @@ def testIssue4376MinimizeModel(self) -> None: ends = [] demands = [] - for i in all_jobs: + for i, job in enumerate(jobs): # Create main interval. start = model.new_int_var(0, horizon, f"start_{i}") - duration = jobs[i][0] + duration, width = job end = model.new_int_var(0, horizon, f"end_{i}") interval = model.new_interval_var(start, duration, end, f"interval_{i}") starts.append(start) intervals.append(interval) ends.append(end) - demands.append(jobs[i][1]) + demands.append(width) # Create an optional copy of interval to be executed on machine 0. performed_on_m0 = model.new_bool_var(f"perform_{i}_on_m0") @@ -2132,6 +2140,101 @@ def testIssue4434(self) -> None: self.assertIsNotNone(expr_ne) self.assertIsNotNone(expr_ge) + def testRaisePythonExceptionInCallback(self) -> None: + model = cp_model.CpModel() + + jobs = [ + [3, 3], # [duration, width] + [2, 5], + [1, 3], + [3, 7], + [7, 3], + [2, 2], + [2, 2], + [5, 5], + [10, 2], + [4, 3], + [2, 6], + [1, 2], + [6, 8], + [4, 5], + [3, 7], + ] + + max_width = 10 + + horizon = sum(t[0] for t in jobs) + + intervals = [] + intervals0 = [] + intervals1 = [] + performed = [] + starts = [] + ends = [] + demands = [] + + for i, job in enumerate(jobs): + # Create main interval. + start = model.new_int_var(0, horizon, f"start_{i}") + duration, width = job + end = model.new_int_var(0, horizon, f"end_{i}") + interval = model.new_interval_var(start, duration, end, f"interval_{i}") + starts.append(start) + intervals.append(interval) + ends.append(end) + demands.append(width) + + # Create an optional copy of interval to be executed on machine 0. + performed_on_m0 = model.new_bool_var(f"perform_{i}_on_m0") + performed.append(performed_on_m0) + start0 = model.new_int_var(0, horizon, f"start_{i}_on_m0") + end0 = model.new_int_var(0, horizon, f"end_{i}_on_m0") + interval0 = model.new_optional_interval_var( + start0, duration, end0, performed_on_m0, f"interval_{i}_on_m0" + ) + intervals0.append(interval0) + + # Create an optional copy of interval to be executed on machine 1. + start1 = model.new_int_var(0, horizon, f"start_{i}_on_m1") + end1 = model.new_int_var(0, horizon, f"end_{i}_on_m1") + interval1 = model.new_optional_interval_var( + start1, + duration, + end1, + ~performed_on_m0, + f"interval_{i}_on_m1", + ) + intervals1.append(interval1) + + # We only propagate the constraint if the tasks is performed on the + # machine. + model.add(start0 == start).only_enforce_if(performed_on_m0) + model.add(start1 == start).only_enforce_if(~performed_on_m0) + + # Width constraint (modeled as a cumulative) + model.add_cumulative(intervals, demands, max_width) + + # Choose which machine to perform the jobs on. + model.add_no_overlap(intervals0) + model.add_no_overlap(intervals1) + + # Objective variable. + makespan = model.new_int_var(0, horizon, "makespan") + model.add_max_equality(makespan, ends) + model.minimize(makespan) + + # Symmetry breaking. + model.add(performed[0] == 0) + + solver = cp_model.CpSolver() + solver.parameters.log_search_progress = True + solver.parameters.num_workers = 1 + msg: str = "this is my test message" + callback = RaiseException(msg) + + with self.assertRaisesRegex(ValueError, msg): + solver.solve(model, callback) + def testInPlaceSumModifications(self) -> None: model = cp_model.CpModel() x = [model.new_int_var(0, 10, f"x{i}") for i in range(5)] diff --git a/ortools/sat/scheduling_helpers.cc b/ortools/sat/scheduling_helpers.cc index a865bb9a18..78e995e93b 100644 --- a/ortools/sat/scheduling_helpers.cc +++ b/ortools/sat/scheduling_helpers.cc @@ -1024,7 +1024,7 @@ void SchedulingDemandHelper::AddEnergyMinInWindowReason( } } -void AddIntegerVariableFromIntervals(SchedulingConstraintHelper* helper, +void AddIntegerVariableFromIntervals(const SchedulingConstraintHelper* helper, Model* model, std::vector* vars) { IntegerEncoder* encoder = model->GetOrCreate(); diff --git a/ortools/sat/scheduling_helpers.h b/ortools/sat/scheduling_helpers.h index dccbcfa06f..1612fb795c 100644 --- a/ortools/sat/scheduling_helpers.h +++ b/ortools/sat/scheduling_helpers.h @@ -771,7 +771,7 @@ inline void SchedulingConstraintHelper::AddEnergyMinInIntervalReason( } // Cuts helpers. -void AddIntegerVariableFromIntervals(SchedulingConstraintHelper* helper, +void AddIntegerVariableFromIntervals(const SchedulingConstraintHelper* helper, Model* model, std::vector* vars); diff --git a/ortools/sat/solution_crush.cc b/ortools/sat/solution_crush.cc new file mode 100644 index 0000000000..db44664c38 --- /dev/null +++ b/ortools/sat/solution_crush.cc @@ -0,0 +1,193 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/solution_crush.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "ortools/algorithms/sparse_permutation.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { + +void SolutionCrush::SetVarToLinearExpression( + int new_var, absl::Span> linear) { + // We only fill the hint of the new variable if all the variable involved + // in its definition have a value. + if (hint_is_loaded_) { + int64_t new_value = 0; + for (const auto [var, coeff] : linear) { + CHECK_GE(var, 0); + CHECK_LE(var, hint_.size()); + if (!hint_has_value_[var]) return; + new_value += coeff * hint_[var]; + } + hint_has_value_[new_var] = true; + hint_[new_var] = new_value; + } +} + +void SolutionCrush::SetVarToClause(int new_var, absl::Span clause) { + if (hint_is_loaded_) { + int new_hint = 0; + bool all_have_hint = true; + for (const int literal : clause) { + const int var = PositiveRef(literal); + if (!hint_has_value_[var]) { + all_have_hint = false; + break; + } + if (hint_[var] == (RefIsPositive(literal) ? 1 : 0)) { + new_hint = 1; + break; + } + } + // Leave the `new_var` hint unassigned if any literal is not hinted. + if (all_have_hint) { + hint_has_value_[new_var] = true; + hint_[new_var] = new_hint; + } + } +} + +void SolutionCrush::SetVarToConjunction(int new_var, + absl::Span conjunction) { + if (hint_is_loaded_) { + int new_hint = 1; + bool all_have_hint = true; + for (const int literal : conjunction) { + const int var = PositiveRef(literal); + if (!hint_has_value_[var]) { + all_have_hint = false; + break; + } + if (hint_[var] == (RefIsPositive(literal) ? 0 : 1)) { + new_hint = 0; + break; + } + } + // Leave the `new_var` hint unassigned if any literal is not hinted. + if (all_have_hint) { + hint_has_value_[new_var] = true; + hint_[new_var] = new_hint; + } + } +} + +void SolutionCrush::UpdateVarToDomain(int var, const Domain& domain) { + if (VarHasSolutionHint(var)) { + UpdateVarSolutionHint(var, domain.ClosestValue(SolutionHint(var))); + } + +#ifdef CHECK_HINT + if (model_has_hint_ && HintIsLoaded() && + !domains_[var].Contains(hint_[var])) { + LOG(FATAL) << "Hint with value " << hint_[var] + << " infeasible when changing domain of " << var << " to " + << domains_[var]; + } +#endif +} + +void SolutionCrush::PermuteHintValues(const SparsePermutation& perm) { + CHECK(hint_is_loaded_); + perm.ApplyToDenseCollection(hint_); + perm.ApplyToDenseCollection(hint_has_value_); +} + +bool SolutionCrush::MaybeSetVarToAffineEquationSolution(int var_x, int var_y, + int64_t coeff, + int64_t offset) { + if (hint_is_loaded_) { + if (!hint_has_value_[var_y] && hint_has_value_[var_x]) { + hint_has_value_[var_y] = true; + hint_[var_y] = (hint_[var_x] - offset) / coeff; + if (hint_[var_y] * coeff + offset != hint_[var_x]) { + // TODO(user): Do we implement a rounding to closest instead of + // routing towards 0. + return false; + } + } + } + +#ifdef CHECK_HINT + const int64_t vx = hint_[var_x]; + const int64_t vy = hint_[var_y]; + if (model_has_hint_ && HintIsLoaded() && vx != vy * coeff + offset) { + LOG(FATAL) << "Affine relation incompatible with hint: " << vx + << " != " << vy << " * " << coeff << " + " << offset; + } +#endif + return true; +} + +void SolutionCrush::Resize(int new_size) { + hint_.resize(new_size, 0); + hint_has_value_.resize(new_size, false); +} + +void SolutionCrush::LoadSolution( + const absl::flat_hash_map& solution) { + CHECK(!hint_is_loaded_); + model_has_hint_ = !solution.empty(); + hint_is_loaded_ = true; + for (const auto [var, value] : solution) { + hint_has_value_[var] = true; + hint_[var] = value; + } +} + +void SolutionCrush::MaybeSetLiteralToValueEncoding(int literal, int var, + int64_t value) { + if (hint_is_loaded_) { + const int bool_var = PositiveRef(literal); + DCHECK(RefIsPositive(var)); + if (!hint_has_value_[bool_var] && hint_has_value_[var]) { + const int64_t bool_value = hint_[var] == value ? 1 : 0; + hint_has_value_[bool_var] = true; + hint_[bool_var] = RefIsPositive(literal) ? bool_value : 1 - bool_value; + } + } +} + +void SolutionCrush::SetVarToReifiedPrecedenceLiteral( + int var, const LinearExpressionProto& time_i, + const LinearExpressionProto& time_j, int active_i, int active_j) { + // Take care of hints. + if (hint_is_loaded_) { + std::optional time_i_hint = GetExpressionSolutionHint(time_i); + std::optional time_j_hint = GetExpressionSolutionHint(time_j); + std::optional active_i_hint = GetRefSolutionHint(active_i); + std::optional active_j_hint = GetRefSolutionHint(active_j); + if (time_i_hint.has_value() && time_j_hint.has_value() && + active_i_hint.has_value() && active_j_hint.has_value()) { + const bool reified_hint = (active_i_hint.value() != 0) && + (active_j_hint.value() != 0) && + (time_i_hint.value() <= time_j_hint.value()); + SetNewVariableHint(var, reified_hint); + } + } +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/solution_crush.h b/ortools/sat/solution_crush.h new file mode 100644 index 0000000000..dddc37ee6e --- /dev/null +++ b/ortools/sat/solution_crush.h @@ -0,0 +1,181 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_SOLUTION_CRUSH_H_ +#define OR_TOOLS_SAT_SOLUTION_CRUSH_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "ortools/algorithms/sparse_permutation.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { + +class SolutionCrush { + public: + SolutionCrush() = default; + + // SolutionCrush is neither copyable nor movable. + SolutionCrush(const SolutionCrush&) = delete; + SolutionCrush(SolutionCrush&&) = delete; + SolutionCrush& operator=(const SolutionCrush&) = delete; + SolutionCrush& operator=(SolutionCrush&&) = delete; + + // Sets the given values in the solution. `solution` must be a map from + // variable indices to variable values. This must be called only once, before + // any other method (besides `Resize`). + // TODO(user): revisit this near the end of the refactoring. + void LoadSolution(const absl::flat_hash_map& solution); + + void PermuteHintValues(const SparsePermutation& perm); + + // Solution hint accessors. + bool VarHasSolutionHint(int var) const { return hint_has_value_[var]; } + int64_t SolutionHint(int var) const { return hint_[var]; } + bool HintIsLoaded() const { return hint_is_loaded_; } + absl::Span SolutionHint() const { return hint_; } + + // Similar to SolutionHint() but make sure the value is within the given + // domain. + int64_t ClampedSolutionHint(int var, const Domain& domain) { + int64_t value = hint_[var]; + if (value > domain.Max()) { + value = domain.Max(); + } else if (value < domain.Min()) { + value = domain.Min(); + } + return value; + } + + bool LiteralSolutionHint(int lit) const { + const int var = PositiveRef(lit); + return RefIsPositive(lit) ? hint_[var] : !hint_[var]; + } + + bool LiteralSolutionHintIs(int lit, bool value) const { + const int var = PositiveRef(lit); + return hint_is_loaded_ && hint_has_value_[var] && + hint_[var] == (RefIsPositive(lit) ? value : !value); + } + + // If the given literal is already hinted, updates its hint. + // Otherwise do nothing. + void UpdateLiteralSolutionHint(int lit, bool value) { + UpdateVarSolutionHint(PositiveRef(lit), + RefIsPositive(lit) == value ? 1 : 0); + } + + std::optional GetRefSolutionHint(int ref) { + const int var = PositiveRef(ref); + if (!VarHasSolutionHint(var)) return std::nullopt; + const int64_t var_hint = SolutionHint(var); + return RefIsPositive(ref) ? var_hint : -var_hint; + } + + std::optional GetExpressionSolutionHint( + const LinearExpressionProto& expr) { + int64_t result = expr.offset(); + for (int i = 0; i < expr.vars().size(); ++i) { + if (expr.coeffs(i) == 0) continue; + if (!VarHasSolutionHint(expr.vars(i))) return std::nullopt; + result += expr.coeffs(i) * SolutionHint(expr.vars(i)); + } + return result; + } + + void UpdateRefSolutionHint(int ref, int hint) { + UpdateVarSolutionHint(PositiveRef(ref), RefIsPositive(ref) ? hint : -hint); + } + + // If the given variable is already hinted, updates its hint value. + // Otherwise, do nothing. + void UpdateVarSolutionHint(int var, int64_t value) { + DCHECK(RefIsPositive(var)); + if (!hint_is_loaded_) return; + if (!hint_has_value_[var]) return; + hint_[var] = value; + } + + // Allows to set the hint of a newly created variable. + void SetNewVariableHint(int var, int64_t value) { + CHECK(hint_is_loaded_); + CHECK(!hint_has_value_[var]); + hint_has_value_[var] = true; + hint_[var] = value; + } + + // Resizes the solution to contain `new_size` variables. Does not change the + // value of existing variables, and does not set any value for the new + // variables. + // WARNING: the methods below do not automatically resize the solution. To set + // the value of a new variable with one of them, call this method first. + void Resize(int new_size); + + // Sets the value of `literal` to "`var`'s value == `value`". Does nothing if + // `literal` already has a value. + void MaybeSetLiteralToValueEncoding(int literal, int var, int64_t value); + + // Sets the value of `var` to the value of the given linear expression. + // `linear` must be a list of (variable index, coefficient) pairs. + void SetVarToLinearExpression( + int var, absl::Span> linear); + + // Sets the value of `var` to 1 if the value of at least one literal in + // `clause` is equal to 1 (or to 0 otherwise). `clause` must be a list of + // literal indices. + void SetVarToClause(int var, absl::Span clause); + + // Sets the value of `var` to 1 if the value of all the literals in + // `conjunction` is 1 (or to 0 otherwise). `conjunction` must be a list of + // literal indices. + void SetVarToConjunction(int var, absl::Span conjunction); + + // Updates the value of the given variable to be within the given domain. The + // variable is updated to the closest value within the domain. `var` must + // already have a value. + void UpdateVarToDomain(int var, const Domain& domain); + + // Sets the value of `var_y` so that "`var_x`'s value = `var_y`'s value + // * `coeff` + `offset`". Does nothing if `var_y` already has a value. + // Returns whether the update was successful. + bool MaybeSetVarToAffineEquationSolution(int var_x, int var_y, int64_t coeff, + int64_t offset); + + // Sets the value of `var` to "`time_i`'s value <= `time_j`'s value && + // `active_i`'s value == true && `active_j`'s value == true". + void SetVarToReifiedPrecedenceLiteral(int var, + const LinearExpressionProto& time_i, + const LinearExpressionProto& time_j, + int active_i, int active_j); + + private: + // This contains all the hinted value or zero if the hint wasn't specified. + // We try to maintain this as we create new variable. + bool model_has_hint_ = false; + bool hint_is_loaded_ = false; + std::vector hint_has_value_; + std::vector hint_; +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_SOLUTION_CRUSH_H_ diff --git a/ortools/sat/swig_helper.cc b/ortools/sat/swig_helper.cc index 25ad97d4c0..4e18ea6a77 100644 --- a/ortools/sat/swig_helper.cc +++ b/ortools/sat/swig_helper.cc @@ -88,7 +88,7 @@ bool SolutionCallback::SolutionBooleanValue(int index) const { : response_.solution(-index - 1) == 0; } -void SolutionCallback::StopSearch() { +void SolutionCallback::StopSearch() const { if (wrapper_ != nullptr) wrapper_->StopSearch(); } diff --git a/ortools/sat/swig_helper.h b/ortools/sat/swig_helper.h index 81b55eb958..d7ed361325 100644 --- a/ortools/sat/swig_helper.h +++ b/ortools/sat/swig_helper.h @@ -63,13 +63,15 @@ class SolutionCallback { bool SolutionBooleanValue(int index) const; // Stops the search. - void StopSearch(); + void StopSearch() const; const operations_research::sat::CpSolverResponse& Response() const; // We use mutable and non const methods to overcome SWIG difficulties. void SetWrapperClass(SolveWrapper* wrapper) const; + SolveWrapper* wrapper() const { return wrapper_; } + bool HasResponse() const; private: diff --git a/ortools/sat/var_domination.cc b/ortools/sat/var_domination.cc index 88cdb181ca..1e84abcf70 100644 --- a/ortools/sat/var_domination.cc +++ b/ortools/sat/var_domination.cc @@ -43,6 +43,7 @@ #include "ortools/sat/integer_base.h" #include "ortools/sat/presolve_context.h" #include "ortools/sat/presolve_util.h" +#include "ortools/sat/solution_crush.h" #include "ortools/sat/util.h" #include "ortools/util/affine_relation.h" #include "ortools/util/saturated_arithmetic.h" @@ -722,6 +723,7 @@ void TransformLinearWithSpecialBoolean(const ConstraintProto& ct, int ref, } // namespace bool DualBoundStrengthening::Strengthen(PresolveContext* context) { + SolutionCrush& crush = context->solution_crush(); num_deleted_constraints_ = 0; const CpModelProto& cp_model = *context->working_model; const int num_vars = cp_model.variables_size(); @@ -851,8 +853,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { context->deductions.ImpliedDomain(enf, positive_ref)); if (implied.IsEmpty()) { context->UpdateRuleStats("dual: fix variable"); - context->UpdateLiteralSolutionHint(enf, false); - if (!context->SetLiteralToFalse(enf)) return false; + if (!context->SetLiteralAndHintToFalse(enf)) return false; if (!context->IntersectDomainWithAndUpdateHint(positive_ref, Domain(bound))) { return false; @@ -882,8 +883,8 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // decreasing [`ct` = enf => (var = implied)] does not apply. We can // thus set the hint of `positive_ref` to `bound` to preserve the hint // feasibility. - if (context->LiteralSolutionHintIs(enf, false)) { - context->UpdateRefSolutionHint(positive_ref, bound); + if (crush.LiteralSolutionHintIs(enf, false)) { + crush.UpdateRefSolutionHint(positive_ref, bound); } if (RefIsPositive(enf)) { // positive_ref = enf * implied + (1 - enf) * bound. @@ -913,8 +914,8 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // hint(ref) is 1. In this case the only locking constraint `ct` does // not apply and thus does not prevent decreasing the hint of ref in // order to preserve the hint feasibility. - if (context->LiteralSolutionHintIs(enf, false)) { - context->UpdateLiteralSolutionHint(ref, false); + if (crush.LiteralSolutionHintIs(enf, false)) { + crush.UpdateLiteralSolutionHint(ref, false); } context->AddImplication(NegatedRef(enf), NegatedRef(ref)); context->UpdateNewConstraintsVariableUsage(); @@ -938,8 +939,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { .IntersectionWith(var_domain); if (rhs.IsEmpty()) { context->UpdateRuleStats("linear1: infeasible"); - context->UpdateLiteralSolutionHint(ct.enforcement_literal(0), false); - if (!context->SetLiteralToFalse(ct.enforcement_literal(0))) { + if (!context->SetLiteralAndHintToFalse(ct.enforcement_literal(0))) { return false; } processed[PositiveRef(ref)] = true; @@ -972,8 +972,8 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // break the hint only if hint(ref) = hint(encoding_lit) = 1. But // in this case `ct` is actually not blocking ref from decreasing. // We can thus set its hint to 0 to preserve the hint feasibility. - if (context->LiteralSolutionHintIs(encoding_lit, true)) { - context->UpdateLiteralSolutionHint(ref, false); + if (crush.LiteralSolutionHintIs(encoding_lit, true)) { + crush.UpdateLiteralSolutionHint(ref, false); } if (!context->StoreBooleanEqualityRelation(encoding_lit, NegatedRef(ref))) { @@ -986,8 +986,8 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // = 1. But in this case `ct` is actually not blocking ref from // decreasing. We can thus set its hint to 0 to preserve the hint // feasibility. - if (context->LiteralSolutionHintIs(encoding_lit, false)) { - context->UpdateLiteralSolutionHint(ref, false); + if (crush.LiteralSolutionHintIs(encoding_lit, false)) { + crush.UpdateLiteralSolutionHint(ref, false); } if (!context->StoreBooleanEqualityRelation(encoding_lit, ref)) { return false; @@ -1012,10 +1012,9 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // set the hint of `ref` to false. This should be safe since the only // constraint blocking `ref` from decreasing is `ct` = not(ref) => // (`var` in `rhs`) -- which does not apply when `ref` is true. - const std::optional var_hint = - context->GetRefSolutionHint(var); + const std::optional var_hint = crush.GetRefSolutionHint(var); if (var_hint.has_value() && !complement.Contains(*var_hint)) { - context->UpdateLiteralSolutionHint(ref, false); + crush.UpdateLiteralSolutionHint(ref, false); } ConstraintProto* new_ct = context->working_model->add_constraints(); new_ct->add_enforcement_literal(ref); @@ -1096,12 +1095,12 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // not blocking the ref at 1 from decreasing. Hence we can set its // hint to false to preserve the hint feasibility despite the new // Boolean equality constraint. - if (context->VarHasSolutionHint(PositiveRef(ref)) && - context->VarHasSolutionHint(PositiveRef(other_ref)) && - context->LiteralSolutionHint(ref) != - context->LiteralSolutionHint(other_ref)) { - context->UpdateLiteralSolutionHint(ref, false); - context->UpdateLiteralSolutionHint(other_ref, false); + if (crush.VarHasSolutionHint(PositiveRef(ref)) && + crush.VarHasSolutionHint(PositiveRef(other_ref)) && + crush.LiteralSolutionHint(ref) != + crush.LiteralSolutionHint(other_ref)) { + crush.UpdateLiteralSolutionHint(ref, false); + crush.UpdateLiteralSolutionHint(other_ref, false); } if (!context->StoreBooleanEqualityRelation(ref, other_ref)) { return false; @@ -1178,8 +1177,8 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // increase its value. If hint(a) is true then hint(b) must be true as well // if the hint is feasible, due to the a => b constraint. Setting hint(a) to // hint(b) is thus always safe. - if (context->VarHasSolutionHint(PositiveRef(b))) { - context->UpdateLiteralSolutionHint(a, context->LiteralSolutionHint(b)); + if (crush.VarHasSolutionHint(PositiveRef(b))) { + crush.UpdateLiteralSolutionHint(a, crush.LiteralSolutionHint(b)); } if (!context->StoreBooleanEqualityRelation(a, b)) return false; context->UpdateRuleStats("dual: enforced equivalence"); @@ -1478,10 +1477,11 @@ namespace { // respectively. void MaybeUpdateLiteralHintFromDominance(PresolveContext& context, int lit, int dominating_lit) { - if (context.LiteralSolutionHintIs(lit, true) && - context.LiteralSolutionHintIs(dominating_lit, false)) { - context.UpdateLiteralSolutionHint(lit, false); - context.UpdateLiteralSolutionHint(dominating_lit, true); + SolutionCrush& crush = context.solution_crush(); + if (crush.LiteralSolutionHintIs(lit, true) && + crush.LiteralSolutionHintIs(dominating_lit, false)) { + crush.UpdateLiteralSolutionHint(lit, false); + crush.UpdateLiteralSolutionHint(dominating_lit, true); } } @@ -1495,7 +1495,8 @@ void MaybeUpdateLiteralHintFromDominance(PresolveContext& context, int lit, void MaybeUpdateRefHintFromDominance( PresolveContext& context, int ref, const Domain& domain, const absl::Span dominating_variables) { - const std::optional ref_hint = context.GetRefSolutionHint(ref); + SolutionCrush& crush = context.solution_crush(); + const std::optional ref_hint = crush.GetRefSolutionHint(ref); if (!ref_hint.has_value()) return; // The quantity to subtract from the solution hint of `ref`. If the closest // value of *ref_hint in `domain` is not *ref_hint then it is either the lower @@ -1506,12 +1507,12 @@ void MaybeUpdateRefHintFromDominance( // hint is not initially feasible (in which case we can't fix it). if (ref_hint_delta <= 0) return; - context.UpdateRefSolutionHint(ref, *ref_hint - ref_hint_delta); + crush.UpdateRefSolutionHint(ref, *ref_hint - ref_hint_delta); int64_t remaining_delta = ref_hint_delta; for (const IntegerVariable ivar : dominating_variables) { const int dominating_ref = VarDomination::IntegerVariableToRef(ivar); const std::optional dominating_ref_hint = - context.GetRefSolutionHint(dominating_ref); + crush.GetRefSolutionHint(dominating_ref); if (!dominating_ref_hint.has_value()) continue; const Domain& dominating_ref_domain = context.DomainOf(dominating_ref); const int64_t new_dominating_ref_hint = @@ -1519,7 +1520,7 @@ void MaybeUpdateRefHintFromDominance( remaining_delta); // This might happen if the solution hint is not initially feasible. if (!dominating_ref_domain.Contains(new_dominating_ref_hint)) continue; - context.UpdateRefSolutionHint(dominating_ref, new_dominating_ref_hint); + crush.UpdateRefSolutionHint(dominating_ref, new_dominating_ref_hint); remaining_delta -= (new_dominating_ref_hint - *dominating_ref_hint); if (remaining_delta == 0) break; } diff --git a/ortools/sat/var_domination_test.cc b/ortools/sat/var_domination_test.cc index 1017567403..cfeb0dda34 100644 --- a/ortools/sat/var_domination_test.cc +++ b/ortools/sat/var_domination_test.cc @@ -229,8 +229,8 @@ TEST(VarDominationTest, ExploitDominanceOfImplicant) { EXPECT_THAT(var_dom.DominatingVariables(X), ElementsAre(NegationOf(Y))); EXPECT_EQ(context.DomainOf(0).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0]"); - EXPECT_EQ(context.SolutionHint(0), 0); - EXPECT_EQ(context.SolutionHint(1), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 0); } // 2X - Y >= 0 @@ -278,8 +278,8 @@ TEST(VarDominationTest, ExploitDominanceOfNegatedImplicand) { EXPECT_THAT(var_dom.DominatingVariables(NegationOf(X)), ElementsAre(Y)); EXPECT_EQ(context.DomainOf(0).ToString(), "[1]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[1]"); - EXPECT_EQ(context.SolutionHint(0), 1); - EXPECT_EQ(context.SolutionHint(1), 1); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 1); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 1); } // X + 2Y >= 0 @@ -324,8 +324,8 @@ TEST(VarDominationTest, ExploitDominanceInExactlyOne) { EXPECT_THAT(var_dom.DominatingVariables(X), ElementsAre(Y)); EXPECT_EQ(context.DomainOf(0).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0,1]"); - EXPECT_EQ(context.SolutionHint(0), 0); - EXPECT_EQ(context.SolutionHint(1), 1); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 1); } // Objective: min(X + Y + 2Z) @@ -380,9 +380,9 @@ TEST(VarDominationTest, ExploitDominanceWithIntegerVariables) { EXPECT_EQ(context.DomainOf(0).ToString(), "[-5]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0,10]"); EXPECT_EQ(context.DomainOf(2).ToString(), "[5]"); - EXPECT_EQ(context.SolutionHint(0), -5); - EXPECT_EQ(context.SolutionHint(1), 10); - EXPECT_EQ(context.SolutionHint(2), 5); + EXPECT_EQ(context.solution_crush().SolutionHint(0), -5); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 10); + EXPECT_EQ(context.solution_crush().SolutionHint(2), 5); } // Objective: min(X + 2Y) @@ -429,8 +429,8 @@ TEST(VarDominationTest, ExploitRemainingDominance) { EqualsProto(expected_constraint_proto)); EXPECT_EQ(context.DomainOf(0).ToString(), "[0,1]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0,1]"); - EXPECT_EQ(context.SolutionHint(0), 1); - EXPECT_EQ(context.SolutionHint(1), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 1); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 0); } // Objective: min(X) @@ -492,9 +492,9 @@ TEST(VarDominationTest, ExploitRemainingDominanceWithIntegerVariables) { EXPECT_EQ(context.DomainOf(0).ToString(), "[-10,-5]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[5,10]"); EXPECT_EQ(context.DomainOf(2).ToString(), "[5]"); - EXPECT_EQ(context.SolutionHint(0), -5); - EXPECT_EQ(context.SolutionHint(1), 6); - EXPECT_EQ(context.SolutionHint(2), 5); + EXPECT_EQ(context.solution_crush().SolutionHint(0), -5); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 6); + EXPECT_EQ(context.solution_crush().SolutionHint(2), 5); } // X + Y + Z = 0 @@ -835,8 +835,8 @@ TEST(DualBoundReductionTest, FixVariableToDomainBound) { EXPECT_EQ(context.DomainOf(0).ToString(), "[-10]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[10]"); - EXPECT_EQ(context.SolutionHint(0), -10); - EXPECT_EQ(context.SolutionHint(1), 10); + EXPECT_EQ(context.solution_crush().SolutionHint(0), -10); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 10); } // Bound propagation see nothing, but if we can remove feasible solution, from @@ -874,9 +874,9 @@ TEST(DualBoundReductionTest, BasicTest) { EXPECT_EQ(context.DomainOf(0).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(2).ToString(), "[0]"); - EXPECT_EQ(context.SolutionHint(0), 0); - EXPECT_EQ(context.SolutionHint(1), 0); - EXPECT_EQ(context.SolutionHint(2), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(2), 0); } TEST(DualBoundReductionTest, CarefulWithHoles) { @@ -938,9 +938,9 @@ TEST(DualBoundReductionTest, Choices) { EXPECT_EQ(context.DomainOf(0).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[-2]"); EXPECT_EQ(context.DomainOf(2).ToString(), "[2]"); - EXPECT_EQ(context.SolutionHint(0), 0); - EXPECT_EQ(context.SolutionHint(1), -2); - EXPECT_EQ(context.SolutionHint(2), 2); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(1), -2); + EXPECT_EQ(context.solution_crush().SolutionHint(2), 2); } TEST(DualBoundReductionTest, AddImplication) { @@ -987,9 +987,9 @@ TEST(DualBoundReductionTest, AddImplication) { EXPECT_EQ(context.DomainOf(0).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0,1]"); EXPECT_EQ(context.DomainOf(2).ToString(), "[0,1]"); - EXPECT_EQ(context.SolutionHint(0), 0); - EXPECT_EQ(context.SolutionHint(1), 0); - EXPECT_EQ(context.SolutionHint(2), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(0), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(1), 0); + EXPECT_EQ(context.solution_crush().SolutionHint(2), 0); } TEST(DualBoundReductionTest, EquivalenceDetection) { @@ -1035,7 +1035,7 @@ TEST(DualBoundReductionTest, EquivalenceDetection) { EXPECT_EQ(context.DomainOf(2).ToString(), "[0,1]"); // Equivalence between a and b. EXPECT_EQ(context.GetLiteralRepresentative(1), 0); - EXPECT_TRUE(context.LiteralSolutionHint(0)); + EXPECT_TRUE(context.solution_crush().LiteralSolutionHint(0)); } } // namespace