Skip to content

Commit

Permalink
[CP-SAT] one more presolve on int_prod
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Oct 7, 2024
1 parent caad9bf commit 0571f3f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 24 deletions.
20 changes: 20 additions & 0 deletions ortools/sat/cp_model_postsolve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,23 @@ void PostsolveIntMod(const ConstraintProto& ct, std::vector<Domain>* domains) {
(*domains)[target.vars(0)] = Domain(value);
}

// We only support assigning to an affine target.
void PostsolveIntProd(const ConstraintProto& ct, std::vector<Domain>* domains) {
int64_t target_value = 1;
for (const LinearExpressionProto& expr : ct.int_prod().exprs()) {
target_value *= EvaluateLinearExpression(expr, *domains);
}

const LinearExpressionProto& target = ct.int_prod().target();
CHECK_EQ(target.vars().size(), 1);
CHECK(RefIsPositive(target.vars(0)));

target_value -= target.offset();
CHECK_EQ(target_value % target.coeffs(0), 0);
target_value /= target.coeffs(0);
(*domains)[target.vars(0)] = Domain(target_value);
}

void PostsolveResponse(const int64_t num_variables_in_original_model,
const CpModelProto& mapping_proto,
const std::vector<int>& postsolve_mapping,
Expand Down Expand Up @@ -390,6 +407,9 @@ void PostsolveResponse(const int64_t num_variables_in_original_model,
case ConstraintProto::kIntMod:
PostsolveIntMod(ct, &domains);
break;
case ConstraintProto::kIntProd:
PostsolveIntProd(ct, &domains);
break;
default:
// This should never happen as we control what kind of constraint we
// add to the mapping_proto;
Expand Down
83 changes: 59 additions & 24 deletions ortools/sat/cp_model_presolve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1455,21 +1455,66 @@ bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) {
return false;
}

Domain EvaluateImpliedIntProdDomain(const LinearArgumentProto& expr,
const PresolveContext& context) {
if (expr.exprs().size() == 2) {
const LinearExpressionProto& expr0 = expr.exprs(0);
const LinearExpressionProto& expr1 = expr.exprs(1);
if (LinearExpressionProtosAreEqual(expr0, expr1)) {
return context.DomainSuperSetOf(expr0).SquareSuperset();
}
if (expr0.vars().size() == 1 && expr1.vars().size() == 1 &&
expr0.vars(0) == expr1.vars(0)) {
return context.DomainOf(expr0.vars(0))
.QuadraticSuperset(expr0.coeffs(0), expr0.offset(), expr1.coeffs(0),
expr1.offset());
}
}

Domain implied(1);
for (const LinearExpressionProto& expr : expr.exprs()) {
implied =
implied.ContinuousMultiplicationBy(context.DomainSuperSetOf(expr));
}
return implied;
}

bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) {
if (context_->ModelIsUnsat()) return false;
if (HasEnforcementLiteral(*ct)) return false;

// Start by restricting the domain of target. We will be more precise later.
bool domain_modified = false;
{
Domain implied(1);
for (const LinearExpressionProto& expr : ct->int_prod().exprs()) {
implied =
implied.ContinuousMultiplicationBy(context_->DomainSuperSetOf(expr));
}
if (!context_->IntersectDomainWith(ct->int_prod().target(), implied,
&domain_modified)) {
return false;
Domain implied_domain =
EvaluateImpliedIntProdDomain(ct->int_prod(), *context_);
if (!context_->IntersectDomainWith(ct->int_prod().target(), implied_domain,
&domain_modified)) {
return false;
}

// Remove a constraint if the target only appears in the constraint. For this
// to be correct some conditions must be met:
// - The target is an affine linear with coefficient -1 or 1.
// - The target does not appear in the rhs (no x = (a*x + b) * ...).
// - The target domain covers all the possible range of the rhs.
if (ExpressionContainsSingleRef(ct->int_prod().target()) &&
context_->VariableIsUniqueAndRemovable(ct->int_prod().target().vars(0)) &&
std::abs(ct->int_prod().target().coeffs(0)) == 1) {
const LinearExpressionProto& target = ct->int_prod().target();
if (!absl::c_any_of(ct->int_prod().exprs(),
[&target](const LinearExpressionProto& expr) {
return absl::c_linear_search(expr.vars(),
target.vars(0));
})) {
const Domain target_domain =
Domain(target.offset())
.AdditionWith(context_->DomainOf(target.vars(0)));
if (implied_domain.IsIncludedIn(target_domain)) {
context_->MarkVariableAsRemoved(ct->int_prod().target().vars(0));
context_->NewMappingConstraint(*ct, __FILE__, __LINE__);
context_->UpdateRuleStats("int_prod: unused affine target");
return RemoveConstraint(ct);
}
}
}

Expand Down Expand Up @@ -1651,21 +1696,11 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) {
}

// Restrict the target domain if possible.
Domain implied(1);
bool is_square = false;
if (ct->int_prod().exprs_size() == 2 &&
LinearExpressionProtosAreEqual(ct->int_prod().exprs(0),
ct->int_prod().exprs(1))) {
is_square = true;
implied =
context_->DomainSuperSetOf(ct->int_prod().exprs(0)).SquareSuperset();
} else {
for (const LinearExpressionProto& expr : ct->int_prod().exprs()) {
implied =
implied.ContinuousMultiplicationBy(context_->DomainSuperSetOf(expr));
}
}
if (!context_->IntersectDomainWith(ct->int_prod().target(), implied,
implied_domain = EvaluateImpliedIntProdDomain(ct->int_prod(), *context_);
const bool is_square = ct->int_prod().exprs_size() == 2 &&
LinearExpressionProtosAreEqual(
ct->int_prod().exprs(0), ct->int_prod().exprs(1));
if (!context_->IntersectDomainWith(ct->int_prod().target(), implied_domain,
&domain_modified)) {
return false;
}
Expand Down
77 changes: 77 additions & 0 deletions ortools/util/sorted_interval_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/numeric/int128.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "ortools/base/logging.h"
Expand Down Expand Up @@ -634,6 +635,82 @@ Domain Domain::SquareSuperset() const {
}
}

namespace {
ClosedInterval EvaluateQuadraticProdInterval(int64_t a, int64_t b, int64_t c,
int64_t d, int64_t variable_min,
int64_t variable_max) {
// We have (a*x + b)(c*x + d) = a*c*x*x + (a*d + b*c)*x + b*d
// The minimum or maximum is at x = -(a*d + b*c)/(2*a*c)
//
// The minimum and maximum of the expression happens when x is one of the
// following:
// - variable_min;
// - variable_max;
// - the closest point to the parabola extreme, rounded down;
// - the closest point to the parabola extreme, rounded up.

const absl::int128 nominator =
-absl::int128{a} * absl::int128{d} - absl::int128{b} * absl::int128{c};
const absl::int128 denominator = absl::int128{a} * absl::int128{c};
const absl::int128 evaluated_minimum_point = (nominator / denominator) / 2;

const auto& evaluate = [&a, &b, &c, &d](const int64_t x) {
return CapProd(CapAdd(CapProd(a, x), b), CapAdd(CapProd(c, x), d));
};

const int64_t at_min_x = evaluate(variable_min);
const int64_t at_max_x = evaluate(variable_max);
int64_t min_var = std::min(at_min_x, at_max_x);
int64_t max_var = std::max(at_min_x, at_max_x);

if (evaluated_minimum_point > variable_min &&
evaluated_minimum_point < variable_max) {
const int64_t point_at_minimum_64 =
static_cast<int64_t>(evaluated_minimum_point);
const int rounder = ((nominator > 0) == (denominator > 0) ? 1 : -1);
const int64_t point1 = evaluate(point_at_minimum_64);
const int64_t point2 = evaluate(point_at_minimum_64 + rounder);
min_var = std::min(min_var, std::min(point1, point2));
max_var = std::max(max_var, std::max(point1, point2));
}

return ClosedInterval(min_var, max_var);
}
} // namespace

Domain Domain::QuadraticSuperset(int64_t a, int64_t b, int64_t c,
int64_t d) const {
if (IsEmpty()) return Domain();

if (Size() < kDomainComplexityLimit) {
std::vector<int64_t> values;
values.reserve(Size());
for (const int64_t value : Values()) {
values.push_back( //
CapProd( //
CapAdd(CapProd(a, value), b), CapAdd(CapProd(c, value), d)));
}
return Domain::FromValues(std::move(values));
}

if (a == 0) {
return MultiplicationBy(CapProd(c, b)).AdditionWith(Domain(CapProd(d, b)));
}
if (c == 0) {
return MultiplicationBy(CapProd(a, d)).AdditionWith(Domain(CapProd(d, b)));
}

Domain result;
result.intervals_.reserve(NumIntervals());
for (const auto& interval : intervals_) {
result.intervals_.push_back(EvaluateQuadraticProdInterval(
a, b, c, d, interval.start, interval.end));
}
std::sort(result.intervals_.begin(), result.intervals_.end());
UnionOfSortedIntervals(&result.intervals_);
return result;
}

// It is a bit difficult to see, but this code is doing the same thing as
// for all interval in this.UnionWith(implied_domain.Complement())):
// - Take the two extreme points (min and max) in interval \inter implied.
Expand Down
5 changes: 5 additions & 0 deletions ortools/util/sorted_interval_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ class Domain {
*/
Domain SquareSuperset() const;

/**
* Returns a superset of {x ∈ Int64, ∃ y ∈ D, x = (a*y + b)*(c*y + d) }.
*/
Domain QuadraticSuperset(int64_t a, int64_t b, int64_t c, int64_t d) const;

/**
* Advanced usage. Given some \e implied information on this domain that is
* assumed to be always true (i.e. only values in the intersection with
Expand Down

0 comments on commit 0571f3f

Please sign in to comment.