Skip to content

Commit

Permalink
Don't include lhs of := in results of predict() (#766)
Browse files Browse the repository at this point in the history
* Don't include lhs of := in results of predict()

* Bump minor version

* Remove unused constructor

* Add a test for `values_as_in_model(rng, model, ...)`

---------

Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
penelopeysm and sunxd3 authored Jan 3, 2025
1 parent b7fd9ea commit 3d18cfc
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.32.2"
version = "0.33.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function DynamicPPL.predict(
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
vcat,
Expand Down
25 changes: 14 additions & 11 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ $(TYPEDFIELDS)
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
context::C
end

ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
function ValuesAsInModelContext(context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), context)
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, child)
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
end

is_extracting_values(context::ValuesAsInModelContext) = true
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
function is_extracting_values(context::AbstractContext)
return is_extracting_values(NodeTrait(context), context)
end
Expand Down Expand Up @@ -114,8 +114,8 @@ function dot_tilde_assume(
end

"""
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
Get the values of `varinfo` as they would be seen in the model.
Expand All @@ -132,6 +132,7 @@ of additional model evaluations.
# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
Expand Down Expand Up @@ -183,24 +184,26 @@ false
julia> # Approach 2: Extract realizations using `values_as_in_model`.
# (✓) `values_as_in_model` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function values_as_in_model(
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(context)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
end
11 changes: 9 additions & 2 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,17 @@ module Issue537 end
@test haskey(varinfo, @varname(x))
@test !haskey(varinfo, @varname(y))

# While `values_as_in_model` should contain both `x` and `y`.
values = values_as_in_model(model, deepcopy(varinfo))
# While `values_as_in_model` should contain both `x` and `y`, if
# include_colon_eq is set to `true`.
values = values_as_in_model(model, true, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test haskey(values, @varname(y))

# And if include_colon_eq is set to `false`, then `values` should
# only contain `x`.
values = values_as_in_model(model, false, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test !haskey(values, @varname(y))
end
end

Expand Down
120 changes: 77 additions & 43 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
realizations = values_as_in_model(model, varinfo)
# We can set the include_colon_eq arg to false because none of
# the demo models contain :=. The behaviour when
# include_colon_eq is true is tested in test/compiler.jl
realizations = values_as_in_model(model, false, varinfo)
# Ensure that all variables are found.
vns_found = collect(keys(realizations))
@test vns vns_found == vns vns_found
Expand All @@ -393,6 +396,22 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "check that sampling obeys rng if passed" begin
@model function f()
x ~ Normal(0)
return y ~ Normal(x)
end
model = f()
# Call values_as_in_model with the rng
values = values_as_in_model(Random.Xoshiro(43), model, false)
# Check that they match the values that would be used if vi was seeded
# with that seed instead
expected_vi = VarInfo(Random.Xoshiro(43), model)
for vn in keys(values)
@test values[vn] == expected_vi[vn]
end
end
end

@testset "Erroneous model call" begin
Expand Down Expand Up @@ -432,72 +451,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()

@testset "predict" begin
@testset "with MCMCChains.Chains" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal* x[i], σ)
end
# Insert a := block to test that it is not included in predictions
return σ2 := σ^2
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end

# Construct a chain with 'sampled values' of β
ground_truth_β = 2
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])

# Generate predictions from that chain
xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01

# Ensure that `rng` is respected
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
# Also test a vectorized model
@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end
m_lin_reg_test_vec = linear_reg_vec(xs_test, missing)

# Multiple chains
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)
@testset "variables in chain" begin
# Note that this also checks that variables on the lhs of :=,
# such as σ2, are not included in the resulting chain
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
end

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@testset "accuracy" begin
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
@testset "ensure that rng is respected" begin
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))
end

@testset "accuracy on vectorized model" begin
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
end

@testset "prediction from multiple chains" begin
# Normal linreg model
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
)
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(
mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
end

# Vectorized linreg model
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
end
end
end

@testset "with AbstractVector{<:AbstractVarInfo}" begin
Expand Down

0 comments on commit 3d18cfc

Please sign in to comment.