From afc76a71131d4fc091ab361c7ec2ed0841c0cb7f Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 24 Jun 2024 10:59:56 +0200 Subject: [PATCH 1/3] Replace Agnostic with Adaptive --- test/argmax.jl | 6 ++++-- test/paths.jl | 6 +++--- test/ranking.jl | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/argmax.jl b/test/argmax.jl index 4075c51..25f7d27 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -116,7 +116,8 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + # frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, error_function=hamming_distance, @@ -198,7 +199,8 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), + # frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), ), error_function=hamming_distance, diff --git a/test/paths.jl b/test/paths.jl index c38176b..0681e8a 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -101,7 +101,7 @@ end shortest_path_maximizer; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, error_function=mse_kw, @@ -177,7 +177,7 @@ end shortest_path_maximizer; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), ), error_function=mse_kw, @@ -247,7 +247,7 @@ end shortest_path_maximizer; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), cost, ), diff --git a/test/ranking.jl b/test/ranking.jl index 122ed99..63954a5 100644 --- a/test/ranking.jl +++ b/test/ranking.jl @@ -101,7 +101,7 @@ end ranking; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, error_function=hamming_distance, @@ -170,7 +170,7 @@ end ranking; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), ), error_function=hamming_distance, @@ -303,7 +303,7 @@ end ranking; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), cost, ), From dc73db3f752255fcefd2f6bfa04e0538fc17ac18 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 24 Jun 2024 11:01:58 +0200 Subject: [PATCH 2/3] cleanup --- test/argmax.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/argmax.jl b/test/argmax.jl index 25f7d27..c3c2a57 100644 --- a/test/argmax.jl +++ b/test/argmax.jl @@ -116,7 +116,6 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - # frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), loss=mse_kw, @@ -200,7 +199,6 @@ end Ω=half_square_norm, Ω_grad=identity_kw, frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), - # frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), ), ), error_function=hamming_distance, @@ -265,7 +263,7 @@ end one_hot_argmax; Ω=half_square_norm, Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), + frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()), ), cost, ), From 9e043435bdebadceb70248f95cf598f5d4b124ed Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 13:02:37 +0100 Subject: [PATCH 3/3] fix tutorial --- examples/tutorial.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/tutorial.jl b/examples/tutorial.jl index f6838ab..876ca1e 100644 --- a/examples/tutorial.jl +++ b/examples/tutorial.jl @@ -113,14 +113,15 @@ Thanks to this smoothing, we can now train our model with a standard gradient op encoder = deepcopy(initial_encoder) opt = Flux.Adam(); +opt_state = Flux.setup(opt, encoder) losses = Float64[] for epoch in 1:100 l = 0.0 for (x, y) in zip(X_train, Y_train) - grads = gradient(Flux.params(encoder)) do - l += loss(encoder(x), y; directions=queen_directions) + grads = Flux.gradient(encoder) do m + l += loss(m(x), y; directions=queen_directions) end - Flux.update!(opt, Flux.params(encoder), grads) + Flux.update!(opt_state, encoder, grads[1]) end push!(losses, l) end;