Skip to content

Commit

Permalink
Add Random.AbstractRNG type annotations (fixing dot_tilde_assume ambi…
Browse files Browse the repository at this point in the history
…guity)
  • Loading branch information
penelopeysm committed Jan 8, 2025
1 parent 653c9c5 commit bcb52e0
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi)
return left, acclogp_observe!!(context, vi, logp)
end

function assume(rng, spl::Sampler, dist)
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

Expand Down Expand Up @@ -291,14 +291,18 @@ end
function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi)
return dot_assume(right, left, vns, vi)
end
function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi)
function dot_tilde_assume(
::IsLeaf, rng::Random.AbstractRNG, ::AbstractContext, sampler, right, left, vns, vi
)
return dot_assume(rng, sampler, right, vns, left, vi)
end

function dot_tilde_assume(::IsParent, context::AbstractContext, args...)
return dot_tilde_assume(childcontext(context), args...)
end
function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...)
function dot_tilde_assume(
::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args...
)
return dot_tilde_assume(rng, childcontext(context), args...)
end

Expand Down Expand Up @@ -371,7 +375,7 @@ function dot_assume(
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dist::MultivariateDistribution,
vns::AbstractVector{<:VarName},
Expand Down Expand Up @@ -404,7 +408,7 @@ function dot_assume(
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vns::AbstractArray{<:VarName},
Expand All @@ -416,7 +420,9 @@ function dot_assume(
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
function dot_assume(
rng::Random.AbstractRNG, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any
)
return error(
"[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement"
)
Expand All @@ -436,7 +442,7 @@ function _maybe_invlink_broadcast(vi, vn, dist)
end

function get_and_set_val!(
rng,
rng::Random.AbstractRNG,
vi::VarInfoOrThreadSafeVarInfo,
vns::AbstractVector{<:VarName},
dist::MultivariateDistribution,
Expand Down Expand Up @@ -478,7 +484,7 @@ function get_and_set_val!(
end

function get_and_set_val!(
rng,
rng::Random.AbstractRNG,
vi::VarInfoOrThreadSafeVarInfo,
vns::AbstractArray{<:VarName},
dists::Union{Distribution,AbstractArray{<:Distribution}},
Expand Down

0 comments on commit bcb52e0

Please sign in to comment.