diff --git a/Project.toml b/Project.toml index cf9d210cc..ef00c58b5 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -38,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] -DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] @@ -58,7 +56,6 @@ DifferentiationInterface = "0.6.39" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" -ForwardDiff = "0.10" JET = "0.9" KernelAbstractions = "< 0.9.32" LinearAlgebra = "1.6" diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index a10e9eae6..000000000 --- a/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,27 +0,0 @@ -module DynamicPPLForwardDiffExt - -using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems -using ForwardDiff - -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk - -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true -standardtag(::ADTypes.AutoForwardDiff) = false - -# Allow Turing tag in gradient etc. calls of the log density function -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::DynamicPPL.LogDensityFunction, - ::AbstractArray{W}, -) where {V,W} - return true -end -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, - ::AbstractArray{W}, -) where {V,W} - return true -end - -end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 00bd3e080..666537a57 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -4,6 +4,10 @@ import DifferentiationInterface as DI LogDensityFunction A callable representing a log density function of a `model`. +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface, +but only to 0th-order, i.e. it is only possible to calculate the log density, +and not its gradient. If you need to calculate the gradient as well, you have +to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object. # Fields $(FIELDS) @@ -55,16 +59,6 @@ struct LogDensityFunction{V,M,C} context::C end -# TODO: Deprecate. -function LogDensityFunction( - varinfo::AbstractVarInfo, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) - return LogDensityFunction(varinfo, model, SamplingContext(sampler, context)) -end - function LogDensityFunction( model::Model, varinfo::AbstractVarInfo=VarInfo(model), @@ -94,11 +88,6 @@ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) return Accessors.@set f.model = model end -# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time -# we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) -hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) - """ getparams(f::LogDensityFunction) @@ -122,7 +111,26 @@ end LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) # LogDensityProblems interface: gradient (1st order) -struct LogDensityFunctionWithGrad{V,M,C,TAD} +""" + LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType) + +A callable representing a log density function of a `model`. +`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl +interface to 1st-order, meaning that you can both calculate the log density +using + + LogDensityProblems.logdensity(f, x) + +and its gradient using + + LogDensityProblems.logdensity_and_gradient(f, x) + +where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters. + +# Fields +$(FIELDS) +""" +struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType} ldf::LogDensityFunction{V,M,C} adtype::TAD prep::DI.GradientPrep diff --git a/test/ad.jl b/test/ad.jl index 7bab1032d..348c2b64e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -27,7 +27,9 @@ @test_broken 1 == 0 else ldf_with_grad = DynamicPPL.LogDensityFunctionWithGrad(f, adtype) - logp, grad = LogDensityProblems.logdensity_and_gradient(f, x) + logp, grad = LogDensityProblems.logdensity_and_gradient( + ldf_with_grad, x + ) @test grad ≈ ref_grad @test logp ≈ ref_logp end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index 8de28046b..000000000 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -@testset "tag" begin - for chunksize in (nothing, 0, 1, 10) - ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) - standardtag = if !isdefined(Base, :get_extension) - DynamicPPL.DynamicPPLForwardDiffExt.standardtag - else - Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag - end - @test standardtag(ad) - for tag in (false, 0, 1) - @test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag)) - end - end -end