Skip to content

Commit

Permalink
support hinting literals in CP-SAT Python
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Oct 7, 2024
1 parent 5912937 commit f269264
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
16 changes: 13 additions & 3 deletions ortools/sat/python/cp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2939,10 +2939,20 @@ def export_to_file(self, file: str) -> bool:
"""
return swig_helper.CpSatHelper.write_model_to_file(self.__model, file)

def add_hint(self, var: IntVar, value: int) -> None:
@overload
def add_hint(self, var: IntVar, value: int) -> None: ...

@overload
def add_hint(self, literal: BoolVarT, value: bool) -> None: ...

def add_hint(self, var, value) -> None:
"""Adds 'var == value' as a hint to the solver."""
self.__model.solution_hint.vars.append(self.get_or_make_index(var))
self.__model.solution_hint.values.append(value)
if var.index >= 0:
self.__model.solution_hint.vars.append(self.get_or_make_index(var))
self.__model.solution_hint.values.append(int(value))
else:
self.__model.solution_hint.vars.append(self.negated(var.index))
self.__model.solution_hint.values.append(int(not value))

def clear_hints(self):
"""Removes any solution hint from the model."""
Expand Down
15 changes: 15 additions & 0 deletions ortools/sat/python/cp_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,21 @@ def testSolutionHinting(self):
self.assertEqual(2, solver.value(x))
self.assertEqual(4, solver.value(y))

def testSolutionHintingWithBooleans(self):
print("testSolutionHintingWithBooleans")
model = cp_model.CpModel()
x = model.new_bool_var("x")
y = model.new_bool_var("y")
model.add_linear_constraint(x + y, 1, 1)
model.add_hint(x, True)
model.add_hint(~y, True)
solver = cp_model.CpSolver()
solver.parameters.cp_model_presolve = False
status = solver.solve(model)
self.assertEqual(cp_model.OPTIMAL, status)
self.assertTrue(solver.boolean_value(x))
self.assertFalse(solver.boolean_value(y))

def testStats(self):
print("testStats")
model = cp_model.CpModel()
Expand Down

0 comments on commit f269264

Please sign in to comment.