Skip to content

Commit

Permalink
[CP-SAT] more work on cp_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Jul 10, 2023
1 parent ab19570 commit e86780b
Showing 1 changed file with 36 additions and 55 deletions.
91 changes: 36 additions & 55 deletions ortools/sat/python/cp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"""

import collections
import itertools
import numbers
import threading
import time
Expand Down Expand Up @@ -111,9 +112,9 @@
IntegralT = Union[numbers.Integral, np.integer, int]
NumberT = Union[numbers.Integral, np.integer, int, numbers.Number, np.double, float]
LiteralT = Union["IntVar", "_NotBooleanVariable", IntegralT, bool]
VariableT = Union["IntVar", "_ProductCst", IntegralT]
LinearExprT = Union["LinearExpr", "IntVar", "_ProductCst", IntegralT]
ObjLinearExprT = Union["LinearExpr", "IntVar", "_ProductCst", NumberT]
VariableT = Union["IntVar", IntegralT]
LinearExprT = Union["LinearExpr", "IntVar", IntegralT]
ObjLinearExprT = Union["LinearExpr", "IntVar", NumberT]
ArcT = Tuple[IntegralT, IntegralT, LiteralT]


Expand Down Expand Up @@ -602,11 +603,13 @@ def __init__(self, expressions, constant=0):
raise TypeError("Not an linear expression: " + str(x))

def __str__(self):
exprs_str = " + ".join(map(repr, self.__expressions))
if self.__constant == 0:
return f"({exprs_str})"
else:
return f"({exprs_str} + {self.__constant})"
constant_terms = (self.__constant,) if self.__constant != 0 else ()
exprs_str = " + ".join(
map(repr, itertools.chain(self.__expressions, constant_terms))
)
if not exprs_str:
return "0"
return exprs_str

def __repr__(self):
exprs_str = ", ".join(map(repr, self.__expressions))
Expand Down Expand Up @@ -801,7 +804,7 @@ class BoundedLinearExpression:
model.Add(x + 2 * y -1 >= z)
"""

def __init__(self, expr, bounds):
def __init__(self, expr: LinearExprT, bounds: Sequence[int]):
self.__expr: LinearExprT = expr
self.__bounds: Sequence[int] = bounds

Expand Down Expand Up @@ -915,13 +918,11 @@ def OnlyEnforceIf(self, *boolvar) -> "Constraint":
self.
"""
for lit in ExpandGeneratorOrTuple(boolvar):
if (isinstance(lit, bool) and bool(lit)) or (
cmh.is_integral(lit) and int(lit) == 1
):
if (isinstance(lit, bool) and lit) or (cmh.is_integral(lit) and lit == 1):
# Always true. Do nothing.
pass
elif (isinstance(lit, bool) and not bool(lit)) or (
cmh.is_integral(lit) and int(lit) == 0
elif (isinstance(lit, bool) and not lit) or (
cmh.is_integral(lit) and lit == 0
):
self.__constraint.enforcement_literal.append(
self.__cp_model.NewConstant(0).Index()
Expand Down Expand Up @@ -2555,83 +2556,65 @@ def StopSearch(self) -> None:
if self.__solve_wrapper:
self.__solve_wrapper.StopSearch()

def Value(self, expression: LinearExprT) -> int:
"""Returns the value of a linear expression after solve."""
def _solution(self) -> cp_model_pb2.CpSolverResponse:
"""Checks Solve() has been called, and returns the solution."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return EvaluateLinearExpr(expression, self.__solution)
return self.__solution

def Value(self, expression: LinearExprT) -> int:
"""Returns the value of a linear expression after solve."""
return EvaluateLinearExpr(expression, self._solution())

def BooleanValue(self, literal: LiteralT) -> bool:
"""Returns the boolean value of a literal after solve."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return EvaluateBooleanExpression(literal, self.__solution)
return EvaluateBooleanExpression(literal, self._solution())

def ObjectiveValue(self) -> float:
"""Returns the value of the objective after solve."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.objective_value
return self._solution().objective_value

def BestObjectiveBound(self) -> float:
"""Returns the best lower (upper) bound found when min(max)imizing."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.best_objective_bound
return self._solution().best_objective_bound

def StatusName(self, status: ... = None) -> str:
"""Returns the name of the status returned by Solve()."""
if status is None:
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
status = self.__solution.status
status = self._solution().status
return cp_model_pb2.CpSolverStatus.Name(status)

def NumBooleans(self) -> int:
"""Returns the number of boolean variables managed by the SAT solver."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.num_booleans
return self._solution().num_booleans

def NumConflicts(self) -> int:
"""Returns the number of conflicts since the creation of the solver."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.num_conflicts
return self._solution().num_conflicts

def NumBranches(self) -> int:
"""Returns the number of search branches explored by the solver."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.num_branches
return self._solution().num_branches

def WallTime(self) -> float:
"""Returns the wall time in seconds since the creation of the solver."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.wall_time
return self._solution().wall_time

def UserTime(self) -> float:
"""Returns the user time in seconds since the creation of the solver."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.user_time
return self._solution().user_time

def ResponseStats(self) -> str:
"""Returns some statistics on the solution found as a string."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return swig_helper.CpSatHelper.SolverResponseStats(self.__solution)
return swig_helper.CpSatHelper.SolverResponseStats(self._solution())

def ResponseProto(self) -> Optional[cp_model_pb2.CpSolverResponse]:
def ResponseProto(self) -> cp_model_pb2.CpSolverResponse:
"""Returns the response object."""
return self.__solution
return self._solution()

def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]:
"""Returns the indices of the infeasible assumptions."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.sufficient_assumptions_for_infeasibility
return self._solution().sufficient_assumptions_for_infeasibility

def SolutionInfo(self) -> str:
"""Returns some information on the solve process.
Expand All @@ -2642,9 +2625,7 @@ def SolutionInfo(self) -> str:
Raises:
RuntimeError: if Solve() has not been called.
"""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution.solution_info
return self._solution().solution_info


class CpSolverSolutionCallback(swig_helper.SolutionCallback):
Expand Down

0 comments on commit e86780b

Please sign in to comment.