Skip to content

Commit

Permalink
Fix bugs and generalize code for ComponentArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ElOceanografo committed Oct 22, 2024
1 parent a46273f commit 2cadf50
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
41 changes: 17 additions & 24 deletions src/MarginalLogDensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,17 @@ struct MarginalLogDensity{
hess_prep::TE
end

_generic_eachindex(x) = eachindex(x)
_generic_eachindex(x::ComponentArray) = keys(x)

function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox();
hess_adtype=nothing, sparsity_detector=DenseSparsityDetector(method.adtype, atol=sqrt(eps())),
coloring_algorithm=GreedyColoringAlgorithm())
iv = setdiff(eachindex(u), iw)
iv = setdiff(_generic_eachindex(u), iw)
w = u[iw]
v = u[iv]
p2 = (p=data, v=v)
f(w, p2) = -logdensity(merge_parameters(p2.v, w, iv, iw), p2.p)
f(w, p2) = -logdensity(merge_parameters(p2.v, w, iv, iw, u), p2.p)
f_opt = OptimizationFunction(f, method.adtype; method.opt_func_kwargs...)
prob = OptimizationProblem(f_opt, w, p2)
cache = init(prob, method.solver)
Expand All @@ -213,16 +216,9 @@ function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox();
H, hess_adtype, prep)
end

function MarginalLogDensity(logdensity, u::ComponentArray, iw::Vector{Symbol},
args...; kwargs...)
iw1 = reduce(vcat, label2index(u, label) for label in iw)
u1 = Vector(u)
MarginalLogDensity(logdensity, u1, iw1, args..., kwargs...)
end

function Base.show(io::IO, mld::MarginalLogDensity)
T = typeof(mld.method).name.name
str = "MarginalLogDensity of function $(repr(mld.logdensity))\nIntegrating $(length(mld.iw))/$(length(mld.u)) variables via $(T)"
str = "MarginalLogDensity of function $(repr(mld.logdensity))\nIntegrating $(nmarginal(mld))/$(dimension(mld)) variables via $(T)"
write(io, str)
end

Expand Down Expand Up @@ -252,23 +248,22 @@ cached_hessian(mld::MarginalLogDensity) = mld.H
Splice together the estimated (fixed) parameters `v` and marginalized (random) parameters
`w` into the single parameter vector `u`, based on their indices `iv` and `iw`.
"""
function merge_parameters(v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw) where {T1,T2}
N = length(v) + length(w)
u = Vector{promote_type(T1, T2)}(undef, N)
u[iv] .= v
u[iw] .= w
return u
function merge_parameters(v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw, u) where {T1,T2}
u1 = convert.(promote_type(T1, T2), u)
u1[iv] .= v
u1[iw] .= w
return u1
end

function ChainRulesCore.rrule(::typeof(merge_parameters),
v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw) where {T1,T2}
u = merge_parameters(v, w, iv, iw)
v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw, u) where {T1,T2}
u1 = merge_parameters(v, w, iv, iw, u)
function merge_parameters_pullback(ubar)
vbar = ubar[iv]
wbar = ubar[iw]
return (NoTangent(), vbar, wbar, NoTangent(), NoTangent())
return (NoTangent(), vbar, wbar, NoTangent(), NoTangent(), NoTangent())
end
return u, merge_parameters_pullback
return u1, merge_parameters_pullback
end


Expand All @@ -295,14 +290,13 @@ function modal_hessian!(mld::MarginalLogDensity, w, p2)
end

function _marginalize(mld, v, data, method::LaplaceApprox, verbose)
p2 = (; p=data, v)
p2 = (; p = data, v = collect(v))
verbose && println("Finding mode...")
wopt, objective = optimize_marginal!(mld, p2)
verbose && println("Calculating hessian...")
modal_hessian!(mld, wopt, p2)
verbose && println("Integrating...")
nw = length(mld.iw)
integral = -objective + (0.5nw) * log(2π) - 0.5logabsdet(mld.H)[1]
integral = -objective + (0.5nmarginal(mld)) * log(2π) - 0.5logabsdet(mld.H)[1]
verbose && println("Done!")
return integral#, sol
end
Expand All @@ -321,7 +315,6 @@ function _marginalize(mld, v, data, method::Cubature, verbose)
p2 = (; p=data, v)
if method.lower == nothing || method.upper == nothing
wopt, _ = optimize_marginal!(mld, p2)
println(wopt)
h = hessdiag(w -> mld.f_opt(w, p2), wopt)
se = 1 ./ sqrt.(h)
upper = wopt .+ method.* se
Expand Down
19 changes: 12 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ u_component = ComponentArray(v = v, w = w)
end
for mld in mlds
@test all(mld.u .== u)
@test all(u .== merge_parameters(v, w, iv, iw))
@test all(u .== merge_parameters(v, w, iv, iw, u))
v1, w1 = split_parameters(mld.u, mld.iv, mld.iw)
@test all(v1 .== v)
@test all(w1 .== w)
Expand All @@ -66,10 +66,10 @@ u_component = ComponentArray(v = v, w = w)
mld1 = MarginalLogDensity(ld, u_component, iw_symbol)
mld2 = MarginalLogDensity(ld, u_vector, iw_indices)
@test dimension(mld1) == dimension(mld2)
@test imarginal(mld1) == imarginal(mld2)
@test all(mld1.u[iw_symbol] .== mld2.u[iw_indices])

@test all(mld1.u .== u_vector)
@test all(u .== merge_parameters(v, w, iv, iw))
@test all(u .== merge_parameters(v, w, iv, iw, u_component))
v1, w1 = split_parameters(mld1.u, mld1.iv, mld1.iw)
v2, w2 = split_parameters(mld2.u, mld2.iv, mld2.iw)
@test all(v1 .== v2)
Expand All @@ -94,15 +94,19 @@ u_component = ComponentArray(v = v, w = w)
end

@testset "Custom ChainRules" begin
v = fill(1, 3)
w = fill(2, 4)
v = fill(1.0, 3)
w = fill(2.0, 4)
iv = 1:3
iw = 4:7
test_rrule(merge_parameters, v, w, iv, iw)
u = zeros(length(v) + length(w))
u[iv] .= v
u[iw] .= w
test_rrule(merge_parameters, v, w, iv, iw, u)
end

@testset "Dense approximations" begin
x = 1.0:3.0
x_component = ComponentVector(v = x[iv], w = x[iw])
mld_laplace = MarginalLogDensity(ld, u, iw, (), LaplaceApprox())
mld_laplace_component = MarginalLogDensity(ld, u_component, [:w], (), LaplaceApprox())
lb = fill(-100.0, 2)
Expand All @@ -117,8 +121,9 @@ end

# analytical: against 1D Gaussian
logpdf_true = logpdf(dmarginal, x[only(iv)])
@test x[iv] == x_component.v
logpdf_laplace = mld_laplace(x[iv], ())
logpdf_laplace_component = mld_laplace_component(x[iv], ())
logpdf_laplace_component = mld_laplace_component(x_component[[:v]], ())
logpdf_cubature1 = mld_cubature1(x[iv], ())
logpdf_cubature2 = mld_cubature2(x[iv], ())

Expand Down

0 comments on commit 2cadf50

Please sign in to comment.