diff --git a/src/priors.jl b/src/priors.jl index f288abe..e0d4c3a 100644 --- a/src/priors.jl +++ b/src/priors.jl @@ -1,4 +1,5 @@ using FastGaussQuadrature, ForwardDiff +import Distributions: TruncatedNormal, mean, var Φ(x) = 0.5*(1+erf(x/sqrt(2.0))) ϕ(x) = exp(-x.^2/2)/sqrt(2π) @@ -45,35 +46,13 @@ struct IntervalPrior{T<:Real} <: Prior u::T end -function moments(p0::IntervalPrior,μ,σ) - xl = (p0.l - μ)/σ - xu = (p0.u - μ)/σ - minval = min(abs(xl), abs(xu)) - - if xu - xl < 1e-10 - return 0.5 * (xu + xl), -1 - end - - if minval <= 6.0 || xl * xu <= 0 - ϕu, Φu, ϕl, Φl = ϕ(xu), Φ(xu), ϕ(xl), Φ(xl) - av = (ϕl - ϕu) / (Φu - Φl) - mom2 = (xl * ϕl - xu * ϕu) / (Φu - Φl) - else - Δ = (xu^2 - xl^2) * 0.5 - if Δ > 40.0 - av = xl^5 / (3.0 - xl^2 + xl^4) - mom2 = xl^6 / (3.0 - xl^2 + xl^4) - else - eΔ = exp(Δ) - av = (xl * xu)^5 * (1. - eΔ) / (-eΔ * (3.0 - xl^2 + xl^4) * xu^5 + xl^5 * (3.0 - xu^2 + xu^4)) - mom2 = (xl * xu)^5 * (xu - xl * eΔ) / (-eΔ * (3.0 - xl^2 + xl^4) * xu^5 + xl^5 * (3.0 - xu^2 + xu^4)) - end - end - va = mom2 - av^2 - return μ + av * σ, σ^2 * (1 + va) +function moments(p0::IntervalPrior{T},μ,σ) where T<:Real + pr = TruncatedNormal(μ,σ,p0.l,p0.u) + return mean(pr),max(zero(T),var(pr)) + # the max fix an error in Distributions + # issue #827 in Distributions.jl end - """ Spike-and-slab prior