diff --git a/src/primitives/logdensity.jl b/src/primitives/logdensity.jl index 81fa4cd..4d13367 100644 --- a/src/primitives/logdensity.jl +++ b/src/primitives/logdensity.jl @@ -43,3 +43,55 @@ end end end +############################################################################### + + +function MeasureBase.logdensity_def(c::ConditionalModel{A,B,M}, x=NamedTuple()) where {A,B,M} + _logdensity_def(M, Model(c), argvals(c), observations(c), x) +end + +export sourceLogdensityDef + +sourceLogdensityDef(m::AbstractModel) = sourceLogdensityDef()(Model(m)) + +# function Base.convert(nt::NamedTuple, args...) +# @show nt +# @show args +# for (n, t) in enumerate(stacktrace()) +# print("\t",n,". ") +# println(t) +# end + +# end + +function sourceLogdensityDef() + function(_m::Model) + proc(_m, st :: Assign) = :($(st.x) = $(st.rhs)) + proc(_m, st :: Return) = nothing + proc(_m, st :: LineNumber) = nothing + function proc(_m, st :: Sample) + x = st.x + rhs = st.rhs + @q begin + _ℓ += Soss.logdensity_def($rhs, $x) + $x = Soss.predict($rhs, $x) + end + end + + wrap(kernel) = @q begin + _ℓ = 0.0 + $kernel + return _ℓ + end + + buildSource(_m, proc, wrap) |> MacroTools.flatten + end +end + +@gg function _logdensity_def(M::Type{<:TypeLevel}, _m::Model, _args, _data, _pars) + body = type2model(_m) |> sourceLogdensityDef() |> loadvals(_args, _data, _pars) + @under_global from_type(_unwrap_type(M)) @q let M + $body + end +end +