Skip to content

Commit

Permalink
Improve interface (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Oct 7, 2024
1 parent f4985a7 commit 54be9e0
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 43 deletions.
16 changes: 0 additions & 16 deletions src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,6 @@ struct Unord end

abstract type AbstractReservoirSample <: OnlineStat{Any} end

# unweighted cases
abstract type AbstractReservoirSampleSingle <: AbstractReservoirSample end
abstract type AbstractReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWorReservoirSampleMulti <: AbstractReservoirSampleMulti end
abstract type AbstractOrdWorReservoirSampleMulti <: AbstractWorReservoirSampleMulti end
abstract type AbstractWrReservoirSampleMulti <: AbstractReservoirSampleMulti end
abstract type AbstractOrdWrReservoirSampleMulti <: AbstractWrReservoirSampleMulti end

# weighted cases
abstract type AbstractWeightedReservoirSample <: AbstractReservoirSample end
abstract type AbstractWeightedReservoirSampleSingle <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedReservoirSampleMulti <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractWeightedReservoirSample end

abstract type ReservoirAlgorithm end

"""
Expand Down
20 changes: 10 additions & 10 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

@hybrid struct SampleMultiAlgR{O,T,R} <: AbstractWorReservoirSampleMulti
@hybrid struct SampleMultiAlgR{O,T,R} <: AbstractReservoirSample
seen_k::Int
const rng::R
const value::Vector{T}
const ord::O
end
const SampleMultiOrdAlgR = SampleMultiAlgR{<:Vector}

@hybrid struct SampleMultiAlgL{O,T,R} <: AbstractWorReservoirSampleMulti
@hybrid struct SampleMultiAlgL{O,T,R} <: AbstractReservoirSample
state::Float64
skip_k::Int
seen_k::Int
Expand All @@ -17,7 +17,7 @@ const SampleMultiOrdAlgR = SampleMultiAlgR{<:Vector}
end
const SampleMultiOrdAlgL = SampleMultiAlgL{<:Vector}

@hybrid struct SampleMultiAlgRSWRSKIP{O,T,R} <: AbstractWrReservoirSampleMulti
@hybrid struct SampleMultiAlgRSWRSKIP{O,T,R} <: AbstractReservoirSample
skip_k::Int
seen_k::Int
const rng::R
Expand Down Expand Up @@ -93,7 +93,7 @@ end
end
return s
end
@inline function OnlineStatsBase._fit!(s::AbstractWrReservoirSampleMulti, el)
@inline function OnlineStatsBase._fit!(s::SampleMultiAlgRSWRSKIP, el)
n = length(s.value)
s = @inline update_state!(s)
if s.seen_k <= n
Expand Down Expand Up @@ -153,17 +153,17 @@ function update_state!(s::SampleMultiAlgL)
@update s.seen_k += 1
return s
end
function update_state!(s::AbstractWrReservoirSampleMulti)
function update_state!(s::SampleMultiAlgRSWRSKIP)
@update s.seen_k += 1
return s
end

function recompute_skip!(s::AbstractWorReservoirSampleMulti, n)
function recompute_skip!(s::SampleMultiAlgL, n)
@update s.state += randexp(s.rng)
@update s.skip_k = s.seen_k-ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
return s
end
function recompute_skip!(s::AbstractWrReservoirSampleMulti, n)
function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n)
q = rand(s.rng)^(1/n)
@update s.skip_k = ceil(Int, s.seen_k/q)-1
return s
Expand All @@ -182,7 +182,7 @@ function choose(n, p, q, z)
return quantile(b, q)
end

update_order!(s::AbstractWorReservoirSampleMulti, j) = nothing
update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing
function update_order!(s::Union{SampleMultiOrdAlgR, SampleMultiOrdAlgL}, j)
s.ord[j] = nobs(s)
end
Expand Down Expand Up @@ -217,14 +217,14 @@ function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRS
return s1
end

function OnlineStatsBase.value(s::AbstractWorReservoirSampleMulti)
function OnlineStatsBase.value(s::Union{SampleMultiAlgR, SampleMultiAlgL})
if nobs(s) < length(s.value)
return s.value[1:nobs(s)]
else
return s.value
end
end
function OnlineStatsBase.value(s::AbstractWrReservoirSampleMulti)
function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP)
if nobs(s) < length(s.value)
return sample(s.rng, s.value[1:nobs(s)], length(s.value))
else
Expand Down
4 changes: 2 additions & 2 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@hybrid struct SampleSingleAlgRSWRSKIP{RT,R} <: AbstractReservoirSampleSingle
@hybrid struct SampleSingleAlgRSWRSKIP{RT,R} <: AbstractReservoirSample
seen_k::Int
skip_k::Int
const rng::R
Expand Down Expand Up @@ -40,7 +40,7 @@ function Base.empty!(s::SampleSingleAlgRSWRSKIP)
return s
end

function Base.merge(s1::AbstractReservoirSampleSingle, s2::AbstractReservoirSampleSingle)
function Base.merge(s1::SampleSingleAlgRSWRSKIP, s2::SampleSingleAlgRSWRSKIP)
n1, n2 = nobs(s1), nobs(s2)
n_tot = n1 + n2
value = rand(s1.rng) < n1/n_tot ? s1.rvalue : s2.rvalue
Expand Down
25 changes: 12 additions & 13 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@

const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(last), DataStructures.FasterForward}} where T

@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti
@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractReservoirSample
seen_k::Int
n::Int
const rng::R
const value::BH
end
const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}}

@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractReservoirSample
state::Float64
min_priority::Float64
seen_k::Int
Expand All @@ -19,7 +19,7 @@ const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, Sam
end
const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}}

@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti
@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractReservoirSample
state::Float64
skip_w::Float64
seen_k::Int
Expand Down Expand Up @@ -100,7 +100,7 @@ end
end
return s
end
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el, w)
@inline function OnlineStatsBase._fit!(s::SampleMultiAlgAExpJ, el, w)
n = s.n
s = @inline update_state!(s, w)
if s.seen_k <= n
Expand All @@ -117,7 +117,7 @@ end
end
return s
end
@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el, w)
@inline function OnlineStatsBase._fit!(s::SampleMultiAlgWRSWRSKIP, el, w)
n = length(s.value)
s = @inline update_state!(s, w)
if s.seen_k <= n
Expand Down Expand Up @@ -176,7 +176,6 @@ function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut)
return s
end


function Base.merge(ss::SampleMultiAlgWRSWRSKIP...)
newvalue = reduce_samples(TypeUnion(), ss...)
skip_w = sum(getfield(s, :skip_w) for s in ss)
Expand All @@ -198,16 +197,16 @@ function Base.merge!(s1::SampleMultiAlgWRSWRSKIP{<:Nothing}, ss::SampleMultiAlgW
return s1
end

function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w)
function update_state!(s::SampleMultiAlgARes, w)
@update s.seen_k += 1
return s
end
function update_state!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, w)
function update_state!(s::SampleMultiAlgAExpJ, w)
@update s.seen_k += 1
@update s.state -= w
return s
end
function update_state!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, w)
function update_state!(s::SampleMultiAlgWRSWRSKIP, w)
@update s.seen_k += 1
@update s.state += w
return s
Expand All @@ -218,12 +217,12 @@ function compute_skip_priority(s, w)
return exp(log(rand(s.rng, Uniform(t,1)))/w)
end

function recompute_skip!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ})
function recompute_skip!(s::SampleMultiAlgAExpJ)
@update s.min_priority = last(first(s.value))
@update s.state = -randexp(s.rng)/log(s.min_priority)
return s
end
function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, n)
function recompute_skip!(s::SampleMultiAlgWRSWRSKIP, n)
q = rand(s.rng)^(1/n)
@update s.skip_w = s.state/q
return s
Expand All @@ -248,14 +247,14 @@ end
is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true
is_ordered(s::SampleMultiAlgWRSWRSKIP) = false

function OnlineStatsBase.value(s::AbstractWeightedWorReservoirSampleMulti)
function OnlineStatsBase.value(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ})
if nobs(s) < s.n
return first.(s.value.valtree[1:nobs(s)])
else
return first.(s.value.valtree)
end
end
function OnlineStatsBase.value(s::AbstractWeightedWrReservoirSampleMulti)
function OnlineStatsBase.value(s::SampleMultiAlgWRSWRSKIP)
if nobs(s) < length(s.value)
return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value))
else
Expand Down
4 changes: 2 additions & 2 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractWeightedReservoirSampleSingle
@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractReservoirSample
seen_k::Int
total_w::Float64
skip_w::Float64
Expand All @@ -14,7 +14,7 @@ function ReservoirSample(rng::R, T, ::AlgWRSWRSKIP, ::ImmutSample) where {R<:Abs
return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}())
end

function OnlineStatsBase.value(s::AbstractWeightedReservoirSampleSingle)
function OnlineStatsBase.value(s::SampleSingleAlgWRSWRSKIP)
s.seen_k === 0 && return nothing
return get_value(s)
end
Expand Down

0 comments on commit 54be9e0

Please sign in to comment.