Skip to content

Commit

Permalink
[CP-SAT] Initial pandas support + samples
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Jul 11, 2023
1 parent e86780b commit 0094784
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 45 deletions.
2 changes: 2 additions & 0 deletions ortools/sat/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ py_library(
deps = [
":cp_model_helper",
":swig_helper",
requirement("numpy"),
requirement("pandas"),
"//ortools/sat:cp_model_py_pb2",
"//ortools/sat:sat_parameters_py_pb2",
"//ortools/util/python:sorted_interval_list",
Expand Down
210 changes: 195 additions & 15 deletions ortools/sat/python/cp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import threading
import time
from typing import (
Callable,
Dict,
Iterable,
List,
Expand All @@ -64,6 +65,7 @@
import warnings

import numpy as np
import pandas as pd

from ortools.sat import cp_model_pb2
from ortools.sat import sat_parameters_pb2
Expand Down Expand Up @@ -116,6 +118,7 @@
LinearExprT = Union["LinearExpr", "IntVar", IntegralT]
ObjLinearExprT = Union["LinearExpr", "IntVar", NumberT]
ArcT = Tuple[IntegralT, IntegralT, LiteralT]
_IndexOrSeries = Union[pd.Index, pd.Series]


def DisplayBounds(bounds: Sequence[int]) -> str:
Expand Down Expand Up @@ -1143,6 +1146,86 @@ def NewConstant(self, value: IntegralT) -> IntVar:
"""Declares a constant integer."""
return IntVar(self.__model, self.GetOrMakeIndexFromConstant(value), None)

def NewIntVarSeries(
self,
name: str,
index: pd.Index,
lower_bounds: Union[IntegralT, pd.Series],
upper_bounds: Union[IntegralT, pd.Series],
) -> pd.Series:
"""Creates a series of (scalar-valued) variables with the given name.
Args:
name (str): Required. The name of the variable set.
index (pd.Index): Required. The index to use for the variable set.
lower_bounds (Union[int, pd.Series]): A lower bound for variables in the
set. If a `pd.Series` is passed in, it will be based on the
corresponding values of the pd.Series.
upper_bounds (Union[int, pd.Series]): An upper bound for variables in the
set. If a `pd.Series` is passed in, it will be based on the
corresponding values of the pd.Series.
Returns:
pd.Series: The variable set indexed by its corresponding dimensions.
Raises:
TypeError: if the `index` is invalid (e.g. a `DataFrame`).
ValueError: if the `name` is not a valid identifier or already exists.
ValueError: if the `lowerbound` is greater than the `upperbound`.
ValueError: if the index of `lower_bound`, or `upper_bound` does not match
the input index.
"""
if not isinstance(index, pd.Index):
raise TypeError("Non-index object is used as index")
if not name.isidentifier():
raise ValueError("name={} is not a valid identifier".format(name))
if (
isinstance(lower_bounds, IntegralT)
and isinstance(upper_bounds, IntegralT)
and lower_bounds > upper_bounds
):
raise ValueError(
f"lower_bound={lower_bounds} is greater than"
f" upper_bound={upper_bounds} for variable set={name}"
)

lower_bounds = _ConvertToSeriesAndValidateIndex(lower_bounds, index)
upper_bounds = _ConvertToSeriesAndValidateIndex(upper_bounds, index)
return pd.Series(
index=index,
data=[
# pylint: disable=g-complex-comprehension
IntVar(
model=self.__model,
name=f"{name}[{i}]",
domain=Domain(lower_bounds[i], upper_bounds[i]),
)
for i in index
],
)

def NewBoolVarSeries(
self,
name: str,
index: pd.Index,
) -> pd.Series:
"""Creates a series of (scalar-valued) variables with the given name.
Args:
name (str): Required. The name of the variable set.
index (pd.Index): Required. The index to use for the variable set.
Returns:
pd.Series: The variable set indexed by its corresponding dimensions.
Raises:
TypeError: if the `index` is invalid (e.g. a `DataFrame`).
ValueError: if the `name` is not a valid identifier or already exists.
"""
return self.NewIntVarSeries(
name=name, index=index, lower_bounds=0, upper_bounds=1
)

# Linear constraints.

def AddLinearConstraint(
Expand Down Expand Up @@ -2556,65 +2639,107 @@ def StopSearch(self) -> None:
if self.__solve_wrapper:
self.__solve_wrapper.StopSearch()

def _solution(self) -> cp_model_pb2.CpSolverResponse:
def _Solution(self) -> cp_model_pb2.CpSolverResponse:
"""Checks Solve() has been called, and returns the solution."""
if self.__solution is None:
raise RuntimeError("Solve() has not been called.")
return self.__solution

def Value(self, expression: LinearExprT) -> int:
"""Returns the value of a linear expression after solve."""
return EvaluateLinearExpr(expression, self._solution())
return EvaluateLinearExpr(expression, self._Solution())

def Values(self, variables: _IndexOrSeries) -> pd.Series:
"""Returns the values of the input variables.
If `variables` is a `pd.Index`, then the output will be indexed by the
variables. If `variables` is a `pd.Series` indexed by the underlying
dimensions, then the output will be indexed by the same underlying
dimensions.
Args:
variables (Union[pd.Index, pd.Series]): The set of variables from which to
get the values.
Returns:
pd.Series: The values of all variables in the set.
"""
solution = self._Solution()
return _AttributeSeries(
func=lambda v: solution.solution[v.Index()],
values=variables,
)

def BooleanValue(self, literal: LiteralT) -> bool:
"""Returns the boolean value of a literal after solve."""
return EvaluateBooleanExpression(literal, self._solution())
return EvaluateBooleanExpression(literal, self._Solution())

def BooleanValues(self, variables: _IndexOrSeries) -> pd.Series:
"""Returns the values of the input variables.
If `variables` is a `pd.Index`, then the output will be indexed by the
variables. If `variables` is a `pd.Series` indexed by the underlying
dimensions, then the output will be indexed by the same underlying
dimensions.
Args:
variables (Union[pd.Index, pd.Series]): The set of variables from which to
get the values.
Returns:
pd.Series: The values of all variables in the set.
"""
solution = self._Solution()
return _AttributeSeries(
func=lambda literal: EvaluateBooleanExpression(literal, solution),
values=variables,
)

def ObjectiveValue(self) -> float:
"""Returns the value of the objective after solve."""
return self._solution().objective_value
return self._Solution().objective_value

def BestObjectiveBound(self) -> float:
"""Returns the best lower (upper) bound found when min(max)imizing."""
return self._solution().best_objective_bound
return self._Solution().best_objective_bound

def StatusName(self, status: ... = None) -> str:
"""Returns the name of the status returned by Solve()."""
if status is None:
status = self._solution().status
status = self._Solution().status
return cp_model_pb2.CpSolverStatus.Name(status)

def NumBooleans(self) -> int:
"""Returns the number of boolean variables managed by the SAT solver."""
return self._solution().num_booleans
return self._Solution().num_booleans

def NumConflicts(self) -> int:
"""Returns the number of conflicts since the creation of the solver."""
return self._solution().num_conflicts
return self._Solution().num_conflicts

def NumBranches(self) -> int:
"""Returns the number of search branches explored by the solver."""
return self._solution().num_branches
return self._Solution().num_branches

def WallTime(self) -> float:
"""Returns the wall time in seconds since the creation of the solver."""
return self._solution().wall_time
return self._Solution().wall_time

def UserTime(self) -> float:
"""Returns the user time in seconds since the creation of the solver."""
return self._solution().user_time
return self._Solution().user_time

def ResponseStats(self) -> str:
"""Returns some statistics on the solution found as a string."""
return swig_helper.CpSatHelper.SolverResponseStats(self._solution())
return swig_helper.CpSatHelper.SolverResponseStats(self._Solution())

def ResponseProto(self) -> cp_model_pb2.CpSolverResponse:
"""Returns the response object."""
return self._solution()
return self._Solution()

def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]:
"""Returns the indices of the infeasible assumptions."""
return self._solution().sufficient_assumptions_for_infeasibility
return self._Solution().sufficient_assumptions_for_infeasibility

def SolutionInfo(self) -> str:
"""Returns some information on the solve process.
Expand All @@ -2625,7 +2750,7 @@ def SolutionInfo(self) -> str:
Raises:
RuntimeError: if Solve() has not been called.
"""
return self._solution().solution_info
return self._Solution().solution_info


class CpSolverSolutionCallback(swig_helper.SolutionCallback):
Expand Down Expand Up @@ -2806,3 +2931,58 @@ def on_solution_callback(self) -> None:
def solution_count(self) -> int:
"""Returns the number of solutions found."""
return self.__solution_count


def _GetIndex(obj: _IndexOrSeries) -> pd.Index:
"""Returns the indices of `obj` as a `pd.Index`."""
if isinstance(obj, pd.Series):
return obj.index
return obj


def _AttributeSeries(
*,
func: Callable[[IntVar], IntegralT],
values: _IndexOrSeries,
) -> pd.Series:
"""Returns the attributes of `values`.
Args:
func: The function to call for getting the attribute data.
values: The values that the function will be applied (element-wise) to.
Returns:
pd.Series: The attribute values.
"""
return pd.Series(
data=[func(v) for v in values],
index=_GetIndex(values),
)


def _ConvertToSeriesAndValidateIndex(
value_or_series: Union[IntegralT, pd.Series], index: pd.Index
) -> pd.Series:
"""Returns a pd.Series of the given index with the corresponding values.
Args:
value_or_series: the values to be converted (if applicable).
index: the index of the resulting pd.Series.
Returns:
pd.Series: The set of values with the given index.
Raises:
TypeError: If the type of `value_or_series` is not recognized.
ValueError: If the index does not match.
"""
if isinstance(value_or_series, (bool, IntegralT)):
result = pd.Series(data=value_or_series, index=index)
elif isinstance(value_or_series, pd.Series):
if value_or_series.index.equals(index):
result = value_or_series
else:
raise ValueError("index does not match")
else:
raise TypeError("invalid type={}".format(type(value_or_series)))
return result
2 changes: 2 additions & 0 deletions ortools/sat/samples/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ code_sample_cc_py(name = "assumptions_sample_sat")

code_sample_cc_py(name = "binpacking_problem_sat")

code_sample_py(name = "bin_packing_sat")

code_sample_cc_py(name = "bool_or_sample_sat")

code_sample_py(name = "boolean_product_sample_sat")
Expand Down
Loading

0 comments on commit 0094784

Please sign in to comment.