Skip to content

Commit

Permalink
Merge pull request #116 from JuliaDecisionFocusedLearning/frankwolfe-…
Browse files Browse the repository at this point in the history
…tests

Replace Agnostic with Adaptive in FrankWolfe tests
  • Loading branch information
BatyLeo authored Dec 23, 2024
2 parents 59ab4fe + 9e04343 commit 343c95e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
7 changes: 4 additions & 3 deletions examples/tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ Thanks to this smoothing, we can now train our model with a standard gradient op

encoder = deepcopy(initial_encoder)
opt = Flux.Adam();
opt_state = Flux.setup(opt, encoder)
losses = Float64[]
for epoch in 1:100
l = 0.0
for (x, y) in zip(X_train, Y_train)
grads = gradient(Flux.params(encoder)) do
l += loss(encoder(x), y; directions=queen_directions)
grads = Flux.gradient(encoder) do m
l += loss(m(x), y; directions=queen_directions)
end
Flux.update!(opt, Flux.params(encoder), grads)
Flux.update!(opt_state, encoder, grads[1])
end
push!(losses, l)
end;
Expand Down
6 changes: 3 additions & 3 deletions test/argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
one_hot_argmax;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
loss=mse_kw,
error_function=hamming_distance,
Expand Down Expand Up @@ -198,7 +198,7 @@ end
one_hot_argmax;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
),
error_function=hamming_distance,
Expand Down Expand Up @@ -263,7 +263,7 @@ end
one_hot_argmax;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
cost,
),
Expand Down
6 changes: 3 additions & 3 deletions test/paths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
shortest_path_maximizer;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
loss=mse_kw,
error_function=mse_kw,
Expand Down Expand Up @@ -177,7 +177,7 @@ end
shortest_path_maximizer;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
),
error_function=mse_kw,
Expand Down Expand Up @@ -247,7 +247,7 @@ end
shortest_path_maximizer;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
cost,
),
Expand Down
6 changes: 3 additions & 3 deletions test/ranking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
ranking;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
loss=mse_kw,
error_function=hamming_distance,
Expand Down Expand Up @@ -170,7 +170,7 @@ end
ranking;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
),
error_function=hamming_distance,
Expand Down Expand Up @@ -303,7 +303,7 @@ end
ranking;
Ω=half_square_norm,
Ω_grad=identity_kw,
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
),
cost,
),
Expand Down

0 comments on commit 343c95e

Please sign in to comment.