Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Oct 24, 2024
1 parent 0ca4c6f commit 20648ee
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 14 deletions.
16 changes: 5 additions & 11 deletions src/predictors/ReLU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ ReducedSpace(ReLU())
struct ReLU <: AbstractPredictor end

function add_predictor(model::JuMP.AbstractModel, predictor::ReLU, x::Vector)
ub = last.(_get_variable_bounds.(x))
y = JuMP.@variable(model, [1:length(x)], base_name = "moai_ReLU")
cons = Any[]
_set_bounds_if_finite.(Ref(cons), y, 0, max.(0, ub))
cons = _set_direct_bounds(x -> max(0, x), 0, nothing, x, y)
append!(cons, JuMP.@constraint(model, y .== max.(0, x)))
return y, Formulation(predictor, y, cons)
end
Expand Down Expand Up @@ -132,14 +130,12 @@ function add_predictor(
x::Vector,
)
m = length(x)
bounds = _get_variable_bounds.(x)
y = JuMP.@variable(model, [1:m], base_name = "moai_ReLU")
cons = Any[]
_set_bounds_if_finite.(Ref(cons), y, 0, max.(0, last.(bounds)))
cons = _set_direct_bounds(x -> max(0, x), 0, nothing, x, y)
formulation = Formulation(predictor, Any[], cons)
append!(formulation.variables, y)
for i in 1:m
lb, ub = bounds[i]
lb, ub = _get_variable_bounds(x[i])
z = JuMP.@variable(model, binary = true)
JuMP.set_name(z, "moai_z[$i]")
push!(formulation.variables, z)
Expand Down Expand Up @@ -217,8 +213,7 @@ function add_predictor(
m = length(x)
bounds = _get_variable_bounds.(x)
y = JuMP.@variable(model, [i in 1:m], base_name = "moai_ReLU")
cons = Any[]
_set_bounds_if_finite.(Ref(cons), y, 0, max.(0, last.(bounds)))
cons = _set_direct_bounds(x -> max(0, x), 0, nothing, x, y)
z = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "moai_z")
_set_bounds_if_finite.(Ref(cons), z, nothing, -first.(bounds))
append!(cons, JuMP.@constraint(model, x .== y - z))
Expand Down Expand Up @@ -294,8 +289,7 @@ function add_predictor(
m = length(x)
bounds = _get_variable_bounds.(x)
y = JuMP.@variable(model, [1:m], base_name = "moai_ReLU")
cons = Any[]
_set_bounds_if_finite.(Ref(cons), y, 0, max.(0, last.(bounds)))
cons = _set_direct_bounds(x -> max(0, x), 0, nothing, x, y)
z = JuMP.@variable(model, [1:m], base_name = "moai_z")
_set_bounds_if_finite.(Ref(cons), z, 0, max.(0, -first.(bounds)))
append!(cons, JuMP.@constraint(model, x .== y - z))
Expand Down
2 changes: 0 additions & 2 deletions src/predictors/Tanh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ ReducedSpace(Tanh())
"""
struct Tanh <: AbstractPredictor end

_eval(::Tanh, x::Real) = tanh(x)

function add_predictor(model::JuMP.AbstractModel, predictor::Tanh, x::Vector)
y = JuMP.@variable(model, [1:length(x)], base_name = "moai_Tanh")
cons = _set_variable_bounds(tanh, -1, 1, x, y)
Expand Down
1 change: 0 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ _get_variable_bounds(::Any) = -Inf, Inf
# Default fallback: skip setting variable bound
_set_bounds_if_finite(::Vector, ::Any, ::Any, ::Any) = nothing


function _set_direct_bounds(f::F, l, u, x::Vector, y::Vector) where {F}
cons = Any[]
for (xi, yi) in zip(x, y)
Expand Down

0 comments on commit 20648ee

Please sign in to comment.