From e86780b734a3f8a1e416fa5782a8aa51011282b3 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 10 Jul 2023 11:40:21 -0700 Subject: [PATCH] [CP-SAT] more work on cp_model.py --- ortools/sat/python/cp_model.py | 91 ++++++++++++++-------------------- 1 file changed, 36 insertions(+), 55 deletions(-) diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index b30979b1b42..f9abfd81270 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -46,6 +46,7 @@ """ import collections +import itertools import numbers import threading import time @@ -111,9 +112,9 @@ IntegralT = Union[numbers.Integral, np.integer, int] NumberT = Union[numbers.Integral, np.integer, int, numbers.Number, np.double, float] LiteralT = Union["IntVar", "_NotBooleanVariable", IntegralT, bool] -VariableT = Union["IntVar", "_ProductCst", IntegralT] -LinearExprT = Union["LinearExpr", "IntVar", "_ProductCst", IntegralT] -ObjLinearExprT = Union["LinearExpr", "IntVar", "_ProductCst", NumberT] +VariableT = Union["IntVar", IntegralT] +LinearExprT = Union["LinearExpr", "IntVar", IntegralT] +ObjLinearExprT = Union["LinearExpr", "IntVar", NumberT] ArcT = Tuple[IntegralT, IntegralT, LiteralT] @@ -602,11 +603,13 @@ def __init__(self, expressions, constant=0): raise TypeError("Not an linear expression: " + str(x)) def __str__(self): - exprs_str = " + ".join(map(repr, self.__expressions)) - if self.__constant == 0: - return f"({exprs_str})" - else: - return f"({exprs_str} + {self.__constant})" + constant_terms = (self.__constant,) if self.__constant != 0 else () + exprs_str = " + ".join( + map(repr, itertools.chain(self.__expressions, constant_terms)) + ) + if not exprs_str: + return "0" + return exprs_str def __repr__(self): exprs_str = ", ".join(map(repr, self.__expressions)) @@ -801,7 +804,7 @@ class BoundedLinearExpression: model.Add(x + 2 * y -1 >= z) """ - def __init__(self, expr, bounds): + def __init__(self, expr: LinearExprT, bounds: Sequence[int]): self.__expr: LinearExprT = expr self.__bounds: Sequence[int] = bounds @@ -915,13 +918,11 @@ def OnlyEnforceIf(self, *boolvar) -> "Constraint": self. """ for lit in ExpandGeneratorOrTuple(boolvar): - if (isinstance(lit, bool) and bool(lit)) or ( - cmh.is_integral(lit) and int(lit) == 1 - ): + if (isinstance(lit, bool) and lit) or (cmh.is_integral(lit) and lit == 1): # Always true. Do nothing. pass - elif (isinstance(lit, bool) and not bool(lit)) or ( - cmh.is_integral(lit) and int(lit) == 0 + elif (isinstance(lit, bool) and not lit) or ( + cmh.is_integral(lit) and lit == 0 ): self.__constraint.enforcement_literal.append( self.__cp_model.NewConstant(0).Index() @@ -2555,83 +2556,65 @@ def StopSearch(self) -> None: if self.__solve_wrapper: self.__solve_wrapper.StopSearch() - def Value(self, expression: LinearExprT) -> int: - """Returns the value of a linear expression after solve.""" + def _solution(self) -> cp_model_pb2.CpSolverResponse: + """Checks Solve() has been called, and returns the solution.""" if self.__solution is None: raise RuntimeError("Solve() has not been called.") - return EvaluateLinearExpr(expression, self.__solution) + return self.__solution + + def Value(self, expression: LinearExprT) -> int: + """Returns the value of a linear expression after solve.""" + return EvaluateLinearExpr(expression, self._solution()) def BooleanValue(self, literal: LiteralT) -> bool: """Returns the boolean value of a literal after solve.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return EvaluateBooleanExpression(literal, self.__solution) + return EvaluateBooleanExpression(literal, self._solution()) def ObjectiveValue(self) -> float: """Returns the value of the objective after solve.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.objective_value + return self._solution().objective_value def BestObjectiveBound(self) -> float: """Returns the best lower (upper) bound found when min(max)imizing.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.best_objective_bound + return self._solution().best_objective_bound def StatusName(self, status: ... = None) -> str: """Returns the name of the status returned by Solve().""" if status is None: - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - status = self.__solution.status + status = self._solution().status return cp_model_pb2.CpSolverStatus.Name(status) def NumBooleans(self) -> int: """Returns the number of boolean variables managed by the SAT solver.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.num_booleans + return self._solution().num_booleans def NumConflicts(self) -> int: """Returns the number of conflicts since the creation of the solver.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.num_conflicts + return self._solution().num_conflicts def NumBranches(self) -> int: """Returns the number of search branches explored by the solver.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.num_branches + return self._solution().num_branches def WallTime(self) -> float: """Returns the wall time in seconds since the creation of the solver.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.wall_time + return self._solution().wall_time def UserTime(self) -> float: """Returns the user time in seconds since the creation of the solver.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.user_time + return self._solution().user_time def ResponseStats(self) -> str: """Returns some statistics on the solution found as a string.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return swig_helper.CpSatHelper.SolverResponseStats(self.__solution) + return swig_helper.CpSatHelper.SolverResponseStats(self._solution()) - def ResponseProto(self) -> Optional[cp_model_pb2.CpSolverResponse]: + def ResponseProto(self) -> cp_model_pb2.CpSolverResponse: """Returns the response object.""" - return self.__solution + return self._solution() def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]: """Returns the indices of the infeasible assumptions.""" - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.sufficient_assumptions_for_infeasibility + return self._solution().sufficient_assumptions_for_infeasibility def SolutionInfo(self) -> str: """Returns some information on the solve process. @@ -2642,9 +2625,7 @@ def SolutionInfo(self) -> str: Raises: RuntimeError: if Solve() has not been called. """ - if self.__solution is None: - raise RuntimeError("Solve() has not been called.") - return self.__solution.solution_info + return self._solution().solution_info class CpSolverSolutionCallback(swig_helper.SolutionCallback):