From 0554f8bde29a97f9336ddedd7ca803aebb897b96 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Mon, 7 Oct 2024 03:08:11 +0200 Subject: [PATCH] More merging (#96) --- src/SamplingInterface.jl | 51 +++++++-------- src/WeightedSamplingMulti.jl | 88 +++++++++++++++----------- src/WeightedSamplingSingle.jl | 16 ++--- src/precompile.jl | 17 ++--- test/weighted_sampling_multi_tests.jl | 4 +- test/weighted_sampling_single_tests.jl | 4 +- 6 files changed, 98 insertions(+), 82 deletions(-) diff --git a/src/SamplingInterface.jl b/src/SamplingInterface.jl index af9b725..7f87e6c 100644 --- a/src/SamplingInterface.jl +++ b/src/SamplingInterface.jl @@ -1,16 +1,14 @@ """ ReservoirSample([rng], T, method = AlgRSWRSKIP()) - ReservoirSample([rng], T, wfunc, method = AlgWRSWRSKIP()) ReservoirSample([rng], T, n::Int, method = AlgL(); ordered = false) - ReservoirSample([rng], T, wfunc, n::Int, method = AlgAExpJ(); ordered = false) Initializes a reservoir sample which can then be fitted with [`fit!`](@ref). The first signature represents a sample where only a single element is collected. -A weight function `wfunc` can be passed to apply weighted sampling. Look at the -[`Sampling Algorithms`](@ref) section for the supported methods. If `ordered` is -true, the reservoir sample values can be retrived in the order they were collected -with [`ordvalue`](@ref). +If `ordered` is true, the reservoir sample values can be retrived in the order +they were collected with [`ordvalue`](@ref). + +Look at the [`Sampling Algorithms`](@ref) section for the supported methods. """ function ReservoirSample(T, method::ReservoirAlgorithm = AlgRSWRSKIP()) return ReservoirSample(Random.default_rng(), T, method, MutSample()) @@ -18,12 +16,6 @@ end function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = AlgRSWRSKIP()) return ReservoirSample(rng, T, method, MutSample()) end -function ReservoirSample(T, wv, method::ReservoirAlgorithm = AlgWRSWRSKIP()) - return ReservoirSample(Random.default_rng(), T, wv, method, MutSample()) -end -function ReservoirSample(rng::AbstractRNG, T, wv, method::ReservoirAlgorithm = AlgWRSWRSKIP()) - return ReservoirSample(rng, T, wv, method, MutSample()) -end Base.@constprop :aggressive function ReservoirSample(T, n::Integer, method::ReservoirAlgorithm=AlgL(); ordered = false) return ReservoirSample(Random.default_rng(), T, n, method, MutSample(), ordered ? Ord() : Unord()) @@ -32,21 +24,17 @@ Base.@constprop :aggressive function ReservoirSample(rng::AbstractRNG, T, n::Int method::ReservoirAlgorithm=AlgL(); ordered = false) return ReservoirSample(rng, T, n, method, MutSample(), ordered ? Ord() : Unord()) end -Base.@constprop :aggressive function ReservoirSample(T, wv, n::Integer, - method::ReservoirAlgorithm=algAExpJ(); ordered = false) - return ReservoirSample(Random.default_rng(), T, wv, n, method, MutSample(), ordered ? Ord() : Unord()) -end -Base.@constprop :aggressive function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, - method::ReservoirAlgorithm=algAExpJ(); ordered = false) - return ReservoirSample(rng, T, wv, n, method, MutSample(), ordered ? Ord() : Unord()) -end """ fit!(rs::AbstractReservoirSample, el) + fit!(rs::AbstractReservoirSample, el, w) Updates the reservoir sample by taking into account the element passed. +If the sampling is weighted also the weight of the elements needs to be +passed. """ @inline OnlineStatsBase.fit!(s::AbstractReservoirSample, el) = OnlineStatsBase._fit!(s, el) +@inline OnlineStatsBase.fit!(s::AbstractReservoirSample, el, w) = OnlineStatsBase._fit!(s, el, w) """ value(rs::AbstractReservoirSample) @@ -166,13 +154,13 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me end end function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter)) - s = ReservoirSample(rng, iter_type, wv, method, ImmutSample()) - return update_all!(s, iter) + s = ReservoirSample(rng, iter_type, method, ImmutSample()) + return update_all!(s, iter, wv) end Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, wv::Function, n::Int, method = AlgAExpJ(); iter_type = infer_eltype(iter), ordered = false) - s = ReservoirSample(rng, iter_type, wv, n, method, ImmutSample(), ordered ? Ord() : Unord()) - return update_all!(s, iter, ordered) + s = ReservoirSample(rng, iter_type, n, method, ImmutSample(), ordered ? Ord() : Unord()) + return update_all!(s, iter, ordered, wv) end function update_all!(s, iter) @@ -181,9 +169,22 @@ function update_all!(s, iter) end return value(s) end -function update_all!(s, iter, ordered) +function update_all!(s, iter, wv) + for x in iter + s = fit!(s, x, wv(x)) + end + return value(s) +end +function update_all!(s, iter, ordered::Bool) for x in iter s = fit!(s, x) end return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) end +function update_all!(s, iter, ordered, wv) + for x in iter + s = fit!(s, x, wv(x)) + end + return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) +end + diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 5fcae95..41a4bca 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -1,27 +1,25 @@ const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(last), DataStructures.FasterForward}} where T -@hybrid struct SampleMultiAlgARes{BH,R,F} <: AbstractWeightedWorReservoirSampleMulti +@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti seen_k::Int n::Int const rng::R const value::BH - wv::F end const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}} -@hybrid struct SampleMultiAlgAExpJ{BH,R,F} <: AbstractWeightedWorReservoirSampleMulti +@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti state::Float64 min_priority::Float64 seen_k::Int const n::Int const rng::R const value::BH - wv::F end const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}} -@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R,F} <: AbstractWeightedWrReservoirSampleMulti +@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti state::Float64 skip_w::Float64 seen_k::Int @@ -29,68 +27,66 @@ const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, S const weights::Vector{Float64} const value::Vector{T} const ord::O - wv::F end const SampleMultiOrdAlgWRSWRSKIP = Union{SampleMultiAlgWRSWRSKIP_Immut{<:Vector}, SampleMultiAlgWRSWRSKIP_Mut{<:Vector}} -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::MutSample, ::Ord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::MutSample, ::Ord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value, wv) + return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::MutSample, ::Unord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::MutSample, ::Unord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value, wv) + return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Ord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Ord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value, wv) + return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Unord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Unord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value, wv) + return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::MutSample, ::Ord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::MutSample, ::Ord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) sizehint!(value, n) - return SampleMultiAlgARes_Mut(0, n, rng, value, wv) + return SampleMultiAlgARes_Mut(0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::MutSample, ::Unord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::MutSample, ::Unord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) sizehint!(value, n) - return SampleMultiAlgARes_Mut(0, n, rng, value, wv) + return SampleMultiAlgARes_Mut(0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::ImmutSample, ::Ord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::ImmutSample, ::Ord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) sizehint!(value, n) - return SampleMultiAlgARes_Immut(0, n, rng, value, wv) + return SampleMultiAlgARes_Immut(0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgARes, ::ImmutSample, ::Unord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::ImmutSample, ::Unord) value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) sizehint!(value, n) - return SampleMultiAlgARes_Immut(0, n, rng, value, wv) + return SampleMultiAlgARes_Immut(0, n, rng, value) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Ord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Ord) ord = collect(1:n) - return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord, wv) + return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Unord) - return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing, wv) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Unord) + return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Ord) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Ord) ord = collect(1:n) - return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord, wv) + return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) end -function ReservoirSample(rng::AbstractRNG, T, wv, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Unord) - return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing, wv) +function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Unord) + return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) end -@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el) +@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w) n = s.n - w = s.wv(el) s = @inline update_state!(s, w) priority = -randexp(s.rng)/w if s.seen_k <= n @@ -104,9 +100,8 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el) +@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el, w) n = s.n - w = s.wv(el) s = @inline update_state!(s, w) if s.seen_k <= n priority = exp(-randexp(s.rng)/w) @@ -122,9 +117,8 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el) +@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el, w) n = length(s.value) - w = s.wv(el) s = @inline update_state!(s, w) if s.seen_k <= n @inbounds s.value[s.seen_k] = el @@ -182,6 +176,28 @@ function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut) return s end + +function Base.merge(ss::SampleMultiAlgWRSWRSKIP...) + newvalue = reduce_samples(TypeUnion(), ss...) + skip_w = sum(getfield(s, :skip_w) for s in ss) + state = sum(getfield(s, :state) for s in ss) + seen_k = sum(getfield(s, :seen_k) for s in ss) + s = SampleMultiAlgWRSWRSKIP_Mut(state, skip_w, seen_k, ss[1].rng, Float64[], newvalue, nothing) + return s +end + +function Base.merge!(s1::SampleMultiAlgWRSWRSKIP{<:Nothing}, ss::SampleMultiAlgWRSWRSKIP...) + newvalue = reduce_samples(TypeS(), s1, ss...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] + end + s1.skip_w += sum(getfield(s, :skip_w) for s in ss) + s1.state += sum(getfield(s, :state) for s in ss) + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + empty!(s1.weights) + return s1 +end + function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w) @update s.seen_k += 1 return s diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index 7c2d3e0..bc11738 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -1,18 +1,17 @@ -@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R,F} <: AbstractWeightedReservoirSampleSingle +@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractWeightedReservoirSampleSingle seen_k::Int total_w::Float64 skip_w::Float64 const rng::R rvalue::RT - wv::F end -function ReservoirSample(rng::R, T, wv, ::AlgWRSWRSKIP, ::MutSample) where {R<:AbstractRNG} - return SampleSingleAlgWRSWRSKIP_Mut(0, 0.0, 0.0, rng, RefVal_Immut{T}(), wv) +function ReservoirSample(rng::R, T, ::AlgWRSWRSKIP, ::MutSample) where {R<:AbstractRNG} + return SampleSingleAlgWRSWRSKIP_Mut(0, 0.0, 0.0, rng, RefVal_Immut{T}()) end -function ReservoirSample(rng::R, T, wv, ::AlgWRSWRSKIP, ::ImmutSample) where {R<:AbstractRNG} - return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}(), wv) +function ReservoirSample(rng::R, T, ::AlgWRSWRSKIP, ::ImmutSample) where {R<:AbstractRNG} + return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}()) end function OnlineStatsBase.value(s::AbstractWeightedReservoirSampleSingle) @@ -20,10 +19,9 @@ function OnlineStatsBase.value(s::AbstractWeightedReservoirSampleSingle) return get_value(s) end -@inline function OnlineStatsBase._fit!(s::SampleSingleAlgWRSWRSKIP, el) +@inline function OnlineStatsBase._fit!(s::SampleSingleAlgWRSWRSKIP, el, w) @update s.seen_k += 1 - weight = s.wv(el) - @update s.total_w += weight + @update s.total_w += w if s.skip_w <= s.total_w @update s.skip_w = s.total_w/rand(s.rng) reset_value!(s, el) diff --git a/src/precompile.jl b/src/precompile.jl index ad2178b..5c69139 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -5,22 +5,23 @@ using PrecompileTools iter = Iterators.filter(x -> x != 10, 1:20); wv(el) = 1.0 update_s!(rs, iter) = for x in iter fit!(rs, x) end + update_s!(rs, iter, wv) = for x in iter fit!(rs, x, wv(x)) end @compile_workload let rs = ReservoirSample(Int, AlgRSWRSKIP()) update_s!(rs, iter) - rs = ReservoirSample(Int, wv, AlgWRSWRSKIP()) - update_s!(rs, iter) + rs = ReservoirSample(Int, AlgWRSWRSKIP()) + update_s!(rs, iter, wv) rs = ReservoirSample(Int, 2, AlgR()) update_s!(rs, iter) rs = ReservoirSample(Int, 2, AlgL()) update_s!(rs, iter) rs = ReservoirSample(Int, 2, AlgRSWRSKIP()) update_s!(rs, iter) - rs = ReservoirSample(Int, wv, 2, AlgARes()) - update_s!(rs, iter) - rs = ReservoirSample(Int, wv, 2, AlgAExpJ()) - update_s!(rs, iter) - rs = ReservoirSample(Int, wv, 2, AlgWRSWRSKIP()) - update_s!(rs, iter) + rs = ReservoirSample(Int, 2, AlgARes()) + update_s!(rs, iter, wv) + rs = ReservoirSample(Int, 2, AlgAExpJ()) + update_s!(rs, iter, wv) + rs = ReservoirSample(Int, 2, AlgWRSWRSKIP()) + update_s!(rs, iter, wv) end end diff --git a/test/weighted_sampling_multi_tests.jl b/test/weighted_sampling_multi_tests.jl index c6988bd..5969485 100644 --- a/test/weighted_sampling_multi_tests.jl +++ b/test/weighted_sampling_multi_tests.jl @@ -67,9 +67,9 @@ end @test ordered ? issorted(s) : true iter = Iterators.filter(x -> x != b + 1, a:b+1) - rs = ReservoirSample(Int, weight, 5, method; ordered = ordered) + rs = ReservoirSample(Int, 5, method; ordered = ordered) for x in iter - fit!(rs, x) + fit!(rs, x, weight(x)) end @test length(value(rs)) == 5 @test all(x -> a <= x <= b, value(rs)) diff --git a/test/weighted_sampling_single_tests.jl b/test/weighted_sampling_single_tests.jl index cc3c0af..a500973 100644 --- a/test/weighted_sampling_single_tests.jl +++ b/test/weighted_sampling_single_tests.jl @@ -9,9 +9,9 @@ @test a <= z <= b iter = Iterators.filter(x -> x != b + 1, a:b+1) - rs = ReservoirSample(Int, wv, method) + rs = ReservoirSample(Int, method) for x in iter - fit!(rs, x) + fit!(rs, x, wv(x)) end @test a <= value(rs) <= b @test nobs(rs) == 100