From 2a803ffa0a20af1838211197baafb33faa16988a Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 12 Sep 2024 21:29:56 -0400 Subject: [PATCH] tests pass now pls --- Project.toml | 2 +- src/sophia.jl | 10 +++++----- test/minibatch.jl | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 822043263..92be1bb46 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ LinearAlgebra = "1.10" Logging = "1.10" LoggingExtras = "0.4, 1" MLUtils = "0.4.4" -OptimizationBase = "2.0.2" +OptimizationBase = "2.0.3" Printf = "1.10" ProgressLogging = "0.1" Reexport = "1.2" diff --git a/src/sophia.jl b/src/sophia.jl index 2bf602ce8..5419b87d7 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -1,4 +1,4 @@ -using Optimization.LinearAlgebra +using Optimization.LinearAlgebra, MLUtils struct Sophia η::Float64 @@ -80,14 +80,14 @@ function SciMLBase.__solve(cache::OptimizationCache{ for _ in 1:maxiters for (i, d) in enumerate(data) if cache.f.fg !== nothing && dataiterate - x = cache.f.fg(G, θ, d) + x = cache.f.fg(gₜ, θ, d) elseif dataiterate - cache.f.grad(G, θ, d) + cache.f.grad(gₜ, θ, d) x = cache.f(θ, d) elseif cache.f.fg !== nothing - x = cache.f.fg(G, θ) + x = cache.f.fg(gₜ, θ) else - cache.f.grad(G, θ) + cache.f.grad(gₜ, θ) x = cache.f(θ) end opt_state = Optimization.OptimizationState(; iter = i, diff --git a/test/minibatch.jl b/test/minibatch.jl index aea533a95..f818f4ee1 100644 --- a/test/minibatch.jl +++ b/test/minibatch.jl @@ -19,7 +19,7 @@ function dudt_(u, p, t) ann(u, p, st)[1] .* u end -function callback(state, l, pred) #callback function to observe training +function callback(state, l) #callback function to observe training display(l) return false end