Skip to content

Commit

Permalink
Fix use of arrays of Distributions (#245)
Browse files Browse the repository at this point in the history
This PR fixes #28 (comment) and allows to use arbitrary arrays of `Distribution`s. This was already allowed in the context implementations but prevented by a check in the code generated by the `@model` macro.

Additionally, the PR replaces the hard-coded check with a `check_tilde_rhs` function which, IMO, makes the code a bit simpler and easier to read. Moreover, a bug in the `dot_assume` implementation for arrays of Distributions is fixed.

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
devmotion and devmotion committed May 18, 2021
1 parent 4c17629 commit f7531ba
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.18"
version = "0.10.19"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
136 changes: 87 additions & 49 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
"Distributions."

const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

Expand Down Expand Up @@ -38,6 +35,20 @@ end
# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

"""
check_tilde_rhs(x)
Check if the right-hand side `x` of a `~` is a `Distribution` or an array of
`Distributions`, then return `x`.
"""
function check_tilde_rhs(@nospecialize(x))
return throw(ArgumentError(
"the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s"
))
end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x

#################
# Main Compiler #
#################
Expand Down Expand Up @@ -225,34 +236,47 @@ Generate an `observe` expression for data variables and `assume` expression for
variables.
"""
function generate_tilde(left, right)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds isassumption
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

# If the LHS is a literal, it is always an observation
if !(left isa Symbol || left isa Expr)
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
__rng__, __context__, __sampler__, $tmpright, $vn, $inds, __varinfo__
)
else
$(DynamicPPL.tilde_observe)(
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
)
end
$(DynamicPPL.tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
)
end
end

# If the LHS is a literal, it is always an observation
# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$(top...)
$(DynamicPPL.tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
__rng__,
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
end
end
end

Expand All @@ -262,34 +286,48 @@ end
Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds isassumption
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

# If the LHS is a literal, it is always an observation
if !(left isa Symbol || left isa Expr)
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
__rng__, __context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
)
else
$(DynamicPPL.dot_tilde_observe)(
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
)
end
$(DynamicPPL.dot_tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
)
end
end

# If the LHS is a literal, it is always an observation
# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$(top...)
$(DynamicPPL.dot_tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
__rng__,
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.dot_tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
end
end
end

Expand Down
4 changes: 1 addition & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,7 @@ function dot_observe(
increment_num_produce!(vi)
@debug "dists = $dists"
@debug "value = $value"
return sum(zip(dists, value)) do (d, v)
Distributions.loglikelihood(d, v)
end
return sum(Distributions.loglikelihood.(dists, value))
end
function dot_observe(
spl::Sampler,
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0"
AbstractPPL = "0.1.2"
AbstractPPL = "0.1.3"
Bijectors = "0.8.2, 0.9"
Distributions = "0.24, 0.25"
DistributionsAD = "0.6.3"
Expand Down
16 changes: 13 additions & 3 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,11 @@ end
vi2 = VarInfo(f2())
vi3 = VarInfo(f3())
@test haskey(vi1.metadata, :y)
@test vi1.metadata.y.vns[1] == VarName(:y)
@test vi1.metadata.y.vns[1] == VarName{:y}()
@test haskey(vi2.metadata, :y)
@test vi2.metadata.y.vns[1] == VarName(:y, ((2,), (Colon(), 1)))
@test vi2.metadata.y.vns[1] == VarName{:y}(((2,), (Colon(), 1)))
@test haskey(vi3.metadata, :y)
@test vi3.metadata.y.vns[1] == VarName(:y, ((1,),))
@test vi3.metadata.y.vns[1] == VarName{:y}(((1,),))
end
@testset "custom tilde" begin
@model demo() = begin
Expand Down Expand Up @@ -313,4 +313,14 @@ end
end
@test demo2()() == 42
end

@testset "check_tilde_rhs" begin
@test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn())

x = Normal()
@test DynamicPPL.check_tilde_rhs(x) === x

x = [Laplace(), Normal(), MvNormal(3, 1.0)]
@test DynamicPPL.check_tilde_rhs(x) === x
end
end
25 changes: 25 additions & 0 deletions test/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,29 @@

test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext())
end

# https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577
@testset "arrays of distributions" begin
@model function test(x, y)
y .~ Normal.(x)
end

for ysize in ((2,), (2, 3), (2, 3, 4))
# drop trailing dimensions
for xsize in ntuple(i -> ysize[1:i], length(ysize))
x = randn(xsize)
y = randn(ysize)
z = logjoint(test(x, y), VarInfo())
@test z sum(logpdf.(Normal.(x), y))
end

# singleton dimensions
for xsize in ntuple(i -> (ysize[1:(i-1)]..., 1, ysize[(i+1):end]...), length(ysize))
x = randn(xsize)
y = randn(ysize)
z = logjoint(test(x, y), VarInfo())
@test z sum(logpdf.(Normal.(x), y))
end
end
end
end

2 comments on commit f7531ba

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36947

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.19 -m "<description of version>" f7531ba2cefadc56be31b30fc092ffdb603d9262
git push origin v0.10.19

Please sign in to comment.