Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iterating through a DataLoader is still broken #843

Open
tiemvanderdeure opened this issue Oct 17, 2024 · 3 comments
Open

iterating through a DataLoader is still broken #843

tiemvanderdeure opened this issue Oct 17, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@tiemvanderdeure
Copy link

Describe the bug 🐞
Iterating over a DataLoader is still broken on 4.0.3. If epochs is less than the number of elements in DataLoader, only the first epochs elements get evaluated.

Expected behavior
Surely this should iterate over the entire dataset epochs times

Minimal Reproducible Example 👇

function lossf(θ, data)
    @show data
    return sum.^2)
end
dataloader = DataLoader(collect(1:10), batchsize = 1)
opt_func = OptimizationFunction(
    lossf,
    Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, [2.0], dataloader)
res_adam = solve(
    opt_prob, Optimisers.Adam(0.001); epochs = 3)

Error & Stacktrace ⚠️
No stacktrace, but julia outputs this:

data = [1]
data = [2]
data = [3]
data = [3]
data = [1]
data = [2]
data = [3]
data = [3]
data = [1]
data = [2]
data = [3]
data = [3]
retcode: Default
u: 1-element Vector{Float64}:
 1.9940003690339771

Environment (please complete the following information):
Latest versions of Optimization and OptimizationOptimisers

Additional context
The problem is in this line:

if i == maxiters #Last iter, revert to best.

Probably the if condition should be something like epoch == maxiters && i == length(data). It would also be good to add a test.

Related to #835 and #842

@tiemvanderdeure tiemvanderdeure added the bug Something isn't working label Oct 17, 2024
@tiemvanderdeure tiemvanderdeure changed the title epochs is still broken iterating through a DataLoader is still broken Oct 17, 2024
@tiemvanderdeure
Copy link
Author

Similarly on this line

opt_state = Optimization.OptimizationState(iter = i,

this should probably be iter = i + (epoch - 1) * length(data) instead

@tiemvanderdeure
Copy link
Author

tiemvanderdeure commented Oct 18, 2024

Just to continue this thread, callback also doesn't work as expected. If callback returns false, this only breaks the inner loop, meaning iteration through the data is reset, but the solver continues on the next epoch regardless. If the data is something else than a dataloader, this means returning false from the callback doesn't do anything at all.

MWE

function callback(state, l)
    @show state.iter
    if state.iter % 10 == 2
        println("stopping training!")
        return true
    else
        return false
    end
end
function lossf(θ, data)
    return sum.^2)
end
dataloader = DataLoader(collect(1:10), batchsize = 1)
opt_func = OptimizationFunction(
    lossf,
    Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, [2.0], 1)
solve(opt_prob, Optimisers.Adam(0.001); callback, epochs = 5)

Which returns

state.iter = 1
state.iter = 2
stopping training!
state.iter = 3
state.iter = 4
state.iter = 5
state.iter = 5
retcode: Default
u: 1-element Vector{Float64}:
 1.9970000478972731

@ChrisRackauckas
Copy link
Member

@Vaibhavdixit02 you got this one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants