Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dot_tilde pipeline #804

Merged
merged 18 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,74 @@

**Breaking**

### `.~` right hand side must be a univariate distribution

Previously we allowed statements like

```julia
x .~ [Normal(), Gamma()]
```

where the right hand side of a `.~` was an array of distributions, and ones like

```julia
x .~ MvNormal(fill(0.0, 2), I)
```

where the right hand side was a multivariate distribution.

These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
x .~ Normal()
```

The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read.

If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of

```julia
x .~ [Normal(), Gamma()]
x .~ Normal.(y)
x .~ MvNormal(fill(0.0, 2), I)
```

do

```julia
x ~ product_distribution([Normal(), Gamma()])
x ~ product_distribution(Normal.(y))
x ~ MvNormal(fill(0.0, 2), I)
```

This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as

```julia
dists = Normal.(y)
for i in 1:length(dists)
x[i] ~ dists[i]
end
```

Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example,

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
x .~ MvNormal(fill(0, 2), I)
```

should be replaced with something like

```julia
x = Array{Float64,3}(2, 3, 4)
for i in 1:3, j in 1:4
x[:, i, j] ~ MvNormal(fill(0, 2), I)
end
```

This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side.

### Remove indexing by samplers

This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,
Expand All @@ -14,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia
- `unflatten` no longer accepts a sampler as an argument
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
- `keys(::VarInfo)` no longer accepts a sampler as an argument
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.
- `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument.

### Reverse prefixing order

Expand Down
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -35,7 +34,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
Expand All @@ -44,7 +42,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
ADTypes = "1"
Expand Down Expand Up @@ -74,5 +71,4 @@ OrderedCollections = "1"
Random = "1.6"
Requires = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"
2 changes: 0 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,8 @@ DynamicPPL.Experimental.is_suitable_varinfo

```@docs
tilde_assume
dot_tilde_assume
```

```@docs
tilde_observe
dot_tilde_observe
```
25 changes: 0 additions & 25 deletions ext/DynamicPPLZygoteRulesExt.jl

This file was deleted.

4 changes: 0 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,9 @@ export AbstractVarInfo,
PrefixContext,
ConditionContext,
assume,
dot_assume,
observe,
dot_observe,
tilde_assume,
tilde_observe,
dot_tilde_assume,
dot_tilde_observe,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down
95 changes: 43 additions & 52 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,16 @@
"""
isliteral(e) = false
isliteral(::Number) = true
isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args)
function isliteral(e::Expr)
# In the special case that the expression is of the form `abc[blahblah]`, we consider it
# to be a literal if `abc` is a literal. This is necessary for cases like
# [1.0, 2.0][idx...] ~ Normal()
# which are generated when turning `.~` expressions into loops over `~` expressions.
if e.head == :ref
return isliteral(e.args[1])
end
return !isempty(e.args) && all(isliteral, e.args)
end

"""
check_tilde_rhs(x)
Expand All @@ -172,7 +181,7 @@
function check_tilde_rhs(@nospecialize(x))
return throw(
ArgumentError(
"the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s",
"the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel",
),
)
end
Expand All @@ -184,6 +193,31 @@
return Sampleable{typeof(model),AutoPrefix}(model)
end

"""
check_dot_tilde_rhs(x)

Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`.
"""
function check_dot_tilde_rhs(@nospecialize(x))
return throw(

Check warning on line 202 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L201-L202

Added lines #L201 - L202 were not covered by tests
ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`")
)
end
function check_dot_tilde_rhs(::AbstractArray{<:Distribution})
msg = """

Check warning on line 207 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L206-L207

Added lines #L206 - L207 were not covered by tests
As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \
Please use `product_distribution` instead, or write a loop if necessary. \
See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \
details.\
"""
return throw(ArgumentError(msg))

Check warning on line 213 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L213

Added line #L213 was not covered by tests
end
check_dot_tilde_rhs(x::UnivariateDistribution) = x
function check_dot_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
model = check_dot_tilde_rhs(x.model)
return Sampleable{typeof(model),AutoPrefix}(model)

Check warning on line 218 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L216-L218

Added lines #L216 - L218 were not covered by tests
end

"""
unwrap_right_vn(right, vn)

Expand Down Expand Up @@ -356,11 +390,8 @@
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return Base.remove_linenums!(
generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
),
return generate_mainbody!(
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
)
end

Expand Down Expand Up @@ -487,56 +518,16 @@
Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right)
isliteral(left) && return generate_tilde_literal(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption value
@gensym dist left_axes idx
return quote
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left .= $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end

$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
$dist = DynamicPPL.check_dot_tilde_rhs($right)
$left_axes = axes($left)
for $idx in Iterators.product($left_axes...)
$left[$idx...] ~ $dist
end
end
end

function generate_dot_tilde_assume(left, right, vn)
# We don't need to use `Setfield.@set` here since
# `.=` is always going to be inplace + needs `left` to
# be something that supports `.=`.
@gensym value
return quote
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
)
$left .= $value
$value
end
end

# Note that we cannot use `MacroTools.isdef` because
# of https://github.com/FluxML/MacroTools.jl/issues/154.
"""
Expand Down
Loading
Loading