Skip to content

Commit

Permalink
Add minibatching tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 22, 2024
1 parent 8d7cd3a commit 5a76a9e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 6 deletions.
8 changes: 4 additions & 4 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[extensions]
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
OptimizationOptimisersMLUtilsExt = "MLUtils"

[weakdeps]
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

[extensions]
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
OptimizationOptimisersMLUtilsExt = "MLUtils"

[compat]
MLDataDevices = "1.1"
MLUtils = "0.4.4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module OptimizationOptimisersMLDataDevicesExt
using MLDataDevices
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true
OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true)

end
2 changes: 1 addition & 1 deletion lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d...)
cache.f.grad(G, θ, d)
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
objective = x[1],
Expand Down
41 changes: 41 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,44 @@ using Zygote

@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
end

@testset "Minibatching" begin
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Statistics, Plots,
Random, ComponentArrays

x = rand(10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)

# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-4
end

function loss(ps, data)
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)

@test res.objective < 1e-4

using MLDataDevices
data = CPUDevice()(data)
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)

@test res.objective < 1e-4
end

0 comments on commit 5a76a9e

Please sign in to comment.