Skip to content

Commit

Permalink
Make joint prior forward map type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed Jan 20, 2025
1 parent 86dd9b0 commit 24c8ff2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
40 changes: 20 additions & 20 deletions src/likelihoods/joint_prior.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Joint prior
"""
JointPrior{modelPriorType<:AbstractSimulatorPrior,likPriorTypes,axesType,lnames} <: AbstractSimulatorPrior
JointPrior{modelPriorType<:AbstractSimulatorPrior,likNames,likPriorTypes,axesType} <: AbstractSimulatorPrior
Represents the "joint" prior `p(θₘ,θₗ)` where `θ = [θₘ θₗ]` are the full set of parameters in the joint;
distribution `p(x,θ)`. θₘ are the model (simulator) parameters and θₗ are the noise/error model parameters.
"""
struct JointPrior{modelPriorType<:AbstractSimulatorPrior,likPriorTypes,axesType,lnames} <: AbstractSimulatorPrior
struct JointPrior{modelPriorType<:AbstractSimulatorPrior,likNames,likPriorTypes,axesType} <: AbstractSimulatorPrior
model::modelPriorType
lik::NamedTuple{lnames,likPriorTypes}
lik::NamedTuple{likNames,likPriorTypes}
ax::axesType
end

Expand Down Expand Up @@ -56,25 +56,25 @@ function logprob(jp::JointPrior, θ::ComponentVector)
end
logprob(jp::JointPrior, θ::AbstractVector) = logprob(jp, ComponentVector(θ, jp.ax))

function forward_map(jp::JointPrior, θ::ComponentVector)
ϕ_m = forward_map(jp.model, θ.model)
ϕ_lik = map(n -> forward_map(jp.lik[n], θ[n]), keys(jp.lik))
ϕ = vcat(ϕ_m, ϕ_lik...)
return ComponentVector(ϕ, jp.ax)
@generated function forward_map(jp::JointPrior{<:Any,lnames}, θ::ComponentVector) where {lnames}
ϕ_args = map(lnames) do n
:(forward_map(jp.lik[$(QuoteNode(n))], θ[$(QuoteNode(n))]))
end
quote
ϕ_m = forward_map(jp.model, θ.model)
ϕ = vcat(ϕ_m, $(ϕ_args...))
return ComponentVector(ϕ, jp.ax)
end
end
forward_map(jp::JointPrior, θ::AbstractVector) = forward_map(jp, ComponentVector(θ, jp.ax))

function unconstrained_forward_map(jp::JointPrior, ζ::ComponentVector)
# get inverse bijectors
f_m = inverse(bijector(jp.model))
f_lik = map(inverse bijector, jp.lik)
# apply bijections
θ_m = f_m.model)
θ_lik = ComponentVector(; map(n -> n => f_lik[n](ζ[n]), keys(jp.lik))...)
# apply forward maps
ϕ_m = forward_map(jp.model, θ_m)
ϕ_lik = map(n -> forward_map(jp.lik[n], θ_lik[n]), keys(jp.lik))
ϕ = vcat(ϕ_m, ϕ_lik...)
return ComponentVector(ϕ, jp.ax)
f = inverse(bijector(jp))
θ = ComponentArray(f(ζ), jp.ax)
# apply forward map
return forward_map(jp.model, θ)
end

function unconstrained_forward_map(jp::JointPrior, θ::AbstractVector)
return unconstrained_forward_map(jp, ComponentVector(θ, jp.ax))
end
unconstrained_forward_map(jp::JointPrior, θ::AbstractVector) = unconstrained_forward_map(jp, ComponentVector(θ, jp.ax))
20 changes: 20 additions & 0 deletions test/likelihood_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using SimulationBasedInference

using Random
using Test

@testset "Joint Prior" begin
observable = SimulatorObservable(:test, identity, 0.0, 0.0:1.0, (1,))
p_prior = prior(:p, LogNormal(0,1))
noise_scale_prior = prior(, LogNormal(0,1))
data = randn(MersenneTwister(1234), 10)
lik = SimulatorLikelihood(IsoNormal, observable, data, noise_scale_prior)
jp = JointPrior(p_prior, lik)
ξ = rand(MersenneTwister(1234), jp)
@test length(ξ) == 2
@test hasproperty(ξ, :model)
@test hasproperty(ξ, :test)
@test hasproperty.model, :p)
θ = @inferred SBI.unconstrained_forward_map(jp, [0.0,0.0])
@test θ [1.0,1.0]
end
3 changes: 2 additions & 1 deletion test/problem_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using LinearAlgebra
using LogDensityProblems
using NonlinearSolve
using OrdinaryDiffEq
using Random
using Test

@testset "Forward ODEProblem" begin
Expand Down Expand Up @@ -40,7 +41,7 @@ end
forwardprob = SimulatorForwardProblem(odeprob, observable)
α_prior = prior(, LogNormal(0,1))
noise_scale_prior = prior(, Exponential(1.0))
data = randn(10)
data = randn(MersenneTwister(1234), 10)
lik = SimulatorLikelihood(IsoNormal, observable, data, noise_scale_prior)
inferenceprob = SimulatorInferenceProblem(forwardprob, Tsit5(), α_prior, lik)
u = copy(inferenceprob.u0)
Expand Down

0 comments on commit 24c8ff2

Please sign in to comment.