From f269264849b7ae551c8913eed97757c7543eef61 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 7 Oct 2024 15:54:04 +0200 Subject: [PATCH] support hinting literals in CP-SAT Python --- ortools/sat/python/cp_model.py | 16 +++++++++++++--- ortools/sat/python/cp_model_test.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index cea258dca5..0f1b93ea07 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -2939,10 +2939,20 @@ def export_to_file(self, file: str) -> bool: """ return swig_helper.CpSatHelper.write_model_to_file(self.__model, file) - def add_hint(self, var: IntVar, value: int) -> None: + @overload + def add_hint(self, var: IntVar, value: int) -> None: ... + + @overload + def add_hint(self, literal: BoolVarT, value: bool) -> None: ... + + def add_hint(self, var, value) -> None: """Adds 'var == value' as a hint to the solver.""" - self.__model.solution_hint.vars.append(self.get_or_make_index(var)) - self.__model.solution_hint.values.append(value) + if var.index >= 0: + self.__model.solution_hint.vars.append(self.get_or_make_index(var)) + self.__model.solution_hint.values.append(int(value)) + else: + self.__model.solution_hint.vars.append(self.negated(var.index)) + self.__model.solution_hint.values.append(int(not value)) def clear_hints(self): """Removes any solution hint from the model.""" diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index 8bd2aae00b..09235c4e0b 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -1428,6 +1428,21 @@ def testSolutionHinting(self): self.assertEqual(2, solver.value(x)) self.assertEqual(4, solver.value(y)) + def testSolutionHintingWithBooleans(self): + print("testSolutionHintingWithBooleans") + model = cp_model.CpModel() + x = model.new_bool_var("x") + y = model.new_bool_var("y") + model.add_linear_constraint(x + y, 1, 1) + model.add_hint(x, True) + model.add_hint(~y, True) + solver = cp_model.CpSolver() + solver.parameters.cp_model_presolve = False + status = solver.solve(model) + self.assertEqual(cp_model.OPTIMAL, status) + self.assertTrue(solver.boolean_value(x)) + self.assertFalse(solver.boolean_value(y)) + def testStats(self): print("testStats") model = cp_model.CpModel()