diff --git a/docs/make.jl b/docs/make.jl index dcc31a8f87..54741bba2d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -48,6 +48,8 @@ pages = [ "manual/distributed_utils.md", "manual/nested_autodiff.md", "manual/compiling_lux_models.md", + "manual/exporting_to_jax.md", + "manual/nested_autodiff_reactant.md" ], "API Reference" => [ "Lux" => [ diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 12c1c8468a..f1cece5bef 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -322,6 +322,10 @@ export default defineConfig({ text: "Exporting Lux Models to Jax", link: "/manual/exporting_to_jax", }, + { + text: "Nested AutoDiff", + link: "/manual/nested_autodiff_reactant", + } ], }, { diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 19270b5be6..d92b9f8526 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -1,16 +1,16 @@ # [Nested Automatic Differentiation](@id nested_autodiff) -!!! note - - This is a relatively new feature in Lux, so there might be some rough edges. If you - encounter any issues, please let us know by opening an issue on the - [GitHub repository](https://github.com/LuxDL/Lux.jl). - In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed. -!!! tip +!!! tip "Reactant Support" + + Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using + Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant) + manual. + +!!! tip "Disabling Nested AD Switching" Don't wan't Lux to do this switching for you? You can disable it by setting the `automatic_nested_ad_switching` Preference to `false`. diff --git a/docs/src/manual/nested_autodiff_reactant.md b/docs/src/manual/nested_autodiff_reactant.md new file mode 100644 index 0000000000..70e5f7dafc --- /dev/null +++ b/docs/src/manual/nested_autodiff_reactant.md @@ -0,0 +1,72 @@ +# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant) + +We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614). + +```@example nested_ad_reactant +using Reactant, Enzyme, Lux, Random, LinearAlgebra + +const xdev = reactant_device(; force=true) +const cdev = cpu_device() + +# XXX: We need to be able to compile this with a for-loop else tracing time will scale +# proportionally to the number of elements in the input. +function ∇potential(potential, x) + dxs = stack(onehot(x)) + ∇p = similar(x) + colons = [Colon() for _ in 1:ndims(x)] + @trace for i in 1:length(x) + dxᵢ = dxs[colons..., i] + res = only(Enzyme.autodiff( + Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ) + )) + @allowscalar ∇p[i] = res[i] + end + return ∇p +end + +# function ∇²potential(potential, x) +# dxs = onehot(x) +# ∇²p = similar(x) +# for i in eachindex(dxs) +# dxᵢ = dxs[i] +# res = only(Enzyme.autodiff( +# Enzyme.set_abi(Forward, Reactant.ReactantABI), +# ∇potential, Const(potential), Duplicated(x, dxᵢ) +# )) +# @allowscalar ∇²p[i] = res[i] +# end +# return ∇²p +# end + +struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential} + potential::P +end + +function (potential::PotentialNet)(x, ps, st) + pnet = StatefulLuxLayer{true}(potential.potential, ps, st) + return ∇potential(pnet, x), pnet.st + # return ∇²potential(pnet, x), pnet.st +end + +model = PotentialNet(Dense(5 => 5, gelu)) +ps, st = Lux.setup(Random.default_rng(), model) |> xdev + +x_ra = randn(Float32, 5, 3) |> xdev + +@code_hlo model(x_ra, ps, st) + +1 + 1 + +model_compiled = @compile model(x_ra, ps, st) +model_compiled(x_ra, ps, st) + +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) + +function enzyme_gradient(model, x, ps, st) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st) + ) +end + +@jit enzyme_gradient(model, x_ra, ps, st) +``` diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 14acc442ef..f62527a7bf 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -6,7 +6,7 @@ using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, Tr using Setfield: @set! using Static: True, False -using Lux: Lux, LuxOps, Training, Utils +using Lux: Lux, LuxOps, Training, Utils, StatefulLuxLayer using Lux.Training: TrainingBackendCache, ReactantBackend Lux.is_extension_loaded(::Val{:Reactant}) = true diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 6d79f2b60f..d7d7fba275 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -7,3 +7,16 @@ Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::NamedTuple) return ps.weight[:, x], st end + +# Tracing Patches +function Reactant.make_tracer( + seen, @nospecialize(model::StatefulLuxLayer), @nospecialize(path), mode; kwargs... +) + return StatefulLuxLayer( + model.model, + Reactant.make_tracer(seen, model.ps, path, mode; kwargs...), + Reactant.make_tracer(seen, model.st, path, mode; kwargs...), + Reactant.make_tracer(seen, model.st_any, path, mode; kwargs...), + model.fixed_state_type + ) +end diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 5b095d97d3..8813766c0f 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.2.2" +version = "1.2.3" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/ext/LuxCoreReactantExt.jl b/lib/LuxCore/ext/LuxCoreReactantExt.jl index f6e7770964..2389e05b90 100644 --- a/lib/LuxCore/ext/LuxCoreReactantExt.jl +++ b/lib/LuxCore/ext/LuxCoreReactantExt.jl @@ -10,6 +10,11 @@ function Reactant.make_tracer( return model end +function Reactant.traced_type_inner( + T::Type{<:AbstractLuxLayer}, seen, mode::Reactant.TraceMode, track_numbers::Type) + return T +end + LuxCore.replicate(rng::Reactant.TracedRNG) = copy(rng) LuxCore.replicate(rng::Reactant.ConcreteRNG) = copy(rng)