Skip to content

Commit

Permalink
More merging (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Oct 7, 2024
1 parent 6e4340f commit 0554f8b
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 82 deletions.
51 changes: 26 additions & 25 deletions src/SamplingInterface.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@

"""
ReservoirSample([rng], T, method = AlgRSWRSKIP())
ReservoirSample([rng], T, wfunc, method = AlgWRSWRSKIP())
ReservoirSample([rng], T, n::Int, method = AlgL(); ordered = false)
ReservoirSample([rng], T, wfunc, n::Int, method = AlgAExpJ(); ordered = false)
Initializes a reservoir sample which can then be fitted with [`fit!`](@ref).
The first signature represents a sample where only a single element is collected.
A weight function `wfunc` can be passed to apply weighted sampling. Look at the
[`Sampling Algorithms`](@ref) section for the supported methods. If `ordered` is
true, the reservoir sample values can be retrived in the order they were collected
with [`ordvalue`](@ref).
If `ordered` is true, the reservoir sample values can be retrived in the order
they were collected with [`ordvalue`](@ref).
Look at the [`Sampling Algorithms`](@ref) section for the supported methods.
"""
function ReservoirSample(T, method::ReservoirAlgorithm = AlgRSWRSKIP())
return ReservoirSample(Random.default_rng(), T, method, MutSample())
end
function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = AlgRSWRSKIP())
return ReservoirSample(rng, T, method, MutSample())
end
function ReservoirSample(T, wv, method::ReservoirAlgorithm = AlgWRSWRSKIP())
return ReservoirSample(Random.default_rng(), T, wv, method, MutSample())
end
function ReservoirSample(rng::AbstractRNG, T, wv, method::ReservoirAlgorithm = AlgWRSWRSKIP())
return ReservoirSample(rng, T, wv, method, MutSample())
end
Base.@constprop :aggressive function ReservoirSample(T, n::Integer, method::ReservoirAlgorithm=AlgL();
ordered = false)
return ReservoirSample(Random.default_rng(), T, n, method, MutSample(), ordered ? Ord() : Unord())
Expand All @@ -32,21 +24,17 @@ Base.@constprop :aggressive function ReservoirSample(rng::AbstractRNG, T, n::Int
method::ReservoirAlgorithm=AlgL(); ordered = false)
return ReservoirSample(rng, T, n, method, MutSample(), ordered ? Ord() : Unord())
end
Base.@constprop :aggressive function ReservoirSample(T, wv, n::Integer,
method::ReservoirAlgorithm=algAExpJ(); ordered = false)
return ReservoirSample(Random.default_rng(), T, wv, n, method, MutSample(), ordered ? Ord() : Unord())
end
Base.@constprop :aggressive function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer,
method::ReservoirAlgorithm=algAExpJ(); ordered = false)
return ReservoirSample(rng, T, wv, n, method, MutSample(), ordered ? Ord() : Unord())
end

"""
fit!(rs::AbstractReservoirSample, el)
fit!(rs::AbstractReservoirSample, el, w)
Updates the reservoir sample by taking into account the element passed.
If the sampling is weighted also the weight of the elements needs to be
passed.
"""
@inline OnlineStatsBase.fit!(s::AbstractReservoirSample, el) = OnlineStatsBase._fit!(s, el)
@inline OnlineStatsBase.fit!(s::AbstractReservoirSample, el, w) = OnlineStatsBase._fit!(s, el, w)

"""
value(rs::AbstractReservoirSample)
Expand Down Expand Up @@ -166,13 +154,13 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me
end
end
function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter))
s = ReservoirSample(rng, iter_type, wv, method, ImmutSample())
return update_all!(s, iter)
s = ReservoirSample(rng, iter_type, method, ImmutSample())
return update_all!(s, iter, wv)
end
Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, wv::Function, n::Int, method = AlgAExpJ();
iter_type = infer_eltype(iter), ordered = false)
s = ReservoirSample(rng, iter_type, wv, n, method, ImmutSample(), ordered ? Ord() : Unord())
return update_all!(s, iter, ordered)
s = ReservoirSample(rng, iter_type, n, method, ImmutSample(), ordered ? Ord() : Unord())
return update_all!(s, iter, ordered, wv)
end

function update_all!(s, iter)
Expand All @@ -181,9 +169,22 @@ function update_all!(s, iter)
end
return value(s)
end
function update_all!(s, iter, ordered)
function update_all!(s, iter, wv)
for x in iter
s = fit!(s, x, wv(x))
end
return value(s)
end
function update_all!(s, iter, ordered::Bool)
for x in iter
s = fit!(s, x)
end
return ordered ? ordvalue(s) : shuffle!(s.rng, value(s))
end
function update_all!(s, iter, ordered, wv)
for x in iter
s = fit!(s, x, wv(x))
end
return ordered ? ordvalue(s) : shuffle!(s.rng, value(s))
end

88 changes: 52 additions & 36 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,92 @@

const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(last), DataStructures.FasterForward}} where T

@hybrid struct SampleMultiAlgARes{BH,R,F} <: AbstractWeightedWorReservoirSampleMulti
@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti
seen_k::Int
n::Int
const rng::R
const value::BH
wv::F
end
const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}}

@hybrid struct SampleMultiAlgAExpJ{BH,R,F} <: AbstractWeightedWorReservoirSampleMulti
@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
state::Float64
min_priority::Float64
seen_k::Int
const n::Int
const rng::R
const value::BH
wv::F
end
const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}}

@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R,F} <: AbstractWeightedWrReservoirSampleMulti
@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti
state::Float64
skip_w::Float64
seen_k::Int
const rng::R
const weights::Vector{Float64}
const value::Vector{T}
const ord::O
wv::F
end
const SampleMultiOrdAlgWRSWRSKIP = Union{SampleMultiAlgWRSWRSKIP_Immut{<:Vector}, SampleMultiAlgWRSWRSKIP_Mut{<:Vector}}

function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::MutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::MutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value, wv)
return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::MutSample, ::Unord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::MutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value, wv)
return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value, wv)
return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Unord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value, wv)
return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::MutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::MutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return SampleMultiAlgARes_Mut(0, n, rng, value, wv)
return SampleMultiAlgARes_Mut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::MutSample, ::Unord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::MutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return SampleMultiAlgARes_Mut(0, n, rng, value, wv)
return SampleMultiAlgARes_Mut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::ImmutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::ImmutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return SampleMultiAlgARes_Immut(0, n, rng, value, wv)
return SampleMultiAlgARes_Immut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::ImmutSample, ::Unord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::ImmutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return SampleMultiAlgARes_Immut(0, n, rng, value, wv)
return SampleMultiAlgARes_Immut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Ord)
ord = collect(1:n)
return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord, wv)
return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Unord)
return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing, wv)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Unord)
return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Ord)
ord = collect(1:n)
return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord, wv)
return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord)
end
function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Unord)
return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing, wv)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Unord)
return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing)
end

@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el)
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w)
n = s.n
w = s.wv(el)
s = @inline update_state!(s, w)
priority = -randexp(s.rng)/w
if s.seen_k <= n
Expand All @@ -104,9 +100,8 @@ end
end
return s
end
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el)
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el, w)
n = s.n
w = s.wv(el)
s = @inline update_state!(s, w)
if s.seen_k <= n
priority = exp(-randexp(s.rng)/w)
Expand All @@ -122,9 +117,8 @@ end
end
return s
end
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el)
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el, w)
n = length(s.value)
w = s.wv(el)
s = @inline update_state!(s, w)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
Expand Down Expand Up @@ -182,6 +176,28 @@ function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut)
return s
end


function Base.merge(ss::SampleMultiAlgWRSWRSKIP...)
newvalue = reduce_samples(TypeUnion(), ss...)
skip_w = sum(getfield(s, :skip_w) for s in ss)
state = sum(getfield(s, :state) for s in ss)
seen_k = sum(getfield(s, :seen_k) for s in ss)
s = SampleMultiAlgWRSWRSKIP_Mut(state, skip_w, seen_k, ss[1].rng, Float64[], newvalue, nothing)
return s
end

function Base.merge!(s1::SampleMultiAlgWRSWRSKIP{<:Nothing}, ss::SampleMultiAlgWRSWRSKIP...)
newvalue = reduce_samples(TypeS(), s1, ss...)
for i in 1:length(newvalue)
@inbounds s1.value[i] = newvalue[i]
end
s1.skip_w += sum(getfield(s, :skip_w) for s in ss)
s1.state += sum(getfield(s, :state) for s in ss)
s1.seen_k += sum(getfield(s, :seen_k) for s in ss)
empty!(s1.weights)
return s1
end

function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w)
@update s.seen_k += 1
return s
Expand Down
16 changes: 7 additions & 9 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@

@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R,F} <: AbstractWeightedReservoirSampleSingle
@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractWeightedReservoirSampleSingle
seen_k::Int
total_w::Float64
skip_w::Float64
const rng::R
rvalue::RT
wv::F
end

function ReservoirSample(rng::R, T, wv, ::AlgWRSWRSKIP, ::MutSample) where {R<:AbstractRNG}
return SampleSingleAlgWRSWRSKIP_Mut(0, 0.0, 0.0, rng, RefVal_Immut{T}(), wv)
function ReservoirSample(rng::R, T, ::AlgWRSWRSKIP, ::MutSample) where {R<:AbstractRNG}
return SampleSingleAlgWRSWRSKIP_Mut(0, 0.0, 0.0, rng, RefVal_Immut{T}())
end
function ReservoirSample(rng::R, T, wv, ::AlgWRSWRSKIP, ::ImmutSample) where {R<:AbstractRNG}
return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}(), wv)
function ReservoirSample(rng::R, T, ::AlgWRSWRSKIP, ::ImmutSample) where {R<:AbstractRNG}
return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}())
end

function OnlineStatsBase.value(s::AbstractWeightedReservoirSampleSingle)
s.seen_k === 0 && return nothing
return get_value(s)
end

@inline function OnlineStatsBase._fit!(s::SampleSingleAlgWRSWRSKIP, el)
@inline function OnlineStatsBase._fit!(s::SampleSingleAlgWRSWRSKIP, el, w)
@update s.seen_k += 1
weight = s.wv(el)
@update s.total_w += weight
@update s.total_w += w
if s.skip_w <= s.total_w
@update s.skip_w = s.total_w/rand(s.rng)
reset_value!(s, el)
Expand Down
17 changes: 9 additions & 8 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@ using PrecompileTools
iter = Iterators.filter(x -> x != 10, 1:20);
wv(el) = 1.0
update_s!(rs, iter) = for x in iter fit!(rs, x) end
update_s!(rs, iter, wv) = for x in iter fit!(rs, x, wv(x)) end
@compile_workload let
rs = ReservoirSample(Int, AlgRSWRSKIP())
update_s!(rs, iter)
rs = ReservoirSample(Int, wv, AlgWRSWRSKIP())
update_s!(rs, iter)
rs = ReservoirSample(Int, AlgWRSWRSKIP())
update_s!(rs, iter, wv)
rs = ReservoirSample(Int, 2, AlgR())
update_s!(rs, iter)
rs = ReservoirSample(Int, 2, AlgL())
update_s!(rs, iter)
rs = ReservoirSample(Int, 2, AlgRSWRSKIP())
update_s!(rs, iter)
rs = ReservoirSample(Int, wv, 2, AlgARes())
update_s!(rs, iter)
rs = ReservoirSample(Int, wv, 2, AlgAExpJ())
update_s!(rs, iter)
rs = ReservoirSample(Int, wv, 2, AlgWRSWRSKIP())
update_s!(rs, iter)
rs = ReservoirSample(Int, 2, AlgARes())
update_s!(rs, iter, wv)
rs = ReservoirSample(Int, 2, AlgAExpJ())
update_s!(rs, iter, wv)
rs = ReservoirSample(Int, 2, AlgWRSWRSKIP())
update_s!(rs, iter, wv)
end
end
4 changes: 2 additions & 2 deletions test/weighted_sampling_multi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ end
@test ordered ? issorted(s) : true

iter = Iterators.filter(x -> x != b + 1, a:b+1)
rs = ReservoirSample(Int, weight, 5, method; ordered = ordered)
rs = ReservoirSample(Int, 5, method; ordered = ordered)
for x in iter
fit!(rs, x)
fit!(rs, x, weight(x))
end
@test length(value(rs)) == 5
@test all(x -> a <= x <= b, value(rs))
Expand Down
4 changes: 2 additions & 2 deletions test/weighted_sampling_single_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
@test a <= z <= b

iter = Iterators.filter(x -> x != b + 1, a:b+1)
rs = ReservoirSample(Int, wv, method)
rs = ReservoirSample(Int, method)
for x in iter
fit!(rs, x)
fit!(rs, x, wv(x))
end
@test a <= value(rs) <= b
@test nobs(rs) == 100
Expand Down

0 comments on commit 0554f8b

Please sign in to comment.