From 2cadf50846cef946ebe3ff3906d77c3fdb5cd30c Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 22 Oct 2024 22:35:46 +1100 Subject: [PATCH] Fix bugs and generalize code for ComponentArrays --- src/MarginalLogDensities.jl | 41 +++++++++++++++---------------------- test/runtests.jl | 19 ++++++++++------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/src/MarginalLogDensities.jl b/src/MarginalLogDensities.jl index 498e741..d767de3 100644 --- a/src/MarginalLogDensities.jl +++ b/src/MarginalLogDensities.jl @@ -188,14 +188,17 @@ struct MarginalLogDensity{ hess_prep::TE end +_generic_eachindex(x) = eachindex(x) +_generic_eachindex(x::ComponentArray) = keys(x) + function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox(); hess_adtype=nothing, sparsity_detector=DenseSparsityDetector(method.adtype, atol=sqrt(eps())), coloring_algorithm=GreedyColoringAlgorithm()) - iv = setdiff(eachindex(u), iw) + iv = setdiff(_generic_eachindex(u), iw) w = u[iw] v = u[iv] p2 = (p=data, v=v) - f(w, p2) = -logdensity(merge_parameters(p2.v, w, iv, iw), p2.p) + f(w, p2) = -logdensity(merge_parameters(p2.v, w, iv, iw, u), p2.p) f_opt = OptimizationFunction(f, method.adtype; method.opt_func_kwargs...) prob = OptimizationProblem(f_opt, w, p2) cache = init(prob, method.solver) @@ -213,16 +216,9 @@ function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox(); H, hess_adtype, prep) end -function MarginalLogDensity(logdensity, u::ComponentArray, iw::Vector{Symbol}, - args...; kwargs...) - iw1 = reduce(vcat, label2index(u, label) for label in iw) - u1 = Vector(u) - MarginalLogDensity(logdensity, u1, iw1, args..., kwargs...) -end - function Base.show(io::IO, mld::MarginalLogDensity) T = typeof(mld.method).name.name - str = "MarginalLogDensity of function $(repr(mld.logdensity))\nIntegrating $(length(mld.iw))/$(length(mld.u)) variables via $(T)" + str = "MarginalLogDensity of function $(repr(mld.logdensity))\nIntegrating $(nmarginal(mld))/$(dimension(mld)) variables via $(T)" write(io, str) end @@ -252,23 +248,22 @@ cached_hessian(mld::MarginalLogDensity) = mld.H Splice together the estimated (fixed) parameters `v` and marginalized (random) parameters `w` into the single parameter vector `u`, based on their indices `iv` and `iw`. """ -function merge_parameters(v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw) where {T1,T2} - N = length(v) + length(w) - u = Vector{promote_type(T1, T2)}(undef, N) - u[iv] .= v - u[iw] .= w - return u +function merge_parameters(v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw, u) where {T1,T2} + u1 = convert.(promote_type(T1, T2), u) + u1[iv] .= v + u1[iw] .= w + return u1 end function ChainRulesCore.rrule(::typeof(merge_parameters), - v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw) where {T1,T2} - u = merge_parameters(v, w, iv, iw) + v::AbstractVector{T1}, w::AbstractVector{T2}, iv, iw, u) where {T1,T2} + u1 = merge_parameters(v, w, iv, iw, u) function merge_parameters_pullback(ubar) vbar = ubar[iv] wbar = ubar[iw] - return (NoTangent(), vbar, wbar, NoTangent(), NoTangent()) + return (NoTangent(), vbar, wbar, NoTangent(), NoTangent(), NoTangent()) end - return u, merge_parameters_pullback + return u1, merge_parameters_pullback end @@ -295,14 +290,13 @@ function modal_hessian!(mld::MarginalLogDensity, w, p2) end function _marginalize(mld, v, data, method::LaplaceApprox, verbose) - p2 = (; p=data, v) + p2 = (; p = data, v = collect(v)) verbose && println("Finding mode...") wopt, objective = optimize_marginal!(mld, p2) verbose && println("Calculating hessian...") modal_hessian!(mld, wopt, p2) verbose && println("Integrating...") - nw = length(mld.iw) - integral = -objective + (0.5nw) * log(2π) - 0.5logabsdet(mld.H)[1] + integral = -objective + (0.5nmarginal(mld)) * log(2π) - 0.5logabsdet(mld.H)[1] verbose && println("Done!") return integral#, sol end @@ -321,7 +315,6 @@ function _marginalize(mld, v, data, method::Cubature, verbose) p2 = (; p=data, v) if method.lower == nothing || method.upper == nothing wopt, _ = optimize_marginal!(mld, p2) - println(wopt) h = hessdiag(w -> mld.f_opt(w, p2), wopt) se = 1 ./ sqrt.(h) upper = wopt .+ method.nσ * se diff --git a/test/runtests.jl b/test/runtests.jl index 5b2ad6a..985bbc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,7 +52,7 @@ u_component = ComponentArray(v = v, w = w) end for mld in mlds @test all(mld.u .== u) - @test all(u .== merge_parameters(v, w, iv, iw)) + @test all(u .== merge_parameters(v, w, iv, iw, u)) v1, w1 = split_parameters(mld.u, mld.iv, mld.iw) @test all(v1 .== v) @test all(w1 .== w) @@ -66,10 +66,10 @@ u_component = ComponentArray(v = v, w = w) mld1 = MarginalLogDensity(ld, u_component, iw_symbol) mld2 = MarginalLogDensity(ld, u_vector, iw_indices) @test dimension(mld1) == dimension(mld2) - @test imarginal(mld1) == imarginal(mld2) + @test all(mld1.u[iw_symbol] .== mld2.u[iw_indices]) @test all(mld1.u .== u_vector) - @test all(u .== merge_parameters(v, w, iv, iw)) + @test all(u .== merge_parameters(v, w, iv, iw, u_component)) v1, w1 = split_parameters(mld1.u, mld1.iv, mld1.iw) v2, w2 = split_parameters(mld2.u, mld2.iv, mld2.iw) @test all(v1 .== v2) @@ -94,15 +94,19 @@ u_component = ComponentArray(v = v, w = w) end @testset "Custom ChainRules" begin - v = fill(1, 3) - w = fill(2, 4) + v = fill(1.0, 3) + w = fill(2.0, 4) iv = 1:3 iw = 4:7 - test_rrule(merge_parameters, v, w, iv, iw) + u = zeros(length(v) + length(w)) + u[iv] .= v + u[iw] .= w + test_rrule(merge_parameters, v, w, iv, iw, u) end @testset "Dense approximations" begin x = 1.0:3.0 + x_component = ComponentVector(v = x[iv], w = x[iw]) mld_laplace = MarginalLogDensity(ld, u, iw, (), LaplaceApprox()) mld_laplace_component = MarginalLogDensity(ld, u_component, [:w], (), LaplaceApprox()) lb = fill(-100.0, 2) @@ -117,8 +121,9 @@ end # analytical: against 1D Gaussian logpdf_true = logpdf(dmarginal, x[only(iv)]) + @test x[iv] == x_component.v logpdf_laplace = mld_laplace(x[iv], ()) - logpdf_laplace_component = mld_laplace_component(x[iv], ()) + logpdf_laplace_component = mld_laplace_component(x_component[[:v]], ()) logpdf_cubature1 = mld_cubature1(x[iv], ()) logpdf_cubature2 = mld_cubature2(x[iv], ())