diff --git a/Project.toml b/Project.toml index 26d959d80..822043263 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ LBFGSB = "0.4.1" LinearAlgebra = "1.10" Logging = "1.10" LoggingExtras = "0.4, 1" +MLUtils = "0.4.4" OptimizationBase = "2.0.2" Printf = "1.10" ProgressLogging = "0.1" diff --git a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl index 736632da2..aea9ada02 100644 --- a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl +++ b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl @@ -159,14 +159,18 @@ function SciMLBase.__solve(cache::OptimizationCache{ return cache.sense === Optimization.MaxSense ? -__x : __x end - fg! = function (G, θ) - if G !== nothing - cache.f.grad(G, θ) - if cache.sense === Optimization.MaxSense - G .*= -one(eltype(G)) + if cache.f.fg === nothing + fg! = function (G, θ) + if G !== nothing + cache.f.grad(G, θ) + if cache.sense === Optimization.MaxSense + G .*= -one(eltype(G)) + end end + return _loss(θ) end - return _loss(θ) + else + fg! = cache.f.fg end if cache.opt isa Optim.KrylovTrustRegion diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index bdae71df9..80904bba5 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -12,6 +12,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] +MLUtils = "0.4.4" Optimisers = "0.2, 0.3" Optimization = "3.21" ProgressLogging = "0.1"