diff --git a/.github/workflows/Benchmark.yml b/.github/workflows/Benchmark.yml index 224535a5..e3e9bf5a 100644 --- a/.github/workflows/Benchmark.yml +++ b/.github/workflows/Benchmark.yml @@ -16,7 +16,7 @@ jobs: version: '1' arch: x64 show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - name: Install dependencies shell: julia --color=yes {0} diff --git a/.github/workflows/CI-ENH.yml b/.github/workflows/CI-ENH.yml index a73cfbd1..534bdeed 100644 --- a/.github/workflows/CI-ENH.yml +++ b/.github/workflows/CI-ENH.yml @@ -45,7 +45,7 @@ jobs: version: ${{ matrix.version }} arch: ${{ matrix.arch }} show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 with: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2ede42bc..551ec32b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -49,7 +49,7 @@ jobs: version: ${{ matrix.version }} arch: ${{ matrix.arch }} show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 with: diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 0838843c..85973a5c 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -16,7 +16,7 @@ jobs: version: '1' arch: x64 show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - name: Pkg.add shell: julia --color=yes {0} run: | diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index 16552aff..239dd84d 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -25,7 +25,7 @@ jobs: version: '1' arch: x64 show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - name: Configure doc environment shell: julia --project=docs --color=yes {0} run: | diff --git a/.github/workflows/Formatter.yml b/.github/workflows/Formatter.yml index e92f590b..cba343c8 100644 --- a/.github/workflows/Formatter.yml +++ b/.github/workflows/Formatter.yml @@ -16,7 +16,7 @@ jobs: version: '1' arch: x64 show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - name: Install JuliaFormatter and format shell: julia --color=yes {0} run: | diff --git a/.github/workflows/Invalidations.yml b/.github/workflows/Invalidations.yml index d50f1bc4..3fc14185 100644 --- a/.github/workflows/Invalidations.yml +++ b/.github/workflows/Invalidations.yml @@ -19,7 +19,7 @@ jobs: version: '1' arch: x64 show-versioninfo: true - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr diff --git a/Project.toml b/Project.toml index b72c8ed3..39609642 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ContinuousNormalizingFlows" uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac" authors = ["Hossein Pourbozorg and contributors"] -version = "0.22.1" +version = "0.22.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -50,7 +50,7 @@ ComputationalResources = "0.3" DataFrames = "1" Dates = "1" DifferentialEquations = "7" -DifferentiationInterface = "0.1, 0.2" +DifferentiationInterface = "0.1, 0.2, 0.3" Distributions = "0.25" DistributionsAD = "0.6" FillArrays = "1" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 39b31a3c..fdc6ed50 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -12,7 +12,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ADTypes = "0.2, 1" BenchmarkTools = "1" ComponentArrays = "0.15" -DifferentiationInterface = "0.1, 0.2" +DifferentiationInterface = "0.1, 0.2, 0.3" Lux = "0.5" PkgBenchmark = "0.2" StableRNGs = "1" diff --git a/src/base_icnf.jl b/src/base_icnf.jl index d9ff817b..a266b201 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -233,9 +233,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -256,9 +256,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -278,9 +278,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -301,9 +301,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -324,9 +324,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -348,9 +348,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -372,9 +372,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -397,9 +397,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -512,28 +512,21 @@ end @inline function make_ode_func( icnf::AbstractICNF{T, CM, INPLACE}, mode::Mode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVecOrMat{T}, ) where {T <: AbstractFloat, CM, INPLACE} function ode_func_op(u, p, t) - augmented_f(u, p, t, icnf, mode, nn, ϵ) + augmented_f(u, p, t, icnf, mode, nn, st, ϵ) end function ode_func_ip(du, u, p, t) - augmented_f(du, u, p, t, icnf, mode, nn, ϵ) + augmented_f(du, u, p, t, icnf, mode, nn, st, ϵ) end ifelse(INPLACE, ode_func_ip, ode_func_op) end -@inline function make_dyn_func(nn::Lux.StatefulLuxLayer, ps::Any) - function dyn_func(x) - LuxCore.apply(nn, x, ps) - end - - dyn_func -end - @inline function (icnf::AbstractICNF{T, CM, INPLACE, false})( xs::AbstractVecOrMat, ps::Any, diff --git a/src/icnf.jl b/src/icnf.jl index 908169a2..c3cd6936 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -116,16 +116,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVectorMode, false}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = AbstractDifferentiation.value_and_jacobian( - icnf.differentiation_backend, - make_dyn_func(nn, p), - z, - ) + ż, J = AbstractDifferentiation.value_and_jacobian(icnf.differentiation_backend, snn, z) l̇ = -LinearAlgebra.tr(only(J)) vcat(ż, l̇) end @@ -137,16 +135,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVectorMode, true}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = AbstractDifferentiation.value_and_jacobian( - icnf.differentiation_backend, - make_dyn_func(nn, p), - z, - ) + ż, J = AbstractDifferentiation.value_and_jacobian(icnf.differentiation_backend, snn, z) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.tr(only(J)) nothing @@ -158,16 +154,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVectorMode, false}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ) + ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.autodiff_backend, z) l̇ = -LinearAlgebra.tr(J) vcat(ż, l̇) end @@ -179,16 +173,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVectorMode, true}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ) + ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.autodiff_backend, z) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.tr(J) nothing @@ -200,12 +192,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:MatrixMode, false}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, make_dyn_func(nn, p), z) + ż, J = jacobian_batched(icnf, snn, z) l̇ = -transpose(LinearAlgebra.tr.(J)) vcat(ż, l̇) end @@ -217,12 +211,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:MatrixMode, true}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, make_dyn_func(nn, p), z) + ż, J = jacobian_batched(icnf, snn, z) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -(LinearAlgebra.tr.(J)) nothing @@ -234,14 +230,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, VJ = AbstractDifferentiation.value_and_pullback_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ϵJ = only(VJ(ϵ)) @@ -266,14 +264,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, VJ = AbstractDifferentiation.value_and_pullback_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ϵJ = only(VJ(ϵ)) @@ -298,14 +298,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż_JV = AbstractDifferentiation.value_and_pushforward_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ż, Jϵ = ż_JV(ϵ) @@ -331,14 +333,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż_JV = AbstractDifferentiation.value_and_pushforward_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ż, Jϵ = ż_JV(ϵ) @@ -364,17 +368,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) l̇ = -LinearAlgebra.dot(ϵJ, ϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -396,17 +397,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -428,17 +426,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) l̇ = -LinearAlgebra.dot(ϵ, Jϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -460,17 +456,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -492,17 +486,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) l̇ = -sum(ϵJ .* ϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -528,17 +519,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z @@ -560,17 +548,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) l̇ = -sum(ϵ .* Jϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -596,17 +582,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z diff --git a/src/utils.jl b/src/utils.jl index 33559c85..08108daf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ @inline function jacobian_batched( icnf::AbstractICNF{T, <:DIVecJacMatrixMode}, - f::Function, + f::Lux.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, ) where {T} y, VJ = DifferentiationInterface.value_and_pullback_split(f, icnf.autodiff_backend, xs) @@ -17,7 +17,7 @@ end @inline function jacobian_batched( icnf::AbstractICNF{T, <:DIJacVecMatrixMode}, - f::Function, + f::Lux.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, ) where {T} y = f(xs) @@ -34,7 +34,7 @@ end @inline function jacobian_batched( icnf::AbstractICNF{T, <:DIMatrixMode}, - f::Function, + f::Lux.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, ) where {T} y, J = DifferentiationInterface.value_and_jacobian(f, icnf.autodiff_backend, xs) diff --git a/test/Project.toml b/test/Project.toml index d70f05b3..744a2e04 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ CUDA = "5" ComponentArrays = "0.15" ComputationalResources = "0.3" DataFrames = "1" -DifferentiationInterface = "0.1, 0.2" +DifferentiationInterface = "0.1, 0.2, 0.3" Distributions = "0.25" ForwardDiff = "0.10" JET = "0.8"