Skip to content

Commit

Permalink
Add AD testing utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 5, 2025
1 parent 1366440 commit eac98e1
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 16 deletions.
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia
This change also affects sampling in Turing.jl.
### New features
The `DynamicPPL.TestUtils.AD` module now contains several functions for testing the correctness of automatic differentiation of log densities.
Please refer to the DynamicPPL documentation for more details.
## 0.34.2
- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.
Expand Down
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down Expand Up @@ -56,14 +55,13 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.39"
Distributions = "0.25"
DocStringExtensions = "0.9"
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
# for why KernelAbstractions is pinned like this.
KernelAbstractions = "< 0.9.32"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "< 0.9.32"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ DynamicPPL.TestUtils.update_values!!
DynamicPPL.TestUtils.test_values
```

To test whether automatic differentiation is working correctly, the following methods can be used:

```@docs
DynamicPPL.TestUtils.AD.ad_ldp
DynamicPPL.TestUtils.AD.ad_di
DynamicPPL.TestUtils.AD.make_function
DynamicPPL.TestUtils.AD.make_params
```

## Debugging Utilities

DynamicPPL provides a few methods for checking validity of a model-definition.
Expand Down
3 changes: 2 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ include("context_implementations.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
Expand All @@ -199,6 +198,8 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")

Expand Down
1 change: 1 addition & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ include("test_utils/models.jl")
include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")
include("test_utils/ad.jl")

end
201 changes: 201 additions & 0 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
module AD

import ADTypes: AbstractADType
import DifferentiationInterface as DI
import ..DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
import LogDensityProblems: logdensity, logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
import Random: Random, AbstractRNG
import Test: @test

export make_function, make_params, ad_ldp, ad_di, test_correctness

"""
flipped_logdensity(θ, ldf)
Flips the order of arguments for `logdensity` to match the signature needed
for DifferentiationInterface.jl.
"""
flipped_logdensity(θ, ldf) = logdensity(ldf, θ)

"""
ad_ldp(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
varinfo::AbstractVarInfo=VarInfo(model)
)
Calculate the logdensity of `model` and its gradient using the AD backend
`adtype`, evaluated at the parameters `params`, using the implementation of
`logdensity_and_gradient` in the LogDensityProblemsAD.jl package.
The `varinfo` argument is optional and is used to provide the container
structure for the parameters. Note that the _parameters_ inside the `varinfo`
argument itself are overridden by the `params` argument. This argument defaults
to [`DynamicPPL.VarInfo`](@ref), which is the default container structure used
throughout the Turing ecosystem; however, you can provide e.g.
[`DynamicPPL.SimpleVarInfo`](@ref) if you want to use a different container
structure.
Returns a tuple `(value, gradient)` where `value <: Real` is the logdensity
of the model evaluated at `params`, and `gradient <: Vector{<:Real}` is the
gradient of the logdensity with respect to `params`.
Note that DynamicPPL.jl and Turing.jl currently use LogDensityProblemsAD.jl
throughout, and hence this function most closely mimics the usage of AD within
the Turing ecosystem.
For some AD backends such as Mooncake.jl, LogDensityProblemsAD.jl simply defers
to the DifferentiationInterface.jl package. In such a case, `ad_ldp` simplifies
to `ad_di` (in that if `ad_di` passes, one should expect `ad_ldp` to pass as
well).
However, there are other AD backends which still have custom code in
LogDensityProblemsAD.jl (such as ForwardDiff.jl). For these backends, `ad_di`
may yield different results compared to `ad_ldp`, and the behaviour of `ad_di`
is in such cases not guaranteed to be consistent with the behaviour of
Turing.jl.
See also: [`ad_di`](@ref).
"""
function ad_ldp(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
vi::AbstractVarInfo=VarInfo(model),
)
ldf = LogDensityFunction(model, vi)
# Note that the implementation of logdensity takes care of setting the
# parameters in vi to the correct values (using unflatten)
return logdensity_and_gradient(ADgradient(adtype, ldf), params)
end

"""
ad_di(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
varinfo::AbstractVarInfo=VarInfo(model)
)
Calculate the logdensity of `model` and its gradient using the AD backend
`adtype`, evaluated at the parameters `params`, directly using
DifferentiationInterface.jl.
See the notes in [`ad_ldp`](@ref) for more details on the differences between
`ad_di` and `ad_ldp`.
"""
function ad_di(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
vi::AbstractVarInfo=VarInfo(model),
)
ldf = LogDensityFunction(model, vi)
# Note that the implementation of logdensity takes care of setting the
# parameters in vi to the correct values (using unflatten)
prep = DI.prepare_gradient(flipped_logdensity, adtype, params, DI.Constant(ldf))
return DI.value_and_gradient(flipped_logdensity, prep, adtype, params, DI.Constant(ldf))
end

"""
make_function(model, varinfo::AbstractVarInfo=VarInfo(model))
Generate the function to be differentiated. Specifically,
`make_function(model)` returns a function which takes a single argument
`params` and returns the logdensity of `model` evaluated at `params`.
The `varinfo` parameter is optional and is used to determine the structure of
the varinfo used during evaluation. See the [`ad_ldp`](@ref) function for more
details on the `varinfo` argument.
If you have an AD package that does not have integrations with either
LogDensityProblemsAD.jl (in which case you can use [`ad_ldp`](@ref)) or
DifferentiationInterface.jl (in which case you can use [`ad_di`](@ref)), you
can test whether your AD package works with Turing.jl models using:
```julia
f = make_function(model)
params = make_params(model)
value, grad = YourADPackage.gradient(f, params)
```
and compare the results against that obtained from either `ad_ldp` or `ad_di` for
an existing AD package that _is_ supported.
See also: [`make_params`](@ref).
"""
function make_function(model::Model, vi::AbstractVarInfo=VarInfo(model))
# TODO: Can we simplify this even further by inlining the definition of
# logdensity?
return Base.Fix1(logdensity, LogDensityFunction(model, vi))
end

"""
make_params(model, rng::Random.AbstractRNG=Random.default_rng())
Generate a vector of parameters sampled from the prior distribution of `model`.
This can be used as the input to the function to be differentiated. See
[`make_function`](@ref) for more details.
"""
function make_params(model::Model, rng::AbstractRNG=Random.default_rng())
return VarInfo(rng, model)[:]
end

"""
test_correctness(
ad_function,
model::Model,
adtypes::Vector{<:ADTypes.AbstractADType},
reference_adtype::ADTypes.AbstractADType,
rng::Random.AbstractRNG=Random.default_rng(),
params::Vector{<:Real}=VarInfo(rng, model)[:];
value_atol=1e-6,
grad_atol=1e-6
)
Test the correctness of all the AD backend `adtypes` for the model `model`
using the implementation `ad_function`. `ad_function` should be either
[`ad_ldp`](@ref) or [`ad_di`](@ref), or a custom function that has the same
signature.
The test is performed by calculating the logdensity and its gradient using all
the AD backends, and comparing the results against that obtained with the
reference AD backend `reference_adtype`.
The parameters can either be passed explicitly using the `params` argument, or can
be sampled from the prior distribution of the model using the `rng` argument.
"""
function test_correctness(
ad_function,
model::Model,
adtypes::Vector{<:AbstractADType},
reference_adtype::AbstractADType,
rng::AbstractRNG=Random.default_rng(),
params::Vector{<:Real}=VarInfo(rng, model)[:];
value_atol=1e-6,
grad_atol=1e-6,
)
value_true, grad_true = ad_function(model, params, reference_adtype)
for adtype in adtypes
value, grad = ad_function(model, params, adtype)
info_str = join(
[
"Testing AD correctness",
" AD function : $(ad_function)",
" backend : $(adtype)",
" model : $(model.f)",
" params : $(params)",
" actual : $((value, grad))",
" expected : $((value_true, grad_true))",
],
"\n",
)
@info info_str
@test value value_true atol = value_atol
@test grad grad_true atol = grad_atol
end
end

end # module DynamicPPL.TestUtils.AD
20 changes: 10 additions & 10 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
params = convert(Vector{Float64}, varinfo[:])
# Use ForwardDiff as reference AD backend
ref_logp, ref_grad = DynamicPPL.TestUtils.AD.ad_ldp(
m, params, ADTypes.AutoForwardDiff()
)

# Test correctness of all other backends
@testset "$adtype" for adtype in [
ADTypes.AutoReverseDiff(; compile=false),
ADTypes.AutoReverseDiff(; compile=true),
ADTypes.AutoMooncake(; config=nothing),
]
@info "Testing AD correctness: $(m.f), $(adtype), $(short_varinfo_name(varinfo))"

# Mooncake can't currently handle something that is going on in
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
@test_broken 1 == 0
else
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
logp, grad = DynamicPPL.TestUtils.AD.ad_ldp(m, params, adtype)
@test logp ref_logp
@test grad ref_grad
end
end
Expand Down

0 comments on commit eac98e1

Please sign in to comment.