From 9aa4462b1adc73491155a1a141b511b941bf7739 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 06:13:43 +0100 Subject: [PATCH 1/6] reinstate all enzyme tests --- Project.toml | 2 -- test/ext_enzyme/enzyme.jl | 18 +++++++++--------- test/runtests.jl | 5 ++--- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 630a956415..3e5224889d 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -61,7 +60,6 @@ OneHotArrays = "0.2.4" Optimisers = "0.4.1" Preferences = "1" ProgressLogging = "0.1" -Reactant = "0.2.16" Reexport = "1.0" Setfield = "1.1" SpecialFunctions = "2.1.2" diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 7b772fcca4..cc67ac09df 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -12,14 +12,14 @@ using Enzyme: Enzyme, Duplicated, Const, Active (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - # (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), + (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), - # (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # Passes on 1.10, fails on 1.11 with MethodError: no method matching function_attributes(::LLVM.UserOperandSet) + (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), - # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # AssertionError: Base.isconcretetype(typ) - # (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # AssertionError: Base.isconcretetype(typ) + (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), + (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), ] for (model, x, name) in models_xs @@ -36,11 +36,11 @@ end end models_xs = [ - # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs diff --git a/test/runtests.jl b/test/runtests.jl index ed6189fce7..365b0aaa1a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,6 @@ using Pkg using FiniteDifferences: FiniteDifferences using Functors: fmapstructure_with_path -using Reactant - ## Uncomment below to change the default test settings # ENV["FLUX_TEST_AMDGPU"] = "true" # ENV["FLUX_TEST_CUDA"] = "true" @@ -23,9 +21,10 @@ using Reactant # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" # ENV["FLUX_TEST_ENZYME"] = "false" +# ENV["FLUX_TEST_REACTANT"] = "true" const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true" -const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true" +const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", "false") == "true" if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT Pkg.add("Enzyme") From 31b08a685f82b6e85dbfb62f034d7e9789c6f4a2 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 06:26:21 +0100 Subject: [PATCH 2/6] some tests crash --- test/ext_enzyme/enzyme.jl | 14 +++++++------- test/runtests.jl | 7 ++++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index cc67ac09df..0b159b393a 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -17,8 +17,8 @@ using Enzyme: Enzyme, Duplicated, Const, Active (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), - (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), + # (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), + # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), ] @@ -36,11 +36,11 @@ end end models_xs = [ - (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs diff --git a/test/runtests.jl b/test/runtests.jl index 365b0aaa1a..f6b76024d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,12 @@ using Functors: fmapstructure_with_path # ENV["FLUX_TEST_REACTANT"] = "true" const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true" -const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", "false") == "true" + +# Reactant will automatically select a GPU backend, if available, and TPU backend, if available. +# Otherwise it will fall back to CPU. +const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", + VERSION < v"1.12-" && +"false") == "true" if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT Pkg.add("Enzyme") From 68e428bcfdbba1630de6db227f86026bfbc462f5 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 06:49:38 +0100 Subject: [PATCH 3/6] cleanup --- test/Project.toml | 1 - test/ext_enzyme/enzyme.jl | 4 ++-- test/ext_reactant/test_utils_reactant.jl | 13 +++++++++++++ test/runtests.jl | 6 +++--- test/test_utils.jl | 23 +++++++++++++---------- 5 files changed, 31 insertions(+), 16 deletions(-) create mode 100644 test/ext_reactant/test_utils_reactant.jl diff --git a/test/Project.toml b/test/Project.toml index faba412fd0..f4161e8af3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 0b159b393a..edff72f3fe 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -25,7 +25,7 @@ using Enzyme: Enzyme, Duplicated, Const, Active for (model, x, name) in models_xs @testset "Enzyme grad check $name" begin println("testing $name with Enzyme") - test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true) + test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true) end end end @@ -46,7 +46,7 @@ end for (model, x, name) in models_xs @testset "check grad $name" begin println("testing $name") - test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true) + test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true) end end end diff --git a/test/ext_reactant/test_utils_reactant.jl b/test/ext_reactant/test_utils_reactant.jl new file mode 100644 index 0000000000..6b3cec53d8 --- /dev/null +++ b/test/ext_reactant/test_utils_reactant.jl @@ -0,0 +1,13 @@ +# These are used only in test_utils.jl but cannot leave there +# because Reactant is only optionally loaded and the macros fail when it is not loaded. + +function reactant_withgradient(f, x...) + y, g = Reactant.@jit enzyme_withgradient(f, x...) + return y, g +end + +function reactant_loss(loss, x...) + l = Reactant.@jit loss(x...) + @test l isa Reactant.ConcreteRNumber + return l +end diff --git a/test/runtests.jl b/test/runtests.jl index f6b76024d2..3f86a54341 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,15 +21,14 @@ using Functors: fmapstructure_with_path # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" # ENV["FLUX_TEST_ENZYME"] = "false" -# ENV["FLUX_TEST_REACTANT"] = "true" +ENV["FLUX_TEST_REACTANT"] = "false" const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true" # Reactant will automatically select a GPU backend, if available, and TPU backend, if available. # Otherwise it will fall back to CPU. const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", - VERSION < v"1.12-" && -"false") == "true" + VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true" if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT Pkg.add("Enzyme") @@ -187,6 +186,7 @@ end if FLUX_TEST_REACTANT @testset "Reactant" begin + include("ext_reactant/test_utils_reactant.jl") include("ext_reactant/reactant.jl") end else diff --git a/test/test_utils.jl b/test/test_utils.jl index 2697213a77..fc598862d0 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -45,20 +45,24 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) end end +# By default, this computes the gradients on cpu using the default AD (Zygote) +# and compares them with finite differences. +# Changing the arguments, you can assume the cpu Zygote gradients as the ground truth +# and test other scenarios. function test_gradients( f, xs...; rtol=1e-4, atol=1e-4, test_gpu = false, test_reactant = false, + test_enzyme = false, test_grad_f = true, test_grad_x = true, compare_finite_diff = true, - compare_enzyme = false, loss = (f, xs...) -> mean(f(xs...)), ) - if !test_gpu && !compare_finite_diff && !compare_enzyme && !test_reactant + if !test_gpu && !compare_finite_diff && !test_enzyme && !test_reactant error("You should either compare numerical gradients methods or CPU vs GPU.") end @@ -79,8 +83,7 @@ function test_gradients( cpu_dev = cpu_device() xs_re = xs |> reactant_dev f_re = f |> reactant_dev - l_re = Reactant.@jit loss(f_re, xs_re...) - @test l_re isa Reactant.ConcreteRNumber + l_re = reactant_loss(loss, f_re, xs_re...) @test l ≈ l_re rtol=rtol atol=atol end @@ -97,7 +100,7 @@ function test_gradients( check_equal_leaves(g, g_fd; rtol, atol) end - if compare_enzyme + if test_enzyme y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...) @test y ≈ y_ez rtol=rtol atol=atol check_equal_leaves(g, g_ez; rtol, atol) @@ -111,9 +114,9 @@ function test_gradients( check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) end - if test_reactant + if compare_reactant # Enzyme gradient with respect to input on Reactant. - y_re, g_re = Reactant.@jit enzyme_withgradient((xs...) -> loss(f_re, xs...), xs_re...) + y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...) @test y ≈ y_re rtol=rtol atol=atol check_equal_leaves(g_re |> cpu_dev, g; rtol, atol) end @@ -133,7 +136,7 @@ function test_gradients( check_equal_leaves(g, g_fd; rtol, atol) end - if compare_enzyme + if test_enzyme y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f) @test y ≈ y_ez rtol=rtol atol=atol check_equal_leaves(g, g_ez; rtol, atol) @@ -147,9 +150,9 @@ function test_gradients( check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) end - if test_reactant + if compare_reactant # Enzyme gradient with respect to input on Reactant. - y_re, g_re = Reactant.@jit enzyme_withgradient(f -> loss(f, xs_re...), f_re) + y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re) @test y ≈ y_re rtol=rtol atol=atol check_equal_leaves(g_re |> cpu_dev, g; rtol, atol) end From a4ab3a942092091f22ab426cf0ba613ed6bfee11 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 06:50:35 +0100 Subject: [PATCH 4/6] cleanup --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3f86a54341..3a111d0f04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,7 +21,7 @@ using Functors: fmapstructure_with_path # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" # ENV["FLUX_TEST_ENZYME"] = "false" -ENV["FLUX_TEST_REACTANT"] = "false" +# ENV["FLUX_TEST_REACTANT"] = "false" const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true" From 5c3d72d89140ca258d6ebd79d1e67d6815c84618 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 07:02:09 +0100 Subject: [PATCH 5/6] fix --- test/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index fc598862d0..b8f4fec717 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -114,7 +114,7 @@ function test_gradients( check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) end - if compare_reactant + if test_reactant # Enzyme gradient with respect to input on Reactant. y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...) @test y ≈ y_re rtol=rtol atol=atol @@ -150,7 +150,7 @@ function test_gradients( check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) end - if compare_reactant + if test_reactant # Enzyme gradient with respect to input on Reactant. y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re) @test y ≈ y_re rtol=rtol atol=atol From 363a043e154fd0dbf07ed1d1ebdfd58c4be75792 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 6 Jan 2025 07:36:52 +0100 Subject: [PATCH 6/6] cleanup --- test/ext_enzyme/enzyme.jl | 14 +++++++------- test/runtests.jl | 12 +++++++----- test/test_utils.jl | 3 +++ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index edff72f3fe..b3a1cf8f8c 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -17,8 +17,8 @@ using Enzyme: Enzyme, Duplicated, Const, Active (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - # (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), - # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), + (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), + (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), ] @@ -36,11 +36,11 @@ end end models_xs = [ - # (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - # (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - # (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - # (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs diff --git a/test/runtests.jl b/test/runtests.jl index 3a111d0f04..8c375a3f22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,11 +35,6 @@ if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT using Enzyme: Enzyme end -if FLUX_TEST_REACTANT - Pkg.add("Reactant") - using Reactant: Reactant -end - include("test_utils.jl") # for test_gradients Random.seed!(0) @@ -185,6 +180,13 @@ end end if FLUX_TEST_REACTANT + ## This Pg.add has to be done after Pkg.add("CUDA") otherwise CUDA.jl + ## will not be functional and complain with: + # ┌ Error: CUDA.jl could not find an appropriate CUDA runtime to use. + # │ + # │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present. + Pkg.add("Reactant") + using Reactant: Reactant @testset "Reactant" begin include("ext_reactant/test_utils_reactant.jl") include("ext_reactant/reactant.jl") diff --git a/test/test_utils.jl b/test/test_utils.jl index b8f4fec717..d8e10f3bf6 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -37,6 +37,7 @@ end function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) fmapstructure_with_path(a, b) do kp, x, y + # @show kp if x isa AbstractArray @test x ≈ y rtol=rtol atol=atol elseif x isa Number @@ -66,6 +67,8 @@ function test_gradients( error("You should either compare numerical gradients methods or CPU vs GPU.") end + Flux.trainmode!(f) # for layers like BatchNorm + ## Let's make sure first that the forward pass works. l = loss(f, xs...) @test l isa Number