From 54be9e074ea0e8565d31b962ad40fab7d5d54479 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:30:26 +0200 Subject: [PATCH] Improve interface (#97) --- src/StreamSampling.jl | 16 ---------------- src/UnweightedSamplingMulti.jl | 20 ++++++++++---------- src/UnweightedSamplingSingle.jl | 4 ++-- src/WeightedSamplingMulti.jl | 25 ++++++++++++------------- src/WeightedSamplingSingle.jl | 4 ++-- 5 files changed, 26 insertions(+), 43 deletions(-) diff --git a/src/StreamSampling.jl b/src/StreamSampling.jl index 030c87c..e754f8c 100644 --- a/src/StreamSampling.jl +++ b/src/StreamSampling.jl @@ -20,22 +20,6 @@ struct Unord end abstract type AbstractReservoirSample <: OnlineStat{Any} end -# unweighted cases -abstract type AbstractReservoirSampleSingle <: AbstractReservoirSample end -abstract type AbstractReservoirSampleMulti <: AbstractReservoirSample end -abstract type AbstractWorReservoirSampleMulti <: AbstractReservoirSampleMulti end -abstract type AbstractOrdWorReservoirSampleMulti <: AbstractWorReservoirSampleMulti end -abstract type AbstractWrReservoirSampleMulti <: AbstractReservoirSampleMulti end -abstract type AbstractOrdWrReservoirSampleMulti <: AbstractWrReservoirSampleMulti end - -# weighted cases -abstract type AbstractWeightedReservoirSample <: AbstractReservoirSample end -abstract type AbstractWeightedReservoirSampleSingle <: AbstractWeightedReservoirSample end -abstract type AbstractWeightedReservoirSampleMulti <: AbstractWeightedReservoirSample end -abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractWeightedReservoirSample end -abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractWeightedReservoirSample end -abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractWeightedReservoirSample end - abstract type ReservoirAlgorithm end """ diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 93c4253..4fd700d 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -1,5 +1,5 @@ -@hybrid struct SampleMultiAlgR{O,T,R} <: AbstractWorReservoirSampleMulti +@hybrid struct SampleMultiAlgR{O,T,R} <: AbstractReservoirSample seen_k::Int const rng::R const value::Vector{T} @@ -7,7 +7,7 @@ end const SampleMultiOrdAlgR = SampleMultiAlgR{<:Vector} -@hybrid struct SampleMultiAlgL{O,T,R} <: AbstractWorReservoirSampleMulti +@hybrid struct SampleMultiAlgL{O,T,R} <: AbstractReservoirSample state::Float64 skip_k::Int seen_k::Int @@ -17,7 +17,7 @@ const SampleMultiOrdAlgR = SampleMultiAlgR{<:Vector} end const SampleMultiOrdAlgL = SampleMultiAlgL{<:Vector} -@hybrid struct SampleMultiAlgRSWRSKIP{O,T,R} <: AbstractWrReservoirSampleMulti +@hybrid struct SampleMultiAlgRSWRSKIP{O,T,R} <: AbstractReservoirSample skip_k::Int seen_k::Int const rng::R @@ -93,7 +93,7 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::AbstractWrReservoirSampleMulti, el) +@inline function OnlineStatsBase._fit!(s::SampleMultiAlgRSWRSKIP, el) n = length(s.value) s = @inline update_state!(s) if s.seen_k <= n @@ -153,17 +153,17 @@ function update_state!(s::SampleMultiAlgL) @update s.seen_k += 1 return s end -function update_state!(s::AbstractWrReservoirSampleMulti) +function update_state!(s::SampleMultiAlgRSWRSKIP) @update s.seen_k += 1 return s end -function recompute_skip!(s::AbstractWorReservoirSampleMulti, n) +function recompute_skip!(s::SampleMultiAlgL, n) @update s.state += randexp(s.rng) @update s.skip_k = s.seen_k-ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n))) return s end -function recompute_skip!(s::AbstractWrReservoirSampleMulti, n) +function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n) q = rand(s.rng)^(1/n) @update s.skip_k = ceil(Int, s.seen_k/q)-1 return s @@ -182,7 +182,7 @@ function choose(n, p, q, z) return quantile(b, q) end -update_order!(s::AbstractWorReservoirSampleMulti, j) = nothing +update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing function update_order!(s::Union{SampleMultiOrdAlgR, SampleMultiOrdAlgL}, j) s.ord[j] = nobs(s) end @@ -217,14 +217,14 @@ function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRS return s1 end -function OnlineStatsBase.value(s::AbstractWorReservoirSampleMulti) +function OnlineStatsBase.value(s::Union{SampleMultiAlgR, SampleMultiAlgL}) if nobs(s) < length(s.value) return s.value[1:nobs(s)] else return s.value end end -function OnlineStatsBase.value(s::AbstractWrReservoirSampleMulti) +function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP) if nobs(s) < length(s.value) return sample(s.rng, s.value[1:nobs(s)], length(s.value)) else diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index c6e7ac0..ef5bbac 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -1,5 +1,5 @@ -@hybrid struct SampleSingleAlgRSWRSKIP{RT,R} <: AbstractReservoirSampleSingle +@hybrid struct SampleSingleAlgRSWRSKIP{RT,R} <: AbstractReservoirSample seen_k::Int skip_k::Int const rng::R @@ -40,7 +40,7 @@ function Base.empty!(s::SampleSingleAlgRSWRSKIP) return s end -function Base.merge(s1::AbstractReservoirSampleSingle, s2::AbstractReservoirSampleSingle) +function Base.merge(s1::SampleSingleAlgRSWRSKIP, s2::SampleSingleAlgRSWRSKIP) n1, n2 = nobs(s1), nobs(s2) n_tot = n1 + n2 value = rand(s1.rng) < n1/n_tot ? s1.rvalue : s2.rvalue diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 41a4bca..73185c2 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -1,7 +1,7 @@ const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(last), DataStructures.FasterForward}} where T -@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti +@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractReservoirSample seen_k::Int n::Int const rng::R @@ -9,7 +9,7 @@ const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(la end const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}} -@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti +@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractReservoirSample state::Float64 min_priority::Float64 seen_k::Int @@ -19,7 +19,7 @@ const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, Sam end const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}} -@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti +@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractReservoirSample state::Float64 skip_w::Float64 seen_k::Int @@ -100,7 +100,7 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el, w) +@inline function OnlineStatsBase._fit!(s::SampleMultiAlgAExpJ, el, w) n = s.n s = @inline update_state!(s, w) if s.seen_k <= n @@ -117,7 +117,7 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el, w) +@inline function OnlineStatsBase._fit!(s::SampleMultiAlgWRSWRSKIP, el, w) n = length(s.value) s = @inline update_state!(s, w) if s.seen_k <= n @@ -176,7 +176,6 @@ function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut) return s end - function Base.merge(ss::SampleMultiAlgWRSWRSKIP...) newvalue = reduce_samples(TypeUnion(), ss...) skip_w = sum(getfield(s, :skip_w) for s in ss) @@ -198,16 +197,16 @@ function Base.merge!(s1::SampleMultiAlgWRSWRSKIP{<:Nothing}, ss::SampleMultiAlgW return s1 end -function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w) +function update_state!(s::SampleMultiAlgARes, w) @update s.seen_k += 1 return s end -function update_state!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, w) +function update_state!(s::SampleMultiAlgAExpJ, w) @update s.seen_k += 1 @update s.state -= w return s end -function update_state!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, w) +function update_state!(s::SampleMultiAlgWRSWRSKIP, w) @update s.seen_k += 1 @update s.state += w return s @@ -218,12 +217,12 @@ function compute_skip_priority(s, w) return exp(log(rand(s.rng, Uniform(t,1)))/w) end -function recompute_skip!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}) +function recompute_skip!(s::SampleMultiAlgAExpJ) @update s.min_priority = last(first(s.value)) @update s.state = -randexp(s.rng)/log(s.min_priority) return s end -function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, n) +function recompute_skip!(s::SampleMultiAlgWRSWRSKIP, n) q = rand(s.rng)^(1/n) @update s.skip_w = s.state/q return s @@ -248,14 +247,14 @@ end is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true is_ordered(s::SampleMultiAlgWRSWRSKIP) = false -function OnlineStatsBase.value(s::AbstractWeightedWorReservoirSampleMulti) +function OnlineStatsBase.value(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ}) if nobs(s) < s.n return first.(s.value.valtree[1:nobs(s)]) else return first.(s.value.valtree) end end -function OnlineStatsBase.value(s::AbstractWeightedWrReservoirSampleMulti) +function OnlineStatsBase.value(s::SampleMultiAlgWRSWRSKIP) if nobs(s) < length(s.value) return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value)) else diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index bc11738..75291cb 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -1,5 +1,5 @@ -@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractWeightedReservoirSampleSingle +@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractReservoirSample seen_k::Int total_w::Float64 skip_w::Float64 @@ -14,7 +14,7 @@ function ReservoirSample(rng::R, T, ::AlgWRSWRSKIP, ::ImmutSample) where {R<:Abs return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}()) end -function OnlineStatsBase.value(s::AbstractWeightedReservoirSampleSingle) +function OnlineStatsBase.value(s::SampleSingleAlgWRSWRSKIP) s.seen_k === 0 && return nothing return get_value(s) end