diff --git a/HISTORY.md b/HISTORY.md index 6b7247c8d..3b743650f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -49,6 +49,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia This change also affects sampling in Turing.jl. +### New features + +The `DynamicPPL.TestUtils.AD` module now contains several functions for testing the correctness of automatic differentiation of log densities. +Please refer to the DynamicPPL documentation for more details. + ## 0.34.2 - Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. diff --git a/Project.toml b/Project.toml index 38382f98f..28c1256a4 100644 --- a/Project.toml +++ b/Project.toml @@ -12,11 +12,10 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see -# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -56,14 +55,13 @@ Bijectors = "0.13.18, 0.14, 0.15" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" +DifferentiationInterface = "0.6.39" Distributions = "0.25" DocStringExtensions = "0.9" -# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 -# for why KernelAbstractions is pinned like this. -KernelAbstractions = "< 0.9.32" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" JET = "0.9" +KernelAbstractions = "< 0.9.32" LinearAlgebra = "1.6" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" diff --git a/docs/src/api.md b/docs/src/api.md index 36dd24250..433a875d4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -219,6 +219,15 @@ DynamicPPL.TestUtils.update_values!! DynamicPPL.TestUtils.test_values ``` +To test whether automatic differentiation is working correctly, the following methods can be used: + +```@docs +DynamicPPL.TestUtils.AD.ad_ldp +DynamicPPL.TestUtils.AD.ad_di +DynamicPPL.TestUtils.AD.make_function +DynamicPPL.TestUtils.AD.make_params +``` + ## Debugging Utilities DynamicPPL provides a few methods for checking validity of a model-definition. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 55e1f7e88..2fd381dba 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -189,7 +189,6 @@ include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("submodel_macro.jl") -include("test_utils.jl") include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") @@ -199,6 +198,8 @@ include("values_as_in_model.jl") include("debug_utils.jl") using .DebugUtils +include("test_utils.jl") + include("experimental.jl") include("deprecated.jl") diff --git a/src/test_utils.jl b/src/test_utils.jl index c7d12c927..65079f023 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -18,5 +18,6 @@ include("test_utils/models.jl") include("test_utils/contexts.jl") include("test_utils/varinfo.jl") include("test_utils/sampler.jl") +include("test_utils/ad.jl") end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl new file mode 100644 index 000000000..9940504dd --- /dev/null +++ b/src/test_utils/ad.jl @@ -0,0 +1,201 @@ +module AD + +import ADTypes: AbstractADType +import DifferentiationInterface as DI +import ..DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo +import LogDensityProblems: logdensity, logdensity_and_gradient +import LogDensityProblemsAD: ADgradient +import Random: Random, AbstractRNG +import Test: @test + +export make_function, make_params, ad_ldp, ad_di, test_correctness + +""" + flipped_logdensity(θ, ldf) + +Flips the order of arguments for `logdensity` to match the signature needed +for DifferentiationInterface.jl. +""" +flipped_logdensity(θ, ldf) = logdensity(ldf, θ) + +""" + ad_ldp( + model::Model, + params::Vector{<:Real}, + adtype::AbstractADType, + varinfo::AbstractVarInfo=VarInfo(model) + ) + +Calculate the logdensity of `model` and its gradient using the AD backend +`adtype`, evaluated at the parameters `params`, using the implementation of +`logdensity_and_gradient` in the LogDensityProblemsAD.jl package. + +The `varinfo` argument is optional and is used to provide the container +structure for the parameters. Note that the _parameters_ inside the `varinfo` +argument itself are overridden by the `params` argument. This argument defaults +to [`DynamicPPL.VarInfo`](@ref), which is the default container structure used +throughout the Turing ecosystem; however, you can provide e.g. +[`DynamicPPL.SimpleVarInfo`](@ref) if you want to use a different container +structure. + +Returns a tuple `(value, gradient)` where `value <: Real` is the logdensity +of the model evaluated at `params`, and `gradient <: Vector{<:Real}` is the +gradient of the logdensity with respect to `params`. + +Note that DynamicPPL.jl and Turing.jl currently use LogDensityProblemsAD.jl +throughout, and hence this function most closely mimics the usage of AD within +the Turing ecosystem. + +For some AD backends such as Mooncake.jl, LogDensityProblemsAD.jl simply defers +to the DifferentiationInterface.jl package. In such a case, `ad_ldp` simplifies +to `ad_di` (in that if `ad_di` passes, one should expect `ad_ldp` to pass as +well). + +However, there are other AD backends which still have custom code in +LogDensityProblemsAD.jl (such as ForwardDiff.jl). For these backends, `ad_di` +may yield different results compared to `ad_ldp`, and the behaviour of `ad_di` +is in such cases not guaranteed to be consistent with the behaviour of +Turing.jl. + +See also: [`ad_di`](@ref). +""" +function ad_ldp( + model::Model, + params::Vector{<:Real}, + adtype::AbstractADType, + vi::AbstractVarInfo=VarInfo(model), +) + ldf = LogDensityFunction(model, vi) + # Note that the implementation of logdensity takes care of setting the + # parameters in vi to the correct values (using unflatten) + return logdensity_and_gradient(ADgradient(adtype, ldf), params) +end + +""" + ad_di( + model::Model, + params::Vector{<:Real}, + adtype::AbstractADType, + varinfo::AbstractVarInfo=VarInfo(model) + ) + +Calculate the logdensity of `model` and its gradient using the AD backend +`adtype`, evaluated at the parameters `params`, directly using +DifferentiationInterface.jl. + +See the notes in [`ad_ldp`](@ref) for more details on the differences between +`ad_di` and `ad_ldp`. +""" +function ad_di( + model::Model, + params::Vector{<:Real}, + adtype::AbstractADType, + vi::AbstractVarInfo=VarInfo(model), +) + ldf = LogDensityFunction(model, vi) + # Note that the implementation of logdensity takes care of setting the + # parameters in vi to the correct values (using unflatten) + prep = DI.prepare_gradient(flipped_logdensity, adtype, params, DI.Constant(ldf)) + return DI.value_and_gradient(flipped_logdensity, prep, adtype, params, DI.Constant(ldf)) +end + +""" + make_function(model, varinfo::AbstractVarInfo=VarInfo(model)) + +Generate the function to be differentiated. Specifically, +`make_function(model)` returns a function which takes a single argument +`params` and returns the logdensity of `model` evaluated at `params`. + +The `varinfo` parameter is optional and is used to determine the structure of +the varinfo used during evaluation. See the [`ad_ldp`](@ref) function for more +details on the `varinfo` argument. + +If you have an AD package that does not have integrations with either +LogDensityProblemsAD.jl (in which case you can use [`ad_ldp`](@ref)) or +DifferentiationInterface.jl (in which case you can use [`ad_di`](@ref)), you +can test whether your AD package works with Turing.jl models using: + +```julia +f = make_function(model) +params = make_params(model) +value, grad = YourADPackage.gradient(f, params) +``` + +and compare the results against that obtained from either `ad_ldp` or `ad_di` for +an existing AD package that _is_ supported. + +See also: [`make_params`](@ref). +""" +function make_function(model::Model, vi::AbstractVarInfo=VarInfo(model)) + # TODO: Can we simplify this even further by inlining the definition of + # logdensity? + return Base.Fix1(logdensity, LogDensityFunction(model, vi)) +end + +""" + make_params(model, rng::Random.AbstractRNG=Random.default_rng()) + +Generate a vector of parameters sampled from the prior distribution of `model`. +This can be used as the input to the function to be differentiated. See +[`make_function`](@ref) for more details. +""" +function make_params(model::Model, rng::AbstractRNG=Random.default_rng()) + return VarInfo(rng, model)[:] +end + +""" + test_correctness( + ad_function, + model::Model, + adtypes::Vector{<:ADTypes.AbstractADType}, + reference_adtype::ADTypes.AbstractADType, + rng::Random.AbstractRNG=Random.default_rng(), + params::Vector{<:Real}=VarInfo(rng, model)[:]; + value_atol=1e-6, + grad_atol=1e-6 + ) + +Test the correctness of all the AD backend `adtypes` for the model `model` +using the implementation `ad_function`. `ad_function` should be either +[`ad_ldp`](@ref) or [`ad_di`](@ref), or a custom function that has the same +signature. + +The test is performed by calculating the logdensity and its gradient using all +the AD backends, and comparing the results against that obtained with the +reference AD backend `reference_adtype`. + +The parameters can either be passed explicitly using the `params` argument, or can +be sampled from the prior distribution of the model using the `rng` argument. +""" +function test_correctness( + ad_function, + model::Model, + adtypes::Vector{<:AbstractADType}, + reference_adtype::AbstractADType, + rng::AbstractRNG=Random.default_rng(), + params::Vector{<:Real}=VarInfo(rng, model)[:]; + value_atol=1e-6, + grad_atol=1e-6, +) + value_true, grad_true = ad_function(model, params, reference_adtype) + for adtype in adtypes + value, grad = ad_function(model, params, adtype) + info_str = join( + [ + "Testing AD correctness", + " AD function : $(ad_function)", + " backend : $(adtype)", + " model : $(model.f)", + " params : $(params)", + " actual : $((value, grad))", + " expected : $((value_true, grad_true))", + ], + "\n", + ) + @info info_str + @test value ≈ value_true atol = value_atol + @test grad ≈ grad_true atol = grad_atol + end +end + +end # module DynamicPPL.TestUtils.AD diff --git a/test/ad.jl b/test/ad.jl index 17981cf2a..d94dce0a1 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -6,29 +6,29 @@ varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = DynamicPPL.LogDensityFunction(m, varinfo) - - # use ForwardDiff result as reference - ad_forwarddiff_f = LogDensityProblemsAD.ADgradient( - ADTypes.AutoForwardDiff(; chunksize=0), f - ) # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 - θ = convert(Vector{Float64}, varinfo[:]) - logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) + params = convert(Vector{Float64}, varinfo[:]) + # Use ForwardDiff as reference AD backend + ref_logp, ref_grad = DynamicPPL.TestUtils.AD.ad_ldp( + m, params, ADTypes.AutoForwardDiff() + ) + # Test correctness of all other backends @testset "$adtype" for adtype in [ ADTypes.AutoReverseDiff(; compile=false), ADTypes.AutoReverseDiff(; compile=true), ADTypes.AutoMooncake(; config=nothing), ] + @info "Testing AD correctness: $(m.f), $(adtype), $(short_varinfo_name(varinfo))" + # Mooncake can't currently handle something that is going on in # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo @test_broken 1 == 0 else - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) + logp, grad = DynamicPPL.TestUtils.AD.ad_ldp(m, params, adtype) + @test logp ≈ ref_logp @test grad ≈ ref_grad end end