Skip to content

Commit

Permalink
Fix bug which mutates user expressions in constraint macro (#3883)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 15, 2024
1 parent 7111683 commit fc0409a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ function _rewrite_expression(expr::Expr)
new_expr = MacroTools.postwalk(_rewrite_to_jump_logic, expr)
new_aff, parse_aff = _MA.rewrite(new_expr; move_factors_into_sums = false)
ret = gensym()
has_copy_if_mutable = Ref(false)
MacroTools.postwalk(parse_aff) do x
if x === MutableArithmetics.copy_if_mutable
has_copy_if_mutable[] = true
end
return x
end
if !has_copy_if_mutable[]
new_aff = :($_MA.copy_if_mutable($new_aff))
end
code = quote
$parse_aff
$ret = $flatten!($new_aff)
Expand Down
3 changes: 3 additions & 0 deletions src/macros/@constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,9 @@ function _clear_constant!(α::Number)
return zero(α), α
end

# !!! warning
# This method assumes that we can mutate `expr`. Ensure that this is the
# case upstream of this call site.
function build_constraint(
::Function,
expr::Union{Number,GenericAffExpr,GenericQuadExpr},
Expand Down
27 changes: 27 additions & 0 deletions test/test_macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2487,4 +2487,31 @@ function test_array_scalar_sets()
return
end

function test_do_not_mutate_expression_double_sided_comparison()
model = Model()
@variable(model, x)
@expression(model, a[1:1], x + 1)
@constraint(model, -1 <= a[1] <= 1)
@test isequal_canonical(a[1], x + 1)
return
end

function test_do_not_mutate_expression_single_sided_comparison()
model = Model()
@variable(model, x)
@expression(model, a[1:1], x + 1)
@constraint(model, a[1] >= 1)
@test isequal_canonical(a[1], x + 1)
return
end

function test_do_not_mutate_expression_in_set()
model = Model()
@variable(model, x)
@expression(model, a[1:1], x + 1)
@constraint(model, a[1] in MOI.Interval(-1, 1))
@test isequal_canonical(a[1], x + 1)
return
end

end # module

0 comments on commit fc0409a

Please sign in to comment.