diff --git a/LICENSE.md b/LICENSE.md index 04e9dcd..2cd4538 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,7 +1,7 @@ -Julia GaussianEP package Copyright (C) 2019 Alfredo Braunstein, Anna-Paola Muntoni, Andrea Pagnani and Mirko Pieropan (the "Authors"). All Rights Reserved. +Julia GaussianEP package Copyright (C) 2019 Alfredo Braunstein, Giovanni Catania, Anna-Paola Muntoni, Andrea Pagnani and Mirko Pieropan (the "Authors"). All Rights Reserved. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the Licence, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. -The file COPYING contains a copy of the GNU General Public License, also available at [https://www.gnu.org/licenses/gpl-3.0.en.html]. \ No newline at end of file +The file COPYING contains a copy of the GNU General Public License, also available at [https://www.gnu.org/licenses/gpl-3.0.en.html]. diff --git a/src/Factor.jl b/src/Factor.jl new file mode 100644 index 0000000..506de65 --- /dev/null +++ b/src/Factor.jl @@ -0,0 +1,31 @@ +export Factor + + +abstract type Factor end + + +""" + moments!(av, va, p0::T, h, J) where T <:Factor -> (mean, variance) + + input: ``p_0, h, J`` + + output: mean and variance of + + `` p(x) ∝ p_0(x) exp(-½⋅J⋅x² + h⋅x)`` +""" +function moments!(av, va, p0::T, h, J) where T <: Factor + error("undefined moment calculation, assuming uniform prior") + return J\h,J\I +end + +""" + + learn!(p0::T, h, J) -> nothing + + update parameters with a single learning gradient step (learning rate is stored in p0) +""" +function learn!(p0::T, h, J) where T <: Factor + #by default, do nothing + return +end + diff --git a/src/FactorGraph.jl b/src/FactorGraph.jl new file mode 100644 index 0000000..8e227e4 --- /dev/null +++ b/src/FactorGraph.jl @@ -0,0 +1,14 @@ +struct FactorGraph{T<:Real,F<:Factor} + factors::Vector{F} + idx::Vector{Vector{Int}} + N::Int + P::AbstractMatrix{T} + d::AbstractVector{T} +end + + +FactorGraph(factors::Vector{F}, idx::Vector{Vector{Int}}, N::Int) where {F<:Factor} = FactorGraph(factors,idx,N,Diagonal(ones(N)),zeros(N)) + +FactorGraph(factors::Vector{F}, idx::Vector{Vector{Int}}, S::AbstractMatrix{T}, b::AbstractVector{T} = zeros(size(S,1))) where {T<:Real, F<:Factor} = FactorGraph(factors, idx, size(S,2), nullspace(Matrix(S)), S\b) + + diff --git a/src/GaussianEP.jl b/src/GaussianEP.jl index 851cc09..0f3c31a 100644 --- a/src/GaussianEP.jl +++ b/src/GaussianEP.jl @@ -1,13 +1,15 @@ module GaussianEP -export expectation_propagation, Term, EPState, EPOut -export Prior, IntervalPrior, SpikeSlabPrior, BinaryPrior, GaussianPrior, PosteriorPrior, QuadraturePrior, AutoPrior, ThetaPrior - using ExtractMacro, SpecialFunctions, LinearAlgebra -include("Term.jl") -include("priors.jl") -include("expectation_propagation.jl") +export FactorGraph, FactorGauss, EPState, expectation_propagation + include("ProgressReporter.jl") +include("Factor.jl") +include("FactorGraph.jl") +include("expectation_propagation.jl") +include("univariate.jl") +include("multivariate.jl") + end # end module diff --git a/src/Term.jl b/src/Term.jl deleted file mode 100644 index 2c2db88..0000000 --- a/src/Term.jl +++ /dev/null @@ -1,41 +0,0 @@ -""" -This type represents an interaction term in the energy function of the form - -``β_i (\\frac12 x'Ax + x'y + c) + M_i \\log β_i`` - -The complete energy function is given by - -``∑_i β_i (\\frac12 x' A_i x + x' y_i + c_i) + M_i log β_i`` - -as is represented by an Vector{Term}. Note that c and M are only needed for paramenter learning -""" -mutable struct Term{T <: Real} - A::Matrix{T} - y::Vector{T} - c::T - β::T - # for parameter learning - δβ::T - M::Int -end - -Term(A,y,β = 1.0) = Term(A,y,0.0,β,0.0,0) - -function (t::Term)(v::Vector) - return v⋅(t.A*v-2*t.y) + t.c -end - -function updateβ(t::Term{T}, v) where T - if t.δβ > 0 - t.β = t.δβ * t.M / t(v) + (1-t.δβ) * t.β - end -end - -function sum!(A::Matrix{T}, y::Vector{T}, H::Vector{Term{T}}) where T <: Real - fill!(A, zero(T)) - fill!(y, zero(T)) - for i=1:length(H) - A .+= H[i].β * H[i].A - y .+= H[i].β * H[i].y - end -end diff --git a/src/expectation_propagation.jl b/src/expectation_propagation.jl index 59613e1..64b6a0a 100644 --- a/src/expectation_propagation.jl +++ b/src/expectation_propagation.jl @@ -1,77 +1,88 @@ using Random, LinearAlgebra, ExtractMacro -```@meta -CurrentModule = GaussianEP -``` -function update_err!(dst, i, val) - r=abs(val - dst[i]) - dst[i] = val +function update!(old::Array{T}, new::Array{T}, ρ::T = zero(T))::T where {T<:Real} + r = maximum(abs, new - old) + old .*= ρ + old .+= (1 - ρ) * new return r end -""" - Instantaneous state of an expectation propagation run. -""" -struct EPState{T<:AbstractFloat} - A::Matrix{T} - y::Vector{T} - Σ::Matrix{T} - v::Vector{T} - av::Vector{T} - va::Vector{T} - a::Vector{T} - μ::Vector{T} - b::Vector{T} - s::Vector{T} +struct EPState{T <: Real, F <: Factor} + Σ :: Matrix{T} + μ :: Vector{T} + J :: Vector{Matrix{T}} + h :: Vector{Vector{T}} + Jc :: Vector{Matrix{T}} + hc :: Vector{Vector{T}} + Jt :: Vector{Matrix{T}} + ht :: Vector{Vector{T}} + FG :: FactorGraph{T,F} end -EPState{T}(N, Nx = N) where {T <: AbstractFloat} = EPState{T}(Matrix{T}(undef,Nx,Nx), zeros(T,Nx), Matrix{T}(undef,Nx,Nx), zeros(T,Nx),zeros(T,N), zeros(T,N), zeros(T,N), zeros(T,N), ones(T,N), ones(T,N)) -""" -Output of EP algorithm +eye(::Type{T}, n::Integer) where T = Matrix(T(1)*I, n, n) -""" -struct EPOut{T<:AbstractFloat} - av::Vector{T} - va::Vector{T} - μ::Vector{T} - s::Vector{T} - converged::Symbol - state::EPState{T} +function EPState(FG::FactorGraph{T,F}) where {T <: Real, F <: Factor} + d(a) = length(FG.idx[a]) + M,N = length(FG.idx), FG.N + return EPState{T,F}(eye(T, N), zeros(T, N), + [eye(T, d(a)) for a=1:M], [zeros(T, d(a)) for a=1:M], + [eye(T, d(a)) for a=1:M], [zeros(T, d(a)) for a=1:M], + [eye(T, d(a)) for a=1:M], [zeros(T, d(a)) for a=1:M], + FG) end -function EPOut(s, converged::Symbol) where {T <: AbstractFloat} - converged ∈ (:converged,:unconverged) || error("$converged is not a valid symbol") - return EPOut(s.av,s.va, s.μ,s.s,converged,s) + +function update!(state::EPState{T}, ψ::Factor, a::Integer, ρ::T, epsvar::T = zero(T)) where {T <: Real} + @extract state : Σ μ J h Jc hc Jt ht FG + ∂a = FG.idx[a] + # J, h are cavity coeffs + hca, Jca = hc[a], Jc[a] + Jca .= (Σ[∂a, ∂a]+epsvar*I)\I .- J[a] + hca .= Σ[∂a, ∂a]\μ[∂a] .- h[a] + # Jta, hta are moments + hta, Jta = ht[a], Jt[a] + moments!(hta, Jta, ψ, hca, Jca) + # Jta, hta are now total exponents + Jta .= (Jta+epsvar*I)\I + hta .= Jta*hta + # Jta - Jc, hta - hc are new approximated factors + ε = max(update!(J[a], Jta .- Jca, ρ), update!(h[a], hta .- hca, ρ)) + # learn params + learn!(ψ, hca, Jca) + return ε end """ - expectation_propagation(H::Vector{Term{T}}, P0::Vector{Prior}, F::AbstractMatrix{T} = zeros(0,length(P0)), d::Vector{T} = zeros(size(F,1)); - maxiter::Int = 2000, - callback = (x...)->nothing, - # state::EPState{T} = EPState{T}(sum(size(F)), size(F)[2]), - damp::T = 0.9, - epsconv::T = 1e-6, - maxvar::T = 1e50, - minvar::T = 1e-50, - inverter::Function = inv) where {T <: Real, P <: Prior} + expectation_propagation(FG::FactorGraph; + maxiter::Int = 2000, + callback = (state,iter,ε)->nothing, + damp::T = 0.9, + epsconv::T = 1e-6, + maxvar::T = 1e50, + minvar::T = 1e-50, + state::EPState{T} = EPState{T}(FG), + inverter::Function = inv) -> (state, converged, iter, ε) EP for approximate inference of -``P(\\bf{x})=\\frac1Z exp(-\\frac12\\bf{x}' A \\bf{x} + \\bf{x'} \\bf{y}))×\\prod_i p_{i}(x_i)`` +``P(\\bf{x})=\\frac1Z \\prod_a ψ_{a}(x_a)`` Arguments: * `A::Array{Term{T}}`: Gaussian Term (involving only x) -* `P0::Array{Prior}`: Prior terms (involving x and y) -* `F::AbstractMatrix{T}`: If included, the unknown becomes ``(\\bf{x},\\bf{y})^T`` and a term ``\\delta(F \\bf{x}+\\bf{d}-\\bf{y})`` is added. + +Optional Arguments: + +* `P::AbstractMatrix{Prior}`: Projector +* `d::AbstractVector{T}`: Contant shift Optional named arguments: * `maxiter::Int = 2000`: maximum number of iterations * `callback = (x...)->nothing`: your own function to report progress, see [`ProgressReporter`](@ref) * `state::EPState{T} = EPState{T}(sum(size(F)), size(F)[2])`: If supplied, all internal state is updated here -* `damp::T = 0.9`: damping parameter +* `damp::T = 0.0`: damping parameter * `epsconv::T = 1e-6`: convergence criterion * `maxvar::T = 1e50`: maximum variance * `minvar::T = 1e-50`: minimum variance @@ -80,82 +91,62 @@ Optional named arguments: # Example ```jldoctest -julia> t=Term(zeros(2,2),zeros(2),1.0) -Term{Float64}([0.0 0.0; 0.0 0.0], [0.0, 0.0], 0.0, 1.0, 0.0, 0) - -julia> P=[IntervalPrior(i...) for i in [(0,1),(0,1),(-2,2)]] -3-element Array{IntervalPrior{Int64},1}: - IntervalPrior{Int64}(0, 1) - IntervalPrior{Int64}(0, 1) - IntervalPrior{Int64}(-2, 2) +julia> FG=FactorGraph([FactorInterval(a,b) for (a,b) in [(0,1),(0,1),(-2,2)]], [[i] for i=1:3], [1.0 -1.0 -1.0]) +FactorGraph(Factor[FactorInterval{Int64}(0, 1), FactorInterval{Int64}(0, 1), FactorInterval{Int64}(-2, 2)], Array{Int64,1}[[1], [2], [3]], 3) -julia> F=[1.0 -1.0]; +julia> using LinearAlgebra -julia> res = expectation_propagation([t], P, F) -GaussianEP.EPOut{Float64}([0.499997, 0.499997, 3.66527e-15], [0.083325, 0.083325, 0.204301], [0.489862, 0.489862, 3.66599e-15], [334.018, 334.018, 0.204341], :converged, EPState{Float64}([9.79055 -0.00299477; -0.00299477 9.79055], [0.0, 0.0], [0.102139 3.12427e-5; 3.12427e-5 0.102139], [0.489862, 0.489862], [0.499997, 0.499997, 3.66527e-15], [0.083325, 0.083325, 0.204301], [0.490876, 0.490876, -1.86785e-17], [0.489862, 0.489862, 3.66599e-15], [0.100288, 0.100288, 403.599], [334.018, 334.018, 0.204341])) +julia> res = expectation_propagation(FG) +(EPState{Float64}([0.0833329 1.00114e-6 0.0833319; 1.00114e-6 0.0833329 -0.0833319; 0.0833319 -0.0833319 0.166664], [0.499994, 0.499994, 1.39058e-13], Array{Float64,2}[[11.9999], [11.9999], [0.00014416]], Array{Float64,1}[[5.99988], [5.99988], [-1.14443e-13]], Array{Float64,2}[[1.0], [1.0], [1.0]], Array{Float64,1}[[0.0], [0.0], [0.0]], FactorGraph(Factor[FactorInterval{Int64}(0, 1), FactorInterval{Int64}(0, 1), FactorInterval{Int64}(-2, 2)], Array{Int64,1}[[1], [2], [3]], 3)), :converged, 162, 9.829257408000558e-7) ``` + +Note on subspace restriction + +P(x) ∝ ∫dz δ(x-Pz-d) ∏ₐψₐ(xₐ) + x = Pz + d +Q(x) ∝ ∏ₐϕₐ(xₐ) + ∝ exp(-½ xᵀAx + xᵀy) + ∝ ∫dz δ(x-Pz-d) Q(z) +Q(z) ∝ exp(-½ (Pz+d)ᵀA(Pz+d) + (Pz-d)ᵀy) + ∝ exp(-½ zᵀPᵀAPz - zᵀPᵀAd -½dᵀAdᵀ + (zᵀPᵀ-dᵀ)y) + ∝ exp(-½ zᵀPᵀAPz + zᵀ(Pᵀ(y - Ad)) +Σz = (PᵀAP)⁻¹ +μz = (PᵀAP)⁻¹Pᵀ(y-Ad) +Σx = P(PᵀAP)⁻¹Pᵀ +μx = P*Σz + d += P((PᵀAP)⁻¹Pᵀ(y-Ad))+d += Σx(y-Ad)+d """ -function expectation_propagation(H::Vector{Term{T}}, P0::Vector{P}, F::AbstractMatrix{T} = zeros(T,0,length(P0)), d::AbstractVector{T} = zeros(T,size(F,1)); - maxiter::Int = 2000, - callback = (x...)->nothing, - state::EPState{T} = EPState{T}(sum(size(F)), size(F)[2]), - damp::T = 0.9, - epsconv::T = 1e-6, - maxvar::T = 1e50, - minvar::T = 1e-50, - inverter::Function = inv) where {T <: Real, P <: Prior} - @extract state A y Σ v av va a μ b s - Ny,Nx = size(F) - N = Nx + Ny - @assert size(P0,1) == N - Fp = copy(F') +function expectation_propagation(FG::FactorGraph{T,F}; + maxiter::Integer = 2000, + callback = (x...)->nothing, + damp::T = zero(T), + epsconv::T = 1e-6, + inverter = inv, + epsvar::T = zero(T), + state::EPState{T} = EPState(FG)) where {F<:Factor, T<:Real} + + @extract state : Σ μ J h + N, M = FG.N, length(FG.factors) + A, y = zeros(N,N), zeros(N) + ε = 0.0 for iter = 1:maxiter - sum!(A,y,H) - Δμ, Δs, Δav, Δva = 0.0, 0.0, 0.0, 0.0 - A .+= Diagonal(1 ./ b[1:Nx]) .+ Fp * Diagonal(1 ./ b[Nx+1:end]) * F - Σ .= inverter(A) - v .= Σ * (y .+ a[1:Nx] ./ b[1:Nx] .+ Fp * ((a[Nx+1:end]-d) ./ b[Nx+1:end])) - for i in 1:N - if i <= Nx - ss = clamp(Σ[i,i], minvar, maxvar) - vv = v[i] - else - x = Fp[:, i-Nx] - ss = clamp(dot(x, Σ*x), minvar, maxvar) - vv = dot(x, v) + d[i-Nx] - end - - if ss < b[i] - Δs = max(Δs, update_err!(s, i, clamp(1/(1/ss - 1/b[i]), minvar, maxvar))) - Δμ = max(Δμ, update_err!(μ, i, s[i] * (vv/ss - a[i]/b[i]))) - else - ss == b[i] && @warn "infinite var, ss = $ss" - Δs = max(Δs, update_err!(s, i, maxvar)) - Δμ = max(Δμ, update_err!(μ, i, 0)) - end - tav, tva = moments(P0[i], μ[i], sqrt(s[i])); - Δav = max(Δav, update_err!(av, i, tav)) - Δva = max(Δva, update_err!(va, i, tva)) - (isnan(av[i]) || isnan(va[i])) && @warn "avnew = $(av[i]) varnew = $(va[i])" - - new_b = clamp(1/(1/va[i] - 1/s[i]), minvar, maxvar) - new_a = av[i] + new_b * (av[i] - μ[i])/s[i] - a[i] = damp * a[i] + (1 - damp) * new_a - b[i] = damp * b[i] + (1 - damp) * new_b + A .= 0.0 + y .= 0.0 + for a in 1:M + ∂a = FG.idx[a] + A[∂a, ∂a] .+= J[a] + y[∂a] .+= h[a] end - - # learn prior's params - for i in randperm(N) - gradient(P0[i], μ[i], sqrt(s[i])); - end - # learn β params - for i in 1:length(H) - updateβ(H[i], av[1:Nx]) - end - callback(av,Δav,epsconv,maxiter,H,P0) - if Δav < epsconv - return EPOut(state, :converged) + Σ .= FG.P*inverter(FG.P'*A*FG.P)*FG.P' + μ .= Σ * (y .- A*FG.d) .+ FG.d + ε = 0.0 + for a=1:M + ε = max(ε, update!(state, FG.factors[a], a, damp, epsvar)) end + callback(state,iter,ε) != nothing && break + ε < epsconv && return (state, :converged, iter, ε) end - return EPOut(state, :unconverged) + return (state, :unconverged, maxiter, ε) end + diff --git a/src/legacy.jl b/src/legacy.jl new file mode 100644 index 0000000..3ec5179 --- /dev/null +++ b/src/legacy.jl @@ -0,0 +1,145 @@ + +""" +This type represents an interaction term in the energy function of the form + +``β_i (\\frac12 x'Ax + x'y + c) + M_i \\log β_i`` + +The complete energy function is given by + +``∑_i β_i (\\frac12 x' A_i x + x' y_i + c_i) + M_i log β_i`` + +as is represented by an Vector{Term}. Note that c and M are only needed for paramenter learning +""" +mutable struct Term{T <: Real} + A::Matrix{T} + y::Vector{T} + c::T + β::T + # for parameter learning + δβ::T + M::Int +end + +Term(A,y,β = 1.0) = Term(A,y,0.0,β,0.0,0) + +function (t::Term)(v::Vector) + return v⋅(t.A*v-2*t.y) + t.c +end + +function updateβ(t::Term{T}, v) where T + if t.δβ > 0 + t.β = t.δβ * t.M / t(v) + (1-t.δβ) * t.β + end +end + +function sum!(A::Matrix{T}, y::Vector{T}, H::Vector{Term{T}}) where T <: Real + fill!(A, zero(T)) + fill!(y, zero(T)) + for i=1:length(H) + A .+= H[i].β * H[i].A + y .+= H[i].β * H[i].y + end +end +function update_err!(dst, i, val) + r=abs(val - dst[i]) + dst[i] = val + return r +end + +""" + Instantaneous state of an expectation propagation run. +""" +struct EPState{T<:AbstractFloat} + A::Matrix{T} + y::Vector{T} + Σ::Matrix{T} + v::Vector{T} + av::Vector{T} + va::Vector{T} + a::Vector{T} + μ::Vector{T} + b::Vector{T} + s::Vector{T} +end +EPState{T}(N, Nx = N) where {T <: AbstractFloat} = EPState{T}(Matrix{T}(undef,Nx,Nx), zeros(T,Nx), Matrix{T}(undef,Nx,Nx), zeros(T,Nx),zeros(T,N), zeros(T,N), zeros(T,N), zeros(T,N), ones(T,N), ones(T,N)) + +""" +Output of EP algorithm + +""" +struct EPOut{T<:AbstractFloat} + av::Vector{T} + va::Vector{T} + μ::Vector{T} + s::Vector{T} + converged::Symbol + state::EPState{T} +end +function EPOut(s, converged::Symbol) where {T <: AbstractFloat} + converged ∈ (:converged,:unconverged) || error("$converged is not a valid symbol") + return EPOut(s.av,s.va, s.μ,s.s,converged,s) +end +function expectation_propagation_legacy(H::Vector{Term{T}}, P0::Vector{P}, F::AbstractMatrix{T} = zeros(T,0,length(P0)), d::AbstractVector{T} = zeros(T,size(F,1)); + maxiter::Int = 2000, + callback = (x...)->nothing, + state::EPState{T} = EPState{T}(sum(size(F)), size(F)[2]), + damp::T = 0.9, + epsconv::T = 1e-6, + maxvar::T = 1e50, + minvar::T = 1e-50, + inverter::Function = inv) where {T <: Real, P <: Prior} + @extract state A y Σ v av va a μ b s + Ny,Nx = size(F) + N = Nx + Ny + @assert size(P0,1) == N + Fp = copy(F') + for iter = 1:maxiter + sum!(A,y,H) + Δμ, Δs, Δav, Δva = 0.0, 0.0, 0.0, 0.0 + A .+= Diagonal(1 ./ b[1:Nx]) .+ Fp * Diagonal(1 ./ b[Nx+1:end]) * F + Σ .= inverter(A) + v .= Σ * (y .+ a[1:Nx] ./ b[1:Nx] .+ Fp * ((a[Nx+1:end]-d) ./ b[Nx+1:end])) + for i in 1:N + if i <= Nx + ss = clamp(Σ[i,i], minvar, maxvar) + vv = v[i] + else + x = Fp[:, i-Nx] + ss = clamp(dot(x, Σ*x), minvar, maxvar) + vv = dot(x, v) + d[i-Nx] + end + + if ss < b[i] + Δs = max(Δs, update_err!(s, i, clamp(1/(1/ss - 1/b[i]), minvar, maxvar))) + Δμ = max(Δμ, update_err!(μ, i, s[i] * (vv/ss - a[i]/b[i]))) + else + ss == b[i] && @warn "infinite var, ss = $ss" + Δs = max(Δs, update_err!(s, i, maxvar)) + Δμ = max(Δμ, update_err!(μ, i, 0)) + end + tav, tva = moments(P0[i], μ[i], sqrt(s[i])); + Δav = max(Δav, update_err!(av, i, tav)) + Δva = max(Δva, update_err!(va, i, tva)) + (isnan(av[i]) || isnan(va[i])) && @warn "avnew = $(av[i]) varnew = $(va[i])" + + new_b = clamp(1/(1/va[i] - 1/s[i]), minvar, maxvar) + new_a = av[i] + new_b * (av[i] - μ[i])/s[i] + a[i] = damp * a[i] + (1 - damp) * new_a + b[i] = damp * b[i] + (1 - damp) * new_b + end + + # learn prior's params + for i in randperm(N) + gradient(P0[i], μ[i], sqrt(s[i])); + end + # learn β params + for i in 1:length(H) + updateβ(H[i], av[1:Nx]) + end + callback(av,Δav,epsconv,maxiter,H,P0) + if Δav < epsconv + return EPOut(state, :converged) + end + end + return EPOut(state, :unconverged) +end diff --git a/src/multivariate.jl b/src/multivariate.jl new file mode 100644 index 0000000..c7ebcf8 --- /dev/null +++ b/src/multivariate.jl @@ -0,0 +1,21 @@ +struct FactorGauss <: Factor + J::Matrix{Float64} + h::Vector{Float64} + β::Float64 + δβ::Float64 +end + +FactorGauss(J, h, β = 1.0) = FactorGauss(J, h, β, 0.0) + +function update!(state::EPState{T}, ψ::FactorGauss, a::Integer, ρ::T, epsvar::T) where {T<:Real} + @extract state : J h Σ μ + @assert size(J[a]) == size(ψ.J) && size(h[a]) == size(ψ.h) + if ψ.δβ > 0 + ψ.β = ψ.δβ * size(J,1) / (0.5*μ'J*μ-h'μ) + (1-ψ.δβ) * ψ.β + end + J[a][1] == ψ.J[1]*ψ.β && return 0.0 + J[a] .= ψ.J*ψ.β + h[a] .= ψ.h*ψ.β + return 1.0 +end + diff --git a/src/priors.jl b/src/univariate.jl similarity index 54% rename from src/priors.jl rename to src/univariate.jl index bd67544..bc53083 100644 --- a/src/priors.jl +++ b/src/univariate.jl @@ -1,37 +1,12 @@ using FastGaussQuadrature, ForwardDiff -Φ(x) = 0.5*(1+erf(x/sqrt(2.0))) -ϕ(x) = exp(-x.^2/2)/sqrt(2π) +export FactorInterval, FactorSpikeSlab, FactorBinary, FactorGaussian, FactorPosterior, FactorQuadrature, FactorAuto, FactorTheta -""" -Abstract Univariate Prior type -""" -abstract type Prior end -""" - moments(p0::T, μ, σ) where T <:Prior -> (mean, variance) +abstract type FactorUnivariate <: Factor end - input: ``p_0, μ, σ`` +moments!(av::Vector, va::Matrix, ψ::FactorUnivariate, h, J) = (p = moments(ψ, h, J); av[]=p[1]; va[]=p[2]; return) - output: mean and variance of - - `` p(x) ∝ p_0(x) \\mathcal{N}(x;μ,σ) `` -""" -function moments(p0::T, μ, σ) where T <: Prior - error("undefined moment calculation, assuming uniform prior") - return μ,σ^2 -end - -""" - - gradient(p0::T, μ, σ) -> nothing - - update parameters with a single learning gradient step (learning rate is stored in p0) -""" -function gradient(p0::T, μ, σ) where T <: Prior - #by default, do nothing - return -end """ Interval prior @@ -40,19 +15,29 @@ Parameters: l,u `` p_0(x) = \\frac{1}{u-l}\\mathbb{I}[l\\leq x\\leq u] `` """ -struct IntervalPrior{T<:Real} <: Prior +struct FactorInterval{T<:Real} <: FactorUnivariate l::T u::T end -function moments(p0::IntervalPrior,μ,σ) +Φ(x) = 0.5*(1+erf(x/sqrt(2.0))) +ϕ(x) = exp(-x.^2/2)/sqrt(2π) + + +function moments(p0::FactorInterval,h,J) + J, h = J[], h[] + if J <= 0 + return 0.5 * (p0.l + p0.u), (p0.u - p0.l)^2/12 + end + σ = 1/sqrt(J) + μ = σ*h xl = (p0.l - μ)/σ xu = (p0.u - μ)/σ minval = min(abs(xl), abs(xu)) - if xu - xl < 1e-10 - return 0.5 * (xu + xl), -1 - end + #=if xu - xl < 1e-10 + return 0.5 * (xu + xl), (xu - xl)^2/12 + end=# if minval <= 6.0 || xl * xu <= 0 ϕu, Φu, ϕl, Φl = ϕ(xu), Φ(xu), ϕ(xl), Φ(xl) @@ -79,9 +64,9 @@ Spike-and-slab prior Parameters: ρ,λ -`` p_0(x) ∝ (1-ρ) δ(x) + ρ \\mathcal{N}(x;0,λ^{-1}) `` +`` p_0(x) = (1-ρ) δ(x) + ρ \\mathcal{N}(x;0,λ^{-1}) `` """ -mutable struct SpikeSlabPrior{T<:Real} <: Prior +mutable struct FactorSpikeSlab{T<:Real} <: FactorUnivariate ρ::T λ::T δρ::T @@ -92,29 +77,18 @@ end """ ``p = \\frac1{(ℓ+1)((1/ρ-1) e^{-\\frac12 (μ/σ)^2 (2-\\frac1{1+ℓ})}\\sqrt{1+\\frac1{ℓ}}+1)}`` """ -function moments(p0::SpikeSlabPrior,μ,σ) -#= - s2 = σ^2 - d = 1 + p0.λ * s2; - sd = 1 / (1/s2 + p0.λ); - n = μ^2/(2*d*s2); - Z = sqrt(p0.λ * sd) * p0.ρ; - f = 1 + (1-p0.ρ) * exp(-n) / Z; - av = μ / (d * f); - va = (sd + (μ / d)^2 ) / f - av^2; - #p0 = (1 - p0.params.ρ) * exp(-n) / (Z + (1-p0.params.ρ).*exp(-n)); - =# - ℓ0 = p0.λ * σ^2 - ℓ = 1 + ℓ0; - z = ℓ * (1 + (1/p0.ρ-1) * exp(-0.5*(μ/σ)^2/ℓ) * sqrt(ℓ/ℓ0)) - av = μ / z; - va = (σ^2 + μ^2*(1/ℓ - 1/z)) / z; - return av, va +function moments(p0::FactorSpikeSlab,h,J) + J, h = J[], h[] + l = J + p0.λ + z = l * (1 + (1/p0.ρ-1) * exp(-h^2/2l) * sqrt(l/p0.λ)) + return h / z, (1 + h^2*(1/l - 1/z)) / z end -function gradient(p0::SpikeSlabPrior, μ, σ) - s = σ^2 +function learn!(p0::FactorSpikeSlab, h, J) + J, h = J[], h[] + s = 1/J + μ = h/J d = 1 + p0.λ * s; q = sqrt(p0.λ * s / d); f = exp(-μ^2 / (2s*d)); @@ -130,72 +104,62 @@ function gradient(p0::SpikeSlabPrior, μ, σ) p0.λ += p0.δλ * num/den; p0.λ = max(p0.λ, 0) end + nothing end """ -Binary Prior +Binary Factor p_0(x) ∝ ρ δ(x-x_0) + (1-ρ) δ(x-x_1) """ -struct BinaryPrior{T<:Real} <: Prior +struct FactorBinary{T<:Real} <: FactorUnivariate x0::T x1::T ρ::T end -function moments(p0::BinaryPrior, μ, σ) - arg = -(σ^2 / 2) * (-p0.x0^2 - 2*(p0.x1 -p0.x0) * μ + p0.x1^2); - earg = exp(arg) - Z = p0.ρ / earg + (1-p0.ρ); - av = p0.ρ * p0.x0 / earg + (1-p0.ρ) * p0.x1; - mom2 = p0.ρ * (p0.x0^2) / earg + (1-p0.ρ) * (p0.x1^2); +function moments(p0::FactorBinary, h, J) + J = J[1]; h = h[1] + w = exp((-1/2*J*(p0.x0+p0.x1)+h)*(p0.x0-p0.x1)) + Z = p0.ρ *w + (1-p0.ρ); + av = p0.ρ * p0.x0 * w + (1-p0.ρ) * p0.x1; + mom2 = p0.ρ * (p0.x0^2) * w + (1-p0.ρ) * (p0.x1^2); if (isnan(Z) || isinf(Z)) - Z = p0.ρ + (1-p0.ρ) * earg; - av = p0.ρ * p0.x0 + (1-p0.ρ) * p0.x1 * earg; - mom2 = p0.ρ * (p0.x0^2) + (1-p0.ρ) * p0.x1 * earg; + Z = p0.ρ + (1-p0.ρ) / w; + av = p0.ρ * p0.x0 + (1-p0.ρ) * p0.x1 / w; + mom2 = p0.ρ * (p0.x0^2) + (1-p0.ρ) * p0.x1 / w; end av /= Z; mom2 /= Z; va = mom2 - av.^2; - return av,va -end - - -struct GaussianPrior{T<:Real} <: Prior - μ::T - β::T - δβ::T + return av, va end -function moments(p0::GaussianPrior, μ, σ) - s = 1/(1/σ^2 + p0.β) - return s*(μ/σ^2 + p0.μ * p0.β), s -end """ -This is a fake Prior that can be used to fix experimental moments +This is a fake Factor that can be used to fix experimental moments Parameters: μ, v (variance, not std) """ -struct PosteriorPrior{T<:Real} <: Prior +struct FactorPosterior{T<:Real} <: FactorUnivariate μ::T v::T end -function moments(p0::PosteriorPrior, μ, σ) +function moments(p0::FactorPosterior, h, J) return p0.μ,p0.v end -struct QuadraturePrior{T<:Real} <: Prior +struct FactorQuadrature{T<:Real} <: FactorUnivariate f X::Vector{T} W0::Vector{T} W1::Vector{T} W2::Vector{T} - function (QuadraturePrior)(f; a::T=-1.0, b::T=1.0, points::Int64=1000) where {T<:Real} + function (FactorQuadrature)(f; a::T=-1.0, b::T=1.0, points::Int64=1000) where {T<:Real} X,W = gausslegendre(points) X = 0.5*(X*(b-a).+(b+a)) W .*= 0.5*(b-a) @@ -206,15 +170,16 @@ struct QuadraturePrior{T<:Real} <: Prior end end -function moments(p0::QuadraturePrior, μ, σ) - v = map(x->exp(-(x-μ)^2/2σ^2),p0.X) +function moments(p0::FactorQuadrature, h, J) + J = J[]; h = h[] + v = map(x->exp(-J/2*x^2+h*x),p0.X) z0 = v ⋅ p0.W0 av = (v ⋅ p0.W1)/z0 va = (v ⋅ p0.W2)/z0 - av^2 - return av, va + return av,va end -mutable struct AutoPrior{T<:Real} <: Prior +mutable struct FactorAuto{T<:Real} <: FactorUnivariate #real arguments f P::Vector{T} @@ -228,7 +193,7 @@ mutable struct AutoPrior{T<:Real} <: Prior FXW::Vector{T} DFXW::Matrix{T} cfg::Any - function (AutoPrior)(f, P::Vector{T}, dP::Vector{T} = zeros(length(P)), a=-1, b=1, points=1000) where {T<:Real} + function (FactorAuto)(f, P::Vector{T}, dP::Vector{T} = zeros(length(P)), a=-1, b=1, points=1000) where {T<:Real} X,W = gausslegendre(points) X = 0.5*(X*(b-a).+(b+a)) W .*= 0.5*(b-a) @@ -237,17 +202,17 @@ mutable struct AutoPrior{T<:Real} <: Prior end end -function moments(p0::AutoPrior, μ, σ) +function moments(p0::FactorAuto, h, J) do_update!(p0) - s22 = 2σ^2 - v = p0.FXW .* map(x->exp(-(x-μ)^2 / s22), p0.X) + J = J[1]; h = h[1] + v = p0.FXW .* map(x->exp(-J/2 * x^2 + h*x), p0.X) v .*= 1/sum(v) av = v ⋅ p0.X va = (v ⋅ p0.X2) - av^2 - return av, va + return av,va end -function do_update!(p0::AutoPrior) +function do_update!(p0::FactorAuto) p0.P == p0.oldP && return copy!(p0.FXW, p0.f.([[x;p0.P] for x in p0.X]) .* p0.W) for i in 1:length(p0.X) @@ -256,8 +221,10 @@ function do_update!(p0::AutoPrior) copy!(p0.oldP, p0.P) end -function gradient(p0::AutoPrior, μ, σ) - s22 = 2σ^2 +function learn!(p0::FactorAuto, h, J) + J = J[]; h = h[] + s22 = 2/J + μ = h/J v = map(x->exp(-(x-μ)^2 / s22), p0.X) z = sum(v) v ./= z @@ -279,13 +246,15 @@ end """ A θ(x) prior """ -struct ThetaPrior <: Prior end +struct FactorTheta <: FactorUnivariate end -function moments(::ThetaPrior,μ,σ) - α=μ/σ - av=μ+pdf_cf(α)*σ - var=σ^2*(1-α*pdf_cf(α)-pdf_cf(α)^2) - return av,var +function moments(::FactorTheta,h,J) + J, h= J[], h[] + μ = h/J + α = h/sqrt(J) + av = μ+pdf_cf(α)/sqrt(J) + va = 1/J*(1-α*pdf_cf(α)-pdf_cf(α)^2) + return av,va end diff --git a/test/ep.jl b/test/ep.jl index 71d9aea..0ed96ed 100644 --- a/test/ep.jl +++ b/test/ep.jl @@ -1,20 +1,18 @@ module TestEP -using GaussianEP, Test +using GaussianEP, Test, LinearAlgebra function simple_ep_test() - t=Term(zeros(2,2),zeros(2),1.0) - P=[IntervalPrior(i...) for i in [(0,1),(0,1),(-2,2)]] - F=[1.0 -1.0] - av0 = [0.4999974709003177,0.4999974709003177,3.665273196564082e-15] - va0 = [0.08332501737195087, 0.08332501737195087, 0.2043006364495929] - μ0 = [0.4898618134668008,0.4898618134668008,3.665993043660408e-15] - s0 = [334.0179053087342,334.0179053087342,0.20434113796777062] - res = expectation_propagation([t], P, F) - @test sum(abs, res.av - av0) < 1e-9 - @test sum(abs, res.va - va0) < 1e-9 - @test sum(abs, res.μ - μ0) < 1e-9 - @test sum(abs, res.s - s0) < 1e-9 - # test precision to 1e-9 to comply with 32 bit + N = 3 + factors = [FactorInterval(a,b) for (a,b) in [(0,1),(0,1),(-2,2)]] + idx = [[i] for i in 1:N] + S = [1.0 -1.0 -1.0] + FG = FactorGraph(factors, idx, S) + av0 = Float64[1/2, 1/2, 0] + va0 = Float64[1/12, 1/12, 1/6] + state,status,iter,ε = expectation_propagation(FG, epsconv=1e-8) + @test state.μ ≈ av0 atol=1e-5 + @test diag(state.Σ) ≈ va0 atol=1e-5 + @test status === :converged end diff --git a/test/priors.jl b/test/priors.jl index a3ba944..888333a 100644 --- a/test/priors.jl +++ b/test/priors.jl @@ -3,18 +3,20 @@ using GaussianEP,Test # test spike and slab -spikeandslabmom(μ,σ,ρ,λ) = GaussianEP.moments(SpikeSlabPrior(ρ,λ,0.0,0.0),μ,σ) +spikeandslabmom(μ,σ,ρ,λ) = GaussianEP.moments(FactorSpikeSlab(ρ,λ,0.0,0.0),μ/σ^2,1/σ^2) + +col((a,b)) = [a[],b[]] function spike_and_slab_test() - @test isapprox.(spikeandslabmom(-1.9,13.0,1.0,0.12), (-0.0892857142857146,7.94172932330828),atol=1e-12) == (true,true) - @test isapprox.(spikeandslabmom(0.0,13.0,0.0,0.12), (0.0,0.0)) == (true,true) - @test isapprox.(spikeandslabmom(6.0,2.0,0.2,2.0), (0.1865693402764403,0.2139510016374009),atol=1e-12) == (true,true) + @test col(spikeandslabmom(-1.9,13.0,1.0,0.12)) ≈ [-0.0892857142857146,7.94172932330828] atol=1e-12 + @test col(spikeandslabmom(0.0,13.0,0.0,0.12)) ≈ [0.0,0.0] atol=1e-12 + @test col(spikeandslabmom(6.0,2.0,0.2,2.0)) ≈ [0.1865693402764403,0.2139510016374009] atol=1e-12 end -intervalmom(μ,σ,lb,ub) = GaussianEP.moments(IntervalPrior(lb,ub),μ,σ) +intervalmom(μ,σ,lb,ub) = GaussianEP.moments(FactorInterval(lb,ub),μ/σ^2,1/σ^2) function uniform_test() - @test isapprox.(intervalmom(-4.0,1.0,0.0,1000.0), (0.2256071444894706679029638,0.0466728383974225474739583319205),atol=1e-12) == (true,true) - @test isapprox.(intervalmom(-5.0,1.0,0.0,1000.0), (0.1865039670969923790710964794,0.0326964346120545146234803723928052),atol=1e-8) == (true,true) # this test require a lower precision ... maybe a better asymptotic expansion + @test col(intervalmom(-4.0,1.0,0.0,1000.0)) ≈ [0.2256071444894706679029638,0.0466728383974225474739583319205] atol=1e-12 + @test col(intervalmom(-5.0,1.0,0.0,1000.0)) ≈ [0.1865039670969923790710964794,0.0326964346120545146234803723928052] atol=1e-8 # this test require a lower precision ... maybe a better asymptotic expansion end spike_and_slab_test()