Skip to content

Commit

Permalink
Merge pull request #827 from SciML/ap/refactor
Browse files Browse the repository at this point in the history
feat: make MLUtils into a weakdep & suppport MLDataDevices
  • Loading branch information
Vaibhavdixit02 authored Sep 22, 2024
2 parents 904cac0 + 1f4cba3 commit 39fa5fb
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 6 deletions.
19 changes: 16 additions & 3 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
name = "OptimizationOptimisers"
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[weakdeps]
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

[extensions]
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
OptimizationOptimisersMLUtilsExt = "MLUtils"

[compat]
MLDataDevices = "1.1"
MLUtils = "0.4.4"
Optimisers = "0.2, 0.3"
Optimization = "4"
Expand All @@ -20,9 +28,14 @@ Reexport = "1.2"
julia = "1"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ForwardDiff", "Test", "Zygote"]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationOptimisersMLDataDevicesExt

using MLDataDevices
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = (@show "dkjht"; true)

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationOptimisersMLUtilsExt

using MLUtils
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true

end
9 changes: 6 additions & 3 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OptimizationOptimisers

using Reexport, Printf, ProgressLogging
@reexport using Optimisers, Optimization
using Optimization.SciMLBase, MLUtils
using Optimization.SciMLBase

SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
SciMLBase.requiresgradient(opt::AbstractRule) = true
Expand All @@ -16,6 +16,8 @@ function SciMLBase.__init(
kwargs...)
end

isa_dataiterator(data) = false

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
Expand Down Expand Up @@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end

if cache.p isa MLUtils.DataLoader
if isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end

opt = cache.opt
θ = copy(cache.u0)
G = copy(θ)
Expand Down Expand Up @@ -114,7 +117,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d...)
cache.f.grad(G, θ, d)
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
objective = x[1],
Expand Down
40 changes: 40 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,43 @@ using Zygote

@test_throws ArgumentError sol=solve(prob, Optimisers.Adam())
end

@testset "Minibatching" begin
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays

x = rand(10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)

# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)

function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-4
end

function loss(ps, data)
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end

optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)

@test res.objective < 1e-4

using MLDataDevices
data = CPUDevice()(data)
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)

@test res.objective < 1e-4
end

0 comments on commit 39fa5fb

Please sign in to comment.