Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AD testing utilities #799

Open
wants to merge 2 commits into
base: release-0.35
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia

This change also affects sampling in Turing.jl.

### New features

The `DynamicPPL.TestUtils.AD` module now contains several functions for testing the correctness of automatic differentiation of log densities.
Please refer to the DynamicPPL documentation for more details.

## 0.34.2

- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.
Expand Down
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down Expand Up @@ -56,14 +55,13 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.39"
Distributions = "0.25"
DocStringExtensions = "0.9"
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
# for why KernelAbstractions is pinned like this.
KernelAbstractions = "< 0.9.32"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "< 0.9.32"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ DynamicPPL.TestUtils.update_values!!
DynamicPPL.TestUtils.test_values
```

To test whether automatic differentiation is working correctly, the following methods can be used:

```@docs
DynamicPPL.TestUtils.AD.ad_ldp
DynamicPPL.TestUtils.AD.ad_di
DynamicPPL.TestUtils.AD.make_function
DynamicPPL.TestUtils.AD.make_params
```

## Debugging Utilities

DynamicPPL provides a few methods for checking validity of a model-definition.
Expand Down
3 changes: 2 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ include("context_implementations.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
Expand All @@ -199,6 +198,8 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")

Expand Down
1 change: 1 addition & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ include("test_utils/models.jl")
include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")
include("test_utils/ad.jl")

end
201 changes: 201 additions & 0 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
module AD

import ADTypes: AbstractADType
import DifferentiationInterface as DI
import ..DynamicPPL: DynamicPPL, Model, LogDensityFunction, VarInfo, AbstractVarInfo
import LogDensityProblems: logdensity, logdensity_and_gradient
import LogDensityProblemsAD: ADgradient
import Random: Random, AbstractRNG
import Test: @test

export make_function, make_params, ad_ldp, ad_di, test_correctness

"""
flipped_logdensity(θ, ldf)

Flips the order of arguments for `logdensity` to match the signature needed
for DifferentiationInterface.jl.
"""
flipped_logdensity(θ, ldf) = logdensity(ldf, θ)

Check warning on line 19 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L19

Added line #L19 was not covered by tests

"""
ad_ldp(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
varinfo::AbstractVarInfo=VarInfo(model)
)

Calculate the logdensity of `model` and its gradient using the AD backend
`adtype`, evaluated at the parameters `params`, using the implementation of
`logdensity_and_gradient` in the LogDensityProblemsAD.jl package.

The `varinfo` argument is optional and is used to provide the container
structure for the parameters. Note that the _parameters_ inside the `varinfo`
argument itself are overridden by the `params` argument. This argument defaults
to [`DynamicPPL.VarInfo`](@ref), which is the default container structure used
throughout the Turing ecosystem; however, you can provide e.g.
[`DynamicPPL.SimpleVarInfo`](@ref) if you want to use a different container
structure.

Returns a tuple `(value, gradient)` where `value <: Real` is the logdensity
of the model evaluated at `params`, and `gradient <: Vector{<:Real}` is the
gradient of the logdensity with respect to `params`.

Note that DynamicPPL.jl and Turing.jl currently use LogDensityProblemsAD.jl
throughout, and hence this function most closely mimics the usage of AD within
the Turing ecosystem.

For some AD backends such as Mooncake.jl, LogDensityProblemsAD.jl simply defers
to the DifferentiationInterface.jl package. In such a case, `ad_ldp` simplifies
to `ad_di` (in that if `ad_di` passes, one should expect `ad_ldp` to pass as
well).

However, there are other AD backends which still have custom code in
LogDensityProblemsAD.jl (such as ForwardDiff.jl). For these backends, `ad_di`
may yield different results compared to `ad_ldp`, and the behaviour of `ad_di`
is in such cases not guaranteed to be consistent with the behaviour of
Turing.jl.

See also: [`ad_di`](@ref).
"""
function ad_ldp(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
vi::AbstractVarInfo=VarInfo(model),
)
ldf = LogDensityFunction(model, vi)
# Note that the implementation of logdensity takes care of setting the
# parameters in vi to the correct values (using unflatten)
return logdensity_and_gradient(ADgradient(adtype, ldf), params)
end

"""
ad_di(
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
varinfo::AbstractVarInfo=VarInfo(model)
)

Calculate the logdensity of `model` and its gradient using the AD backend
`adtype`, evaluated at the parameters `params`, directly using
DifferentiationInterface.jl.

See the notes in [`ad_ldp`](@ref) for more details on the differences between
`ad_di` and `ad_ldp`.
"""
function ad_di(

Check warning on line 89 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L89

Added line #L89 was not covered by tests
model::Model,
params::Vector{<:Real},
adtype::AbstractADType,
vi::AbstractVarInfo=VarInfo(model),
)
ldf = LogDensityFunction(model, vi)

Check warning on line 95 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L95

Added line #L95 was not covered by tests
# Note that the implementation of logdensity takes care of setting the
# parameters in vi to the correct values (using unflatten)
prep = DI.prepare_gradient(flipped_logdensity, adtype, params, DI.Constant(ldf))
return DI.value_and_gradient(flipped_logdensity, prep, adtype, params, DI.Constant(ldf))

Check warning on line 99 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
end

"""
make_function(model, varinfo::AbstractVarInfo=VarInfo(model))

Generate the function to be differentiated. Specifically,
`make_function(model)` returns a function which takes a single argument
`params` and returns the logdensity of `model` evaluated at `params`.

The `varinfo` parameter is optional and is used to determine the structure of
the varinfo used during evaluation. See the [`ad_ldp`](@ref) function for more
details on the `varinfo` argument.

If you have an AD package that does not have integrations with either
LogDensityProblemsAD.jl (in which case you can use [`ad_ldp`](@ref)) or
DifferentiationInterface.jl (in which case you can use [`ad_di`](@ref)), you
can test whether your AD package works with Turing.jl models using:

```julia
f = make_function(model)
params = make_params(model)
value, grad = YourADPackage.gradient(f, params)
```

and compare the results against that obtained from either `ad_ldp` or `ad_di` for
an existing AD package that _is_ supported.

See also: [`make_params`](@ref).
"""
function make_function(model::Model, vi::AbstractVarInfo=VarInfo(model))

Check warning on line 129 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L129

Added line #L129 was not covered by tests
# TODO: Can we simplify this even further by inlining the definition of
# logdensity?
return Base.Fix1(logdensity, LogDensityFunction(model, vi))

Check warning on line 132 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L132

Added line #L132 was not covered by tests
end

"""
make_params(model, rng::Random.AbstractRNG=Random.default_rng())

Generate a vector of parameters sampled from the prior distribution of `model`.
This can be used as the input to the function to be differentiated. See
[`make_function`](@ref) for more details.
"""
function make_params(model::Model, rng::AbstractRNG=Random.default_rng())
return VarInfo(rng, model)[:]

Check warning on line 143 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L142-L143

Added lines #L142 - L143 were not covered by tests
end

"""
test_correctness(
ad_function,
model::Model,
adtypes::Vector{<:ADTypes.AbstractADType},
reference_adtype::ADTypes.AbstractADType,
rng::Random.AbstractRNG=Random.default_rng(),
params::Vector{<:Real}=VarInfo(rng, model)[:];
value_atol=1e-6,
grad_atol=1e-6
)

Test the correctness of all the AD backend `adtypes` for the model `model`
using the implementation `ad_function`. `ad_function` should be either
[`ad_ldp`](@ref) or [`ad_di`](@ref), or a custom function that has the same
signature.

The test is performed by calculating the logdensity and its gradient using all
the AD backends, and comparing the results against that obtained with the
reference AD backend `reference_adtype`.

The parameters can either be passed explicitly using the `params` argument, or can
be sampled from the prior distribution of the model using the `rng` argument.
"""
function test_correctness(

Check warning on line 170 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L170

Added line #L170 was not covered by tests
ad_function,
model::Model,
adtypes::Vector{<:AbstractADType},
reference_adtype::AbstractADType,
rng::AbstractRNG=Random.default_rng(),
params::Vector{<:Real}=VarInfo(rng, model)[:];
value_atol=1e-6,
grad_atol=1e-6,
)
value_true, grad_true = ad_function(model, params, reference_adtype)
for adtype in adtypes
value, grad = ad_function(model, params, adtype)
info_str = join(

Check warning on line 183 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L180-L183

Added lines #L180 - L183 were not covered by tests
[
"Testing AD correctness",
" AD function : $(ad_function)",
" backend : $(adtype)",
" model : $(model.f)",
" params : $(params)",
" actual : $((value, grad))",
" expected : $((value_true, grad_true))",
],
"\n",
)
@info info_str
@test value ≈ value_true atol = value_atol
@test grad ≈ grad_true atol = grad_atol
end

Check warning on line 198 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L195-L198

Added lines #L195 - L198 were not covered by tests
end

end # module DynamicPPL.TestUtils.AD
40 changes: 17 additions & 23 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,30 @@
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
TESTED_ADTYPES = [
ADTypes.AutoReverseDiff(; compile=false),
ADTypes.AutoReverseDiff(; compile=true),
ADTypes.AutoMooncake(; config=nothing),
]

@testset "AD correctness" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@info "Testing AD for $(m.f)"
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
params = convert(Vector{Float64}, varinfo[:])
# Use ForwardDiff as reference AD backend
ref_logp, ref_grad = DynamicPPL.TestUtils.AD.ad_ldp(
m, params, ADTypes.AutoForwardDiff()
)

@testset "$adtype" for adtype in [
ADTypes.AutoReverseDiff(; compile=false),
ADTypes.AutoReverseDiff(; compile=true),
ADTypes.AutoMooncake(; config=nothing),
]
# Mooncake can't currently handle something that is going on in
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
@test_broken 1 == 0
else
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ≈ ref_grad
end
@testset "$adtype" for adtype in TESTED_ADTYPES
logp, grad = DynamicPPL.TestUtils.AD.ad_ldp(m, params, adtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing varinfo argument

@test logp ≈ ref_logp
@test grad ≈ ref_grad
end
end
end
Expand Down