Skip to content

Commit

Permalink
fix tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
BatyLeo committed Dec 23, 2024
1 parent b90258c commit 9e04343
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions examples/tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ Thanks to this smoothing, we can now train our model with a standard gradient op

encoder = deepcopy(initial_encoder)
opt = Flux.Adam();
opt_state = Flux.setup(opt, encoder)
losses = Float64[]
for epoch in 1:100
l = 0.0
for (x, y) in zip(X_train, Y_train)
grads = gradient(Flux.params(encoder)) do
l += loss(encoder(x), y; directions=queen_directions)
grads = Flux.gradient(encoder) do m
l += loss(m(x), y; directions=queen_directions)
end
Flux.update!(opt, Flux.params(encoder), grads)
Flux.update!(opt_state, encoder, grads[1])
end
push!(losses, l)
end;
Expand Down

0 comments on commit 9e04343

Please sign in to comment.