Skip to content

Commit

Permalink
Fix and cleanup SPO+
Browse files Browse the repository at this point in the history
  • Loading branch information
BatyLeo committed Dec 23, 2024
1 parent 4068071 commit 908482d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 32 deletions.
39 changes: 9 additions & 30 deletions src/losses/spoplus_loss.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""
SPOPlusLoss <: AbstractLossLayer
$TYPEDEF
Convex surrogate of the Smart "Predict-then-Optimize" loss.
# Fields
- `maximizer`: linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy`
- `α::Float64`: convexification parameter, default = 2.0
$TYPEDFIELDS
Reference: <https://arxiv.org/abs/1710.08005>
"""
struct SPOPlusLoss{F} <: AbstractLossLayer
"linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy`"
maximizer::F
"convexification parameter, default = 2.0"
α::Float64
end

Expand All @@ -20,7 +21,9 @@ function Base.show(io::IO, spol::SPOPlusLoss)
end

"""
SPOPlusLoss(maximizer; α=2.0)
$TYPEDSIGNATURES
Constructor for [`SPOPlusLoss`](@ref).
"""
SPOPlusLoss(maximizer; α=2.0) = SPOPlusLoss(maximizer, float(α))

Expand All @@ -35,17 +38,7 @@ function (spol::SPOPlusLoss)(
(; maximizer, α) = spol
θ_α = α * θ - θ_true
y_α = maximizer(θ_α; kwargs...)
l = dot(θ_α, y_α) - dot(θ_α, y_true)
return l
end

function (spol::SPOPlusLoss{<:GeneralizedMaximizer})(
θ::AbstractArray, θ_true::AbstractArray, y_true::AbstractArray; kwargs...
)
(; maximizer, α) = spol
θ_α = α * θ - θ_true
y_α = maximizer(θ_α; kwargs...)
# This only works in theory if α = 2
# In theory, in the general case with a LinearMaximizer, this only works if α = 2
l =
objective_value(maximizer, θ_α, y_α; kwargs...) -
objective_value(maximizer, θ_α, y_true; kwargs...)
Expand All @@ -68,28 +61,14 @@ function compute_loss_and_gradient(
θ_true::AbstractArray,
y_true::AbstractArray;
kwargs...,
)
(; maximizer, α) = spol
θ_α = α * θ - θ_true
y_α = maximizer(θ_α; kwargs...)
l = dot(θ_α, y_α) - dot(θ_α, y_true)
return l, α .* (y_α .- y_true)
end

function compute_loss_and_gradient(
spol::SPOPlusLoss{<:GeneralizedMaximizer},
θ::AbstractArray,
θ_true::AbstractArray,
y_true::AbstractArray;
kwargs...,
)
(; maximizer, α) = spol
θ_α = α * θ - θ_true
y_α = maximizer(θ_α; kwargs...)
l =
objective_value(maximizer, θ_α, y_α; kwargs...) -
objective_value(maximizer, θ_α, y_true; kwargs...)
g = α .* (maximizer.g(y_α; kwargs...) - maximizer.g(y_true; kwargs...))
g = α .* (apply_g(maximizer, y_α; kwargs...) - apply_g(maximizer, y_true; kwargs...))
return l, g
end

Expand Down
4 changes: 2 additions & 2 deletions test/learning_generalized_maximizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end

true_encoder = encoder_factory()

generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
generalized_maximizer = LinearMaximizer(max_pricing; g, h)
function cost(y; instance)
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
end
Expand All @@ -157,7 +157,7 @@ end

true_encoder = encoder_factory()

generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
generalized_maximizer = LinearMaximizer(max_pricing; g, h)
function cost(y; instance)
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
end
Expand Down

0 comments on commit 908482d

Please sign in to comment.