Skip to content

Commit

Permalink
Remove DynamicPPLForwardDiffExt
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 13, 2025
1 parent 8de4742 commit 7cb38f3
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 61 deletions.
3 changes: 0 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Expand All @@ -38,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
Expand All @@ -58,7 +56,6 @@ DifferentiationInterface = "0.6.39"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "< 0.9.32"
LinearAlgebra = "1.6"
Expand Down
27 changes: 0 additions & 27 deletions ext/DynamicPPLForwardDiffExt.jl

This file was deleted.

40 changes: 24 additions & 16 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import DifferentiationInterface as DI
LogDensityFunction
A callable representing a log density function of a `model`.
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
but only to 0th-order, i.e. it is only possible to calculate the log density,
and not its gradient. If you need to calculate the gradient as well, you have
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
# Fields
$(FIELDS)
Expand Down Expand Up @@ -55,16 +59,6 @@ struct LogDensityFunction{V,M,C}
context::C
end

# TODO: Deprecate.
function LogDensityFunction(
varinfo::AbstractVarInfo,
model::Model,
sampler::AbstractSampler,
context::AbstractContext,
)
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
end

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
Expand Down Expand Up @@ -94,11 +88,6 @@ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
# we need to define these annoying methods to ensure that we stay compatible with everything.
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))

"""
getparams(f::LogDensityFunction)
Expand All @@ -122,7 +111,26 @@ end
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

# LogDensityProblems interface: gradient (1st order)
struct LogDensityFunctionWithGrad{V,M,C,TAD}
"""
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)
A callable representing a log density function of a `model`.
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
interface to 1st-order, meaning that you can both calculate the log density
using
LogDensityProblems.logdensity(f, x)
and its gradient using
LogDensityProblems.logdensity_and_gradient(f, x)
where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
# Fields
$(FIELDS)
"""
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
ldf::LogDensityFunction{V,M,C}
adtype::TAD
prep::DI.GradientPrep
Expand Down
4 changes: 3 additions & 1 deletion test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
@test_broken 1 == 0
else
ldf_with_grad = DynamicPPL.LogDensityFunctionWithGrad(f, adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(f, x)
logp, grad = LogDensityProblems.logdensity_and_gradient(
ldf_with_grad, x
)
@test grad ref_grad
@test logp ref_logp
end
Expand Down
14 changes: 0 additions & 14 deletions test/ext/DynamicPPLForwardDiffExt.jl

This file was deleted.

0 comments on commit 7cb38f3

Please sign in to comment.