From 1480e796e5347f3de63af3b44d35d3810706cd0d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 28 Nov 2024 08:51:05 +0100 Subject: [PATCH] Explicit Enzyme rules on Enzyme 0.13 (#350) * Bump Enzyme to v0.13 * Mark more broken Enzyme tests * Bump Julia compat entry to 1.10 (#342) * Bump minimum Julia version to 1.10 * Use 'min' in CI * Bump versions of GHA (julia-actions needs to be v2 for 'min') * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Mark Enzyme tests as unbroken * Explicit Enzyme rules * Mark a few tests as broken * Remove batch reverse mode tests completely * Generic Exception? * Skip failing tests * Simplify Enzyme tests --------- Co-authored-by: Markus Hauru Co-authored-by: Penelope Yong Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- .github/dependabot.yml | 7 + .github/workflows/AD.yml | 27 ++-- .github/workflows/Docs.yml | 4 +- .github/workflows/Format.yml | 4 +- .github/workflows/Interface.yml | 20 +-- Project.toml | 15 +- ext/BijectorsDistributionsADExt.jl | 48 ++----- ext/BijectorsEnzymeCoreExt.jl | 217 +++++++++++++++++++++++++++++ ext/BijectorsEnzymeExt.jl | 18 --- ext/BijectorsForwardDiffExt.jl | 9 +- ext/BijectorsLazyArraysExt.jl | 9 +- ext/BijectorsMooncakeExt.jl | 11 +- ext/BijectorsReverseDiffExt.jl | 127 ++++++----------- ext/BijectorsTrackerExt.jl | 77 ++++------ ext/BijectorsZygoteExt.jl | 102 +++++--------- src/Bijectors.jl | 34 ----- src/bijectors/product_bijector.jl | 12 +- test/Project.toml | 10 +- test/ad/enzyme.jl | 51 +++++++ test/ad/utils.jl | 64 +++++---- test/bijectors/product_bijector.jl | 59 +++----- test/runtests.jl | 15 +- 22 files changed, 486 insertions(+), 454 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 ext/BijectorsEnzymeCoreExt.jl delete mode 100644 ext/BijectorsEnzymeExt.jl create mode 100644 test/ad/enzyme.jl diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..d60f0707 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "monthly" diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 7d0aa4ae..74a7f59d 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -6,6 +6,12 @@ on: - master pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: runs-on: ${{ matrix.os }} @@ -13,13 +19,12 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - 'min' + - 'lts' - '1' os: - ubuntu-latest - macOS-latest - arch: - - x64 AD: - Enzyme - ForwardDiff @@ -27,21 +32,13 @@ jobs: - Tracker - ReverseDiff - Zygote - exclude: - - version: 1.6 - AD: Mooncake - # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see - # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - - version: 1.6 - AD: Enzyme steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 env: GROUP: AD AD: ${{ matrix.AD }} diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml index e5f79a0f..1f53e736 100644 --- a/.github/workflows/Docs.yml +++ b/.github/workflows/Docs.yml @@ -18,8 +18,8 @@ jobs: docs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@latest + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: '1' - name: Install dependencies diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml index 5847fee5..965f2ea6 100644 --- a/.github/workflows/Format.yml +++ b/.github/workflows/Format.yml @@ -20,8 +20,8 @@ jobs: format: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@latest + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: 1 - name: Format code diff --git a/.github/workflows/Interface.yml b/.github/workflows/Interface.yml index b305124e..37e692b6 100644 --- a/.github/workflows/Interface.yml +++ b/.github/workflows/Interface.yml @@ -7,6 +7,12 @@ on: - master pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: runs-on: ${{ matrix.os }} @@ -14,20 +20,18 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - 'min' + - 'lts' - '1' os: - ubuntu-latest - macOS-latest - arch: - - x64 steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 env: GROUP: Interface diff --git a/Project.toml b/Project.toml index 70cfcc48..cd70f097 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,8 @@ version = "0.14.2" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -18,14 +16,12 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -36,7 +32,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BijectorsDistributionsADExt = "DistributionsAD" -BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"] +BijectorsEnzymeCoreExt = "EnzymeCore" BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsMooncakeExt = "Mooncake" @@ -46,15 +42,12 @@ BijectorsZygoteExt = "Zygote" [compat] ArgCheck = "1, 2" -ChainRules = "1" ChainRulesCore = "0.10.11, 1" ChangesOfVariables = "0.1" -Compat = "3.46, 4.2" Distributions = "0.25.33" DistributionsAD = "0.6" DocStringExtensions = "0.9" -Enzyme = "0.12.22" -EnzymeCore = "0.7.8" +EnzymeCore = "0.8.4" ForwardDiff = "0.10" Functors = "0.1, 0.2, 0.3, 0.4, 0.5" InverseFunctions = "0.1" @@ -64,17 +57,15 @@ LogExpFunctions = "0.3.3" MappedArrays = "0.2.2, 0.3, 0.4" Mooncake = "0.4.19" Reexport = "0.2, 1" -Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.15, 2" Statistics = "1" Tracker = "0.2" Zygote = "0.6.63" -julia = "1.6" +julia = "1.10" [extras] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" diff --git a/ext/BijectorsDistributionsADExt.jl b/ext/BijectorsDistributionsADExt.jl index 69dce30a..4be6b489 100644 --- a/ext/BijectorsDistributionsADExt.jl +++ b/ext/BijectorsDistributionsADExt.jl @@ -1,38 +1,20 @@ module BijectorsDistributionsADExt -if isdefined(Base, :get_extension) - using Bijectors - using Bijectors: LinearAlgebra - using Bijectors.Distributions: AbstractMvLogNormal - using DistributionsAD: - TuringDirichlet, - TuringWishart, - TuringInverseWishart, - FillVectorOfUnivariate, - FillMatrixOfUnivariate, - MatrixOfUnivariate, - FillVectorOfMultivariate, - VectorOfMultivariate, - TuringScalMvNormal, - TuringDiagMvNormal, - TuringDenseMvNormal -else - using ..Bijectors - using ..Bijectors: LinearAlgebra - using ..Bijectors.Distributions: AbstractMvLogNormal - using ..DistributionsAD: - TuringDirichlet, - TuringWishart, - TuringInverseWishart, - FillVectorOfUnivariate, - FillMatrixOfUnivariate, - MatrixOfUnivariate, - FillVectorOfMultivariate, - VectorOfMultivariate, - TuringScalMvNormal, - TuringDiagMvNormal, - TuringDenseMvNormal -end +using Bijectors +using Bijectors: LinearAlgebra +using Bijectors.Distributions: AbstractMvLogNormal +using DistributionsAD: + TuringDirichlet, + TuringWishart, + TuringInverseWishart, + FillVectorOfUnivariate, + FillMatrixOfUnivariate, + MatrixOfUnivariate, + FillVectorOfMultivariate, + VectorOfMultivariate, + TuringScalMvNormal, + TuringDiagMvNormal, + TuringDenseMvNormal # Bijectors diff --git a/ext/BijectorsEnzymeCoreExt.jl b/ext/BijectorsEnzymeCoreExt.jl new file mode 100644 index 00000000..5e3f5b89 --- /dev/null +++ b/ext/BijectorsEnzymeCoreExt.jl @@ -0,0 +1,217 @@ +module BijectorsEnzymeCoreExt + +using EnzymeCore: + Active, + Const, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + EnzymeRules +using Bijectors: find_alpha + +# Compute a tuple of partial derivatives wrt non-`Const` arguments +# and `nothing`s for `Const` arguments +function ∂find_alpha( + Ω::Real, + wt_y::Union{Const,Active,Duplicated,BatchDuplicated}, + wt_u_hat::Union{Const,Active,Duplicated,BatchDuplicated}, + b::Union{Const,Active,Duplicated,BatchDuplicated}, +) + # We reuse the following term in the computation of the derivatives + Ωpb = Ω + b.val + c = wt_u_hat.val * sech(Ωpb)^2 + cp1 = c + 1 + + ∂Ω_∂wt_y = wt_y isa Const ? nothing : oneunit(wt_y.val) / cp1 + ∂Ω_∂wt_u_hat = wt_u_hat isa Const ? nothing : -tanh(Ωpb) / cp1 + ∂Ω_∂b = b isa Const ? nothing : -c / cp1 + + return (∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b) +end + +# `muladd` for partial derivatives that can deal with `nothing` derivatives +_muladd_partial(::Nothing, ::Const, x::Union{Real,Tuple{Vararg{Real}},Nothing}) = x +_muladd_partial(x::Real, y::Duplicated, z::Real) = muladd(x, y.dval, z) +_muladd_partial(x::Real, y::Duplicated, ::Nothing) = x * y.dval +function _muladd_partial(x::Real, y::BatchDuplicated{<:Real,N}, z::NTuple{N,Real}) where {N} + let x = x + map((a, b) -> muladd(x, a, b), y.dval, z) + end +end +_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval) + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(find_alpha)}, + ::Type{RT}, + wt_y::Union{Const,Duplicated,BatchDuplicated}, + wt_u_hat::Union{Const,Duplicated,BatchDuplicated}, + b::Union{Const,Duplicated,BatchDuplicated}, +) where {RT<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}} + # Check that the types of the activities are consistent + if !( + RT <: Union{Const,Duplicated,DuplicatedNoNeed} && + wt_y isa Union{Const,Duplicated} && + wt_u_hat isa Union{Const,Duplicated} && + b isa Union{Const,Duplicated} + ) && !( + RT <: Union{Const,BatchDuplicated,BatchDuplicatedNoNeed} && + wt_y isa Union{Const,BatchDuplicated} && + wt_u_hat isa Union{Const,BatchDuplicated} && + b isa Union{Const,BatchDuplicated} + ) + throw(ArgumentError("inconsistent activities")) + end + + # Early exit: Neither primal nor shadow needed + if !EnzymeRules.needs_primal(config) && !EnzymeRules.needs_shadow(config) + return nothing + end + + # Compute primal value + Ω = find_alpha(wt_y.val, wt_u_hat.val, b.val) + + # Early exit if no derivatives are requested + if !EnzymeRules.needs_shadow(config) + return Ω + end + + Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const + # Trivial case: All partial derivatives are 0 + if EnzymeRules.width(config) == 1 + zero(Ω) + else + ntuple(Zero(Ω), Val(EnzymeRules.width(config))) + end + else + # In all other cases we have to compute the partial derivatives + ∂Ω_∂wt_y, ∂Ω_∂wt_u_hat, ∂Ω_∂b = ∂find_alpha(Ω, wt_y, wt_u_hat, b) + _muladd_partial( + ∂Ω_∂wt_y, + wt_y, + _muladd_partial(∂Ω_∂wt_u_hat, wt_u_hat, _muladd_partial(∂Ω_∂b, b, nothing)), + ) + end + @assert (EnzymeRules.width(config) == 1 && Ω̇ isa Real) || + (EnzymeRules.width(config) > 1 && Ω̇ isa NTuple{EnzymeRules.width(config),Real}) + + if EnzymeRules.needs_primal(config) + if EnzymeRules.width(config) == 1 + return Duplicated(Ω, Ω̇) + else + return BatchDuplicated(Ω, Ω̇) + end + else + return Ω̇ + end +end + +struct Zero{T} + x::T +end +(f::Zero)(_) = zero(f.x) + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + ::Const{typeof(find_alpha)}, + ::Type{RT}, + wt_y::Union{Const,Active}, + wt_u_hat::Union{Const,Active}, + b::Union{Const,Active}, +) where {RT<:Union{Const,Active}} + # Only compute the the original return value if it is actually needed + Ω = + if EnzymeRules.needs_primal(config) || + EnzymeRules.needs_shadow(config) || + !(RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const)) + find_alpha(wt_y.val, wt_u_hat.val, b.val) + else + nothing + end + + tape = if RT <: Const || (wt_y isa Const && wt_u_hat isa Const && b isa Const) + # Trivial case: No differentiation or all derivatives are 0 + # Thus no tape is needed + nothing + else + # Derivatives with respect to at least one argument needed + # They are computed in the reverse pass, and therefore the original return is cached + # In principle, the partial derivatives could be computed here and be cached + # But Enzyme only executes the reverse pass once, + # thus this would not increase efficiency but instead more values would have to be cached + Ω + end + + # Ensure that we follow the interface requirements of `augmented_primal` + primal = EnzymeRules.needs_primal(config) ? Ω : nothing + shadow = if EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) === 1 + zero(Ω) + else + ntuple(Zero(Ω), Val(EnzymeRules.width(config))) + end + else + nothing + end + + return EnzymeRules.AugmentedReturn(primal, shadow, tape) +end + +struct ZeroOrNothing{N} end +(::ZeroOrNothing)(::Const) = nothing +(::ZeroOrNothing{1})(x::Active) = zero(x.val) +(::ZeroOrNothing{N})(x::Active) where {N} = ntuple(Zero(x.val), Val{N}()) + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(find_alpha)}, + ::Type{<:Const}, + ::Nothing, + wt_y::Union{Const,Active}, + wt_u_hat::Union{Const,Active}, + b::Union{Const,Active}, +) + # Trivial case: Nothing to be differentiated (return activity is `Const`) + return map(ZeroOrNothing{EnzymeRules.width(config)}(), (wt_y, wt_u_hat, b)) +end +function EnzymeRules.reverse( + ::EnzymeRules.RevConfig, + ::Const{typeof(find_alpha)}, + ::Active, + ::Nothing, + ::Const, + ::Const, + ::Const, +) + # Trivial case: Tape does not exist sice all partial derivatives are 0 + return (nothing, nothing, nothing) +end + +struct MulPartialOrNothing{T<:Union{Real,Tuple{Vararg{Real}}}} + x::T +end +(::MulPartialOrNothing)(::Nothing) = nothing +(f::MulPartialOrNothing{<:Real})(∂f_∂x::Real) = ∂f_∂x * f.x +function (f::MulPartialOrNothing{<:NTuple{N,Real}})(∂f_∂x::Real) where {N} + return map(Base.Fix1(*, ∂f_∂x), f.x) +end + +function EnzymeRules.reverse( + ::EnzymeRules.RevConfig, + ::Const{typeof(find_alpha)}, + ΔΩ::Active, + Ω::Real, + wt_y::Union{Const,Active}, + wt_u_hat::Union{Const,Active}, + b::Union{Const,Active}, +) + # Tape must be `nothing` if all arguments are `Const` + @assert !(wt_y isa Const && wt_u_hat isa Const && b isa Const) + + # Compute partial derivatives + ∂Ω_∂xs = ∂find_alpha(Ω, wt_y, wt_u_hat, b) + return map(MulPartialOrNothing(ΔΩ.val), ∂Ω_∂xs) +end + +end # module diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl deleted file mode 100644 index 303fd92f..00000000 --- a/ext/BijectorsEnzymeExt.jl +++ /dev/null @@ -1,18 +0,0 @@ -module BijectorsEnzymeExt - -if isdefined(Base, :get_extension) - using Enzyme: @import_rrule, @import_frule - using Bijectors: find_alpha -else - using ..Enzyme: @import_rrule, @import_frule - using ..Bijectors: find_alpha -end - -@static if v"1.11.1" <= VERSION < v"1.12" - @warn "Bijectors and Enzyme do not work together on Julia $VERSION" -else - @import_rrule typeof(find_alpha) Real Real Real - @import_frule typeof(find_alpha) Real Real Real -end - -end # module diff --git a/ext/BijectorsForwardDiffExt.jl b/ext/BijectorsForwardDiffExt.jl index 29db3028..76ae0dd2 100644 --- a/ext/BijectorsForwardDiffExt.jl +++ b/ext/BijectorsForwardDiffExt.jl @@ -1,12 +1,7 @@ module BijectorsForwardDiffExt -if isdefined(Base, :get_extension) - using Bijectors: Bijectors, find_alpha - using ForwardDiff: ForwardDiff -else - using ..Bijectors: Bijectors, find_alpha - using ..ForwardDiff: ForwardDiff -end +using Bijectors: Bijectors, find_alpha +using ForwardDiff: ForwardDiff Bijectors._eps(::Type{<:ForwardDiff.Dual{<:Any,Real}}) = Bijectors._eps(Real) Bijectors._eps(::Type{<:ForwardDiff.Dual{<:Any,<:Integer}}) = Bijectors._eps(Real) diff --git a/ext/BijectorsLazyArraysExt.jl b/ext/BijectorsLazyArraysExt.jl index fa060470..03943fdc 100644 --- a/ext/BijectorsLazyArraysExt.jl +++ b/ext/BijectorsLazyArraysExt.jl @@ -1,12 +1,7 @@ module BijectorsLazyArraysExt -if isdefined(Base, :get_extension) - import Bijectors: maporbroadcast - using LazyArrays: LazyArrays -else - import ..Bijectors: maporbroadcast - using ..LazyArrays: LazyArrays -end +import Bijectors: maporbroadcast +using LazyArrays: LazyArrays function maporbroadcast(f, x1::LazyArrays.BroadcastArray, x...) return copy(f.(x1, x...)) diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index d7285bf6..0c2d8903 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,13 +1,8 @@ module BijectorsMooncakeExt -if isdefined(Base, :get_extension) - using Mooncake: - @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule - using Bijectors: find_alpha, ChainRulesCore -else - using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule - using ..Bijectors: find_alpha, ChainRulesCore -end +using Mooncake: + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule +using Bijectors: find_alpha, ChainRulesCore for P in [Float16, Float32, Float64] @from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P}) diff --git a/ext/BijectorsReverseDiffExt.jl b/ext/BijectorsReverseDiffExt.jl index 4489cb26..a0cbfb9d 100644 --- a/ext/BijectorsReverseDiffExt.jl +++ b/ext/BijectorsReverseDiffExt.jl @@ -1,94 +1,47 @@ module BijectorsReverseDiffExt -if isdefined(Base, :get_extension) - using ReverseDiff: - ReverseDiff, - @grad, - value, - track, - TrackedReal, - TrackedVector, - TrackedMatrix, - @grad_from_chainrules +using ReverseDiff: + ReverseDiff, + @grad, + value, + track, + TrackedReal, + TrackedVector, + TrackedMatrix, + @grad_from_chainrules - using Bijectors: - ChainRulesCore, - Elementwise, - SimplexBijector, - maphcat, - simplex_link_jacobian, - simplex_invlink_jacobian, - simplex_logabsdetjac_gradient, - Inverse - import Bijectors: - Bijectors, - _eps, - logabsdetjac, - _logabsdetjac_scale, - _simplex_bijector, - _simplex_inv_bijector, - replace_diag, - jacobian, - _inv_link_chol_lkj, - _link_chol_lkj, - _transform_ordered, - _transform_inverse_ordered, - find_alpha, - pd_from_lower, - lower_triangular, - upper_triangular, - transpose_eager, - cholesky_lower, - cholesky_upper +using Bijectors: + ChainRulesCore, + Elementwise, + SimplexBijector, + maphcat, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse +import Bijectors: + Bijectors, + _eps, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + _inv_link_chol_lkj, + _link_chol_lkj, + _transform_ordered, + _transform_inverse_ordered, + find_alpha, + pd_from_lower, + lower_triangular, + upper_triangular, + transpose_eager, + cholesky_lower, + cholesky_upper - using Bijectors.LinearAlgebra - using Bijectors.Compat: eachcol - using Bijectors.Distributions: LocationScale -else - using ..ReverseDiff: - ReverseDiff, - @grad, - value, - track, - TrackedReal, - TrackedVector, - TrackedMatrix, - @grad_from_chainrules - - using ..Bijectors: - ChainRulesCore, - Elementwise, - SimplexBijector, - maphcat, - simplex_link_jacobian, - simplex_invlink_jacobian, - simplex_logabsdetjac_gradient, - Inverse - import ..Bijectors: - Bijectors, - _eps, - logabsdetjac, - _logabsdetjac_scale, - _simplex_bijector, - _simplex_inv_bijector, - replace_diag, - jacobian, - _inv_link_chol_lkj, - _link_chol_lkj, - _transform_ordered, - _transform_inverse_ordered, - find_alpha, - pd_from_lower, - lower_triangular, - upper_triangular, - transpose_eager, - cholesky_lower, - cholesky_upper - - using ..Bijectors.LinearAlgebra - using ..Bijectors.Compat: eachcol - using ..Bijectors.Distributions: LocationScale -end +using Bijectors.LinearAlgebra +using Bijectors.Distributions: LocationScale _eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) diff --git a/ext/BijectorsTrackerExt.jl b/ext/BijectorsTrackerExt.jl index 9d04ea32..193bfb52 100644 --- a/ext/BijectorsTrackerExt.jl +++ b/ext/BijectorsTrackerExt.jl @@ -1,58 +1,29 @@ module BijectorsTrackerExt -if isdefined(Base, :get_extension) - using Tracker: - Tracker, - TrackedReal, - TrackedVector, - TrackedMatrix, - TrackedArray, - TrackedVecOrMat, - @grad, - track, - data, - param - - using Bijectors: - Elementwise, - SimplexBijector, - Inverse, - Stacked, - Bijectors, - ChainRulesCore, - LogExpFunctions, - _triu1_dim_from_length - - using Bijectors.LinearAlgebra - using Bijectors.Compat: eachcol - using Bijectors.Distributions: LocationScale -else - using ..Tracker: - Tracker, - TrackedReal, - TrackedVector, - TrackedMatrix, - TrackedArray, - TrackedVecOrMat, - @grad, - track, - data, - param - - using Bijectors: - Elementwise, - SimplexBijector, - Inverse, - Stacked, - Bijectors, - ChainRulesCore, - LogExpFunctions, - _triu1_dim_from_length - - using ..Bijectors.LinearAlgebra - using ..Bijectors.Compat: eachcol - using ..Bijectors.Distributions: LocationScale -end +using Tracker: + Tracker, + TrackedReal, + TrackedVector, + TrackedMatrix, + TrackedArray, + TrackedVecOrMat, + @grad, + track, + data, + param + +using Bijectors: + Elementwise, + SimplexBijector, + Inverse, + Stacked, + Bijectors, + ChainRulesCore, + LogExpFunctions, + _triu1_dim_from_length + +using Bijectors.LinearAlgebra +using Bijectors.Distributions: LocationScale Bijectors.maporbroadcast(f, x::TrackedArray...) = f.(x...) function Bijectors.maporbroadcast( diff --git a/ext/BijectorsZygoteExt.jl b/ext/BijectorsZygoteExt.jl index 1befd0f2..79195b88 100644 --- a/ext/BijectorsZygoteExt.jl +++ b/ext/BijectorsZygoteExt.jl @@ -1,76 +1,38 @@ module BijectorsZygoteExt -if isdefined(Base, :get_extension) - using Zygote: Zygote, @adjoint, pullback - using Bijectors: - Elementwise, - SimplexBijector, - simplex_link_jacobian, - simplex_invlink_jacobian, - simplex_logabsdetjac_gradient, - Inverse, - maphcat, - IrrationalConstants, - Distributions, - logabsdetjac, - _logabsdetjac_scale, - _simplex_bijector, - _simplex_inv_bijector, - replace_diag, - jacobian, - _transform_ordered, - _transform_inverse_ordered, - find_alpha, - pd_logpdf_with_trans, - istraining, - mapvcat, - eachcolmaphcat, - sumeachcol, - pd_link, - pd_from_lower, - lower_triangular, - upper_triangular, - getlogp +using Zygote: Zygote, @adjoint, pullback +using Bijectors: + Elementwise, + SimplexBijector, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse, + maphcat, + IrrationalConstants, + Distributions, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + _transform_ordered, + _transform_inverse_ordered, + find_alpha, + pd_logpdf_with_trans, + istraining, + mapvcat, + eachcolmaphcat, + sumeachcol, + pd_link, + pd_from_lower, + lower_triangular, + upper_triangular, + getlogp - using Bijectors.LinearAlgebra - using Bijectors.Compat: eachcol - using Bijectors.Distributions: LocationScale -else - using ..Zygote: Zygote, @adjoint, pullback - using ..Bijectors: - Elementwise, - SimplexBijector, - simplex_link_jacobian, - simplex_invlink_jacobian, - simplex_logabsdetjac_gradient, - Inverse, - maphcat, - IrrationalConstants, - Distributions, - logabsdetjac, - _logabsdetjac_scale, - _simplex_bijector, - _simplex_inv_bijector, - replace_diag, - jacobian, - _transform_ordered, - _transform_inverse_ordered, - find_alpha, - pd_logpdf_with_trans, - istraining, - mapvcat, - eachcolmaphcat, - sumeachcol, - pd_link, - pd_from_lower, - lower_triangular, - upper_triangular, - getlogp - - using ..Bijectors.LinearAlgebra - using ..Bijectors.Compat: eachcol - using ..Bijectors.Distributions: LocationScale -end +using Bijectors.LinearAlgebra +using Bijectors.Distributions: LocationScale @adjoint istraining() = true, _ -> nothing diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 9f70243c..389b3e46 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -41,12 +41,10 @@ import ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian import InverseFunctions: inverse using ChainRulesCore: ChainRulesCore -using ChainRules: ChainRules using Functors: Functors using IrrationalConstants: IrrationalConstants using LogExpFunctions: LogExpFunctions using Roots: Roots -using Compat: Compat using DocStringExtensions: TYPEDFIELDS export TransformDistribution, @@ -79,10 +77,6 @@ export TransformDistribution, InvertibleBatchNorm, elementwise -if VERSION < v"1.9" - using Compat: stack -end - const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) _debug(str) = @debug str @@ -362,32 +356,4 @@ include("chainrules.jl") maporbroadcast(f, x::AbstractArray{<:Any,N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) -# optional dependencies -if !isdefined(Base, :get_extension) - using Requires -end - -function __init__() - @static if !isdefined(Base, :get_extension) - @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" include( - "../ext/BijectorsLazyArraysExt.jl" - ) - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( - "../ext/BijectorsForwardDiffExt.jl" - ) - @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include( - "../ext/BijectorsTrackerExt.jl" - ) - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include( - "../ext/BijectorsZygoteExt.jl" - ) - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( - "../ext/BijectorsReverseDiffExt.jl" - ) - @require DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" include( - "../ext/BijectorsDistributionsADExt.jl" - ) - end -end - end # module diff --git a/src/bijectors/product_bijector.jl b/src/bijectors/product_bijector.jl index c24ef394..3eee17d1 100644 --- a/src/bijectors/product_bijector.jl +++ b/src/bijectors/product_bijector.jl @@ -23,17 +23,7 @@ function _product_bijector_slices( # If N < M, then the bijectors expect an input vector of dimension `M - N`. # To achieve this, we need to slice along the last `N` dimensions. slice_indices = ntuple(i -> i + (M - N), N) - if VERSION >= v"1.9" - return eachslice(x; dims=slice_indices) - else - # Earlier Julia versions can't eachslice over multiple dimensions, so reshape the - # slice dimensions into a single one. - other_dims = tuple((size(x, i) for i in 1:(M - N))...) - slice_dims = tuple((size(x, i) for i in (1 + M - N):M)...) - x_reshaped = reshape(x, other_dims..., prod(slice_dims)) - slices = eachslice(x_reshaped; dims=M - N + 1) - return reshape(collect(slices), slice_dims) - end + return eachslice(x; dims=slice_indices) end # Specialization for case where we're just applying elementwise. diff --git a/test/Project.toml b/test/Project.toml index bd62af58..48c4b662 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,9 +4,9 @@ AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -30,9 +31,9 @@ AdvancedHMC = "0.6" ChainRulesTestUtils = "0.7, 1" ChangesOfVariables = "0.1" Combinatorics = "1.0.2" -Compat = "3.46, 4.2" DistributionsAD = "0.6.3" -Enzyme = "0.12.22" +Enzyme = "0.13.12" +EnzymeTestUtils = "0.2.1" FillArrays = "1" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" @@ -42,7 +43,8 @@ LazyArrays = "1, 2" LogDensityProblems = "2" LogExpFunctions = "0.3.1" MCMCDiagnosticTools = "0.3" +Mooncake = "0.4" ReverseDiff = "1.4.2" Tracker = "0.2.11" Zygote = "0.6.63" -julia = "1.3" +julia = "1.10" diff --git a/test/ad/enzyme.jl b/test/ad/enzyme.jl new file mode 100644 index 00000000..78fcab05 --- /dev/null +++ b/test/ad/enzyme.jl @@ -0,0 +1,51 @@ +@testset "Enzyme: Bijectors.find_alpha" begin + x = randn() + y = expm1(randn()) + z = randn() + + @testset "forward" begin + # No batches + @testset for RT in (Const, Duplicated, DuplicatedNoNeed), + Tx in (Const, Duplicated), + Ty in (Const, Duplicated), + Tz in (Const, Duplicated) + + # Rule not picked up by Enzyme on Julia 1.11?! + # Ref https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2470766968 + if VERSION >= v"1.11" && Tx <: Const && Ty <: Const && Tz <: Const + continue + end + + test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz)) + end + + # Batches + @testset for RT in (Const, BatchDuplicated, BatchDuplicatedNoNeed), + Tx in (Const, BatchDuplicated), + Ty in (Const, BatchDuplicated), + Tz in (Const, BatchDuplicated) + + # Rule not picked up by Enzyme on Julia 1.11?! + # Ref https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2470766968 + if VERSION >= v"1.11" && Tx <: Const && Ty <: Const && Tz <: Const + continue + end + + test_forward(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz)) + end + end + @testset "reverse" begin + # No batches + @testset for RT in (Const, Active), + Tx in (Const, Active), + Ty in (Const, Active), + Tz in (Const, Active) + + test_reverse(Bijectors.find_alpha, RT, (x, Tx), (y, Ty), (z, Tz)) + end + + # TODO: Test batch mode + # This is a bit problematic since Enzyme does not support all combinations of activities currently + # https://github.com/TuringLang/Bijectors.jl/pull/350#issuecomment-2480468728 + end +end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 1358173b..d7e0a245 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -12,6 +12,10 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) :Enzyme, :EnzymeForward, :EnzymeReverse, + # The `Crash` ones indicate that the error will cause a Julia crash, and + # thus we can't even run `@test_broken on it. + :EnzymeForwardCrash, + :EnzymeReverseCrash, ) ) error("Unknown broken AD backend: $b") @@ -19,7 +23,6 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] - et = eltype(finitediff) if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken @@ -34,7 +37,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol = rtol atol = atol else ∇zygote = Zygote.gradient(f, x)[1] - @test (all(finitediff .== 0) && ∇zygote === nothing) || + @test (all(iszero, finitediff) && ∇zygote === nothing) || isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) end end @@ -47,38 +50,43 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - # TODO(mhauru) The version bound should be relaxed once some Enzyme issues get - # sorted out. I think forward mode will remain broken for versions <= 1.6 due to - # some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and - # discussion in https://github.com/TuringLang/Bijectors.jl/pull/318. - if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10" + if AD == "All" || AD == "Enzyme" forward_broken = :EnzymeForward in broken || :Enzyme in broken reverse_broken = :EnzymeReverse in broken || :Enzyme in broken - if forward_broken - @test_broken( - collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, - atol = atol - ) - else - @test( - collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, - atol = atol - ) + if !(:EnzymeForwardCrash in broken) + if forward_broken + @test_broken( + Enzyme.gradient(Enzyme.Forward, f, x)[1] ≈ finitediff, + rtol = rtol, + atol = atol + ) + else + @test( + Enzyme.gradient(Enzyme.Forward, f, x)[1] ≈ finitediff, + rtol = rtol, + atol = atol + ) + end end - if reverse_broken - @test_broken( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol - ) - else - @test( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol - ) + + if !(:EnzymeReverseCrash in broken) + if reverse_broken + @test_broken( + Enzyme.gradient(Enzyme.Reverse, f, x)[1] ≈ finitediff, + rtol = rtol, + atol = atol + ) + else + @test( + Enzyme.gradient(Enzyme.Reverse, f, x)[1] ≈ finitediff, + rtol = rtol, + atol = atol + ) + end end end - if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10" + if AD == "All" || AD == "Mooncake" try Mooncake.build_rrule(f, x) catch exc diff --git a/test/bijectors/product_bijector.jl b/test/bijectors/product_bijector.jl index 78310572..818f89c0 100644 --- a/test/bijectors/product_bijector.jl +++ b/test/bijectors/product_bijector.jl @@ -33,27 +33,14 @@ has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) end y, logjac = stack(map(first, results)), sum(last, results) - if VERSION < v"1.9" && length(size(d)) > 0 - # `eachslice`, which is used by `ProductBijector`, is type-unstable - # for multivariate cases on Julia < 1.9. Hence the type-inference fails. - @test_broken test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) - else - test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) - end + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) end @testset "Two-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds @@ -70,27 +57,13 @@ has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) results = map(Base.Fix1(with_logabsdet_jacobian, b), xs) y, logjac = stack(map(first, results)), sum(last, results) - if VERSION < v"1.9" && length(size(d)) > 0 - # `eachslice`, which is used by `ProductBijector`, does not support - # `dims` with more than one value. As a result, stacking anything that - # isn't univariate won't work here. - @test_broken test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) - else - test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) - end + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) end end diff --git a/test/runtests.jl b/test/runtests.jl index 638bd15c..c66f4c3f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,10 +4,12 @@ using ChainRulesTestUtils using Combinatorics using DistributionsAD using Enzyme +using EnzymeTestUtils using FiniteDifferences using ForwardDiff using Functors using LogExpFunctions +using Mooncake using ReverseDiff using Tracker using Zygote @@ -30,18 +32,6 @@ using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions using LazyArrays: LazyArrays -if VERSION < v"1.9" - using Compat: stack -end - -# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing -# on at least version 1.10. -if VERSION >= v"1.10" - using Pkg - Pkg.add("Mooncake") - using Mooncake -end - const GROUP = get(ENV, "GROUP", "All") # Always include this since it can be useful for other tests. @@ -68,6 +58,7 @@ end if GROUP == "All" || GROUP == "AD" include("ad/chainrules.jl") + include("ad/enzyme.jl") include("ad/flows.jl") include("ad/pd.jl") include("ad/corr.jl")