Skip to content

Commit

Permalink
Start working on OnlineStatsBase Interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Sep 14, 2024
1 parent 8ee4f55 commit 3fb5fa5
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 82 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "StreamSampling"
uuid = "ff63dad9-3335-55d8-95ec-f8139d39e468"
version = "0.3.8"
version = "0.3.9"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -15,6 +16,7 @@ julia = "1.8"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
OnlineStatsBase = "1"
PrecompileTools = "1"
Random = "1"
StatsBase = "0.32, 0.33, 0.34"
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ consult the [ReadMe](https://github.com/JuliaDynamics/StreamSampling.jl).
```@docs
ReservoirSample
update!
reset!
empty!
value
ordered_value
itsample
Expand Down
12 changes: 2 additions & 10 deletions src/SortedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ function sortedindices_sample(rng, iter, n::Int;
if replace
return sample(rng, reservoir, n, ordered=ordered)
else
if ordered
return reservoir
else
return shuffle!(rng, reservoir)
end
return ordered ? reservoir : shuffle!(rng, reservoir)
end
end
reservoir = Vector{iter_type}(undef, n)
Expand All @@ -36,11 +32,7 @@ function sortedindices_sample(rng, iter, n::Int;
end
i += 1
end
if ordered
return reservoir
else
return shuffle!(rng, reservoir)
end
return ordered ? reservoir : shuffle!(rng, reservoir)
end

function skip_ahead_no_end(iter, state, n)
Expand Down
51 changes: 24 additions & 27 deletions src/StreamSampling.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module StreamSampling

using Accessors
import Accessors
using DataStructures
using Distributions
using OnlineStatsBase
using Random
using StatsBase

Expand Down Expand Up @@ -39,18 +40,6 @@ struct AlgARes <: ReservoirAlgorithm end
struct AlgAExpJ <: ReservoirAlgorithm end
struct AlgWRSWRSKIP <: ReservoirAlgorithm end


macro imm_reset(e)
s = e.args[1].args[1]
esc(quote
if ismutabletype(typeof($s))
$e
else
StreamSampling.Accessors.@reset $e
end
end)
end

"""
Implements random sampling without replacement.
Expand Down Expand Up @@ -99,16 +88,18 @@ const algWRSWRSKIP = AlgWRSWRSKIP()

export algL, algR, algRSWRSKIP, algARes, algAExpJ, algWRSWRSKIP

include("SortedSamplingSingle.jl")
include("SortedSamplingMulti.jl")
include("UnweightedSamplingSingle.jl")
include("UnweightedSamplingMulti.jl")
include("WeightedSamplingSingle.jl")
include("WeightedSamplingMulti.jl")
include("precompile.jl")
macro reset(e)
s = e.args[1].args[1]
esc(quote
if ismutabletype(typeof($s))
$e
else
$StreamSampling.Accessors.@reset $e
end
end)
end

"""
ReservoirSample([rng], T, method = algL)
ReservoirSample([rng], T, n::Int, method = algL; ordered = false)
Expand All @@ -121,7 +112,6 @@ function ReservoirSample end
export ReservoirSample

"""
update!(rs::AbstractReservoirSample, el)
update!(rs::AbstractReservoirSample, el, w::Float64)
Expand All @@ -134,17 +124,16 @@ function update! end
export update!

"""
reset!(rs::AbstractReservoirSample)
Base.empty!(rs::AbstractReservoirSample)
Resets the reservoir sample to its initial state.
Useful to avoid allocating a new sample in some cases.
"""
function reset! end

export reset!
function Base.empty!(::AbstractReservoirSample)
error("Abstract Version")
end

"""
value(rs::AbstractReservoirSample)
Returns the elements collected in the sample at the current
Expand Down Expand Up @@ -208,4 +197,12 @@ before starting the sampling.
"""
function sortedindices_sample end

include("SortedSamplingSingle.jl")
include("SortedSamplingMulti.jl")
include("UnweightedSamplingSingle.jl")
include("UnweightedSamplingMulti.jl")
include("WeightedSamplingSingle.jl")
include("WeightedSamplingMulti.jl")
include("precompile.jl")

end
10 changes: 5 additions & 5 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ end
s = @inline update_state!(s)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
if s.seen_k == n
if s.seen_k === n
s = @inline recompute_skip!(s, n)
end
elseif s.skip_k < 0
Expand All @@ -115,7 +115,7 @@ end
s = @inline update_state!(s)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
if s.seen_k == n
if s.seen_k === n
s = recompute_skip!(s, n)
new_values = sample(s.rng, s.value, n, ordered=is_ordered(s))
@inbounds for i in 1:n
Expand Down Expand Up @@ -146,17 +146,17 @@ end
return s
end

function reset!(s::Union{SampleMultiAlgR, SampleMultiOrdAlgR})
function Base.empty!(s::Union{SampleMultiAlgR, SampleMultiOrdAlgR})
s.seen_k = 0
return s
end
function reset!(s::Union{SampleMultiAlgL, SampleMultiOrdAlgL})
function Base.empty!(s::Union{SampleMultiAlgL, SampleMultiOrdAlgL})
s.state = 0.0
s.skip_k = 0
s.seen_k = 0
return s
end
function reset!(s::Union{SampleMultiAlgRSWRSKIP, SampleMultiOrdAlgRSWRSKIP})
function Base.empty!(s::Union{SampleMultiAlgRSWRSKIP, SampleMultiOrdAlgRSWRSKIP})
s.skip_k = 0
s.seen_k = 0
return s
Expand Down
2 changes: 1 addition & 1 deletion src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
return s
end

function reset!(s::SampleSingleAlgR)
function Base.empty!(s::SampleSingleAlgR)
s.seen_k = 0
s.skip_k = 0
return s
Expand Down
22 changes: 11 additions & 11 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,39 +242,39 @@ end
return s
end

function reset!(s::Union{MutSampleMultiAlgARes, MutSampleMultiOrdAlgARes})
function Base.empty!(s::Union{MutSampleMultiAlgARes, MutSampleMultiOrdAlgARes})
s.seen_k = 0
empty!(s.value)
sizehint!(s.value, s.n)
return s
end
function reset!(s::Union{MutSampleMultiAlgAExpJ, MutSampleMultiOrdAlgAExpJ})
function Base.empty!(s::Union{MutSampleMultiAlgAExpJ, MutSampleMultiOrdAlgAExpJ})
s.state = 0.0
s.min_priority = 0.0
s.seen_k = 0
empty!(s.value)
sizehint!(s.value, s.n)
return s
end
function reset!(s::Union{MutSampleMultiAlgWRSWRSKIP, MutSampleMultiOrdAlgWRSWRSKIP})
function Base.empty!(s::Union{MutSampleMultiAlgWRSWRSKIP, MutSampleMultiOrdAlgWRSWRSKIP})
s.state = 0.0
s.skip_w = 0.0
s.seen_k = 0
return s
end

function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w)
@imm_reset s.seen_k += 1
@reset s.seen_k += 1
return s
end
function update_state!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, w)
@imm_reset s.seen_k += 1
@imm_reset s.state -= w
@reset s.seen_k += 1
@reset s.state -= w
return s
end
function update_state!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, w)
@imm_reset s.seen_k += 1
@imm_reset s.state += w
@reset s.seen_k += 1
@reset s.state += w
return s
end

Expand All @@ -284,13 +284,13 @@ function compute_skip_priority(s, w)
end

function recompute_skip!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ})
@imm_reset s.min_priority = last(first(s.value))
@imm_reset s.state = -randexp(s.rng)/log(s.min_priority)
@reset s.min_priority = last(first(s.value))
@reset s.state = -randexp(s.rng)/log(s.min_priority)
return s
end
function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, n)
q = rand(s.rng)^(1/n)
@imm_reset s.skip_w = s.state/q
@reset s.skip_w = s.state/q
return s
end

Expand Down
36 changes: 18 additions & 18 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ mutable struct RefVal{T}
end

struct ImmutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
total_w::Float64
skip_w::Float64
rng::R
rvalue::RefVal{T}
end
mutable struct MutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
total_w::Float64
skip_w::Float64
const rng::R
value::T
MutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
MutSampleSingleAlgAExpJ{T,R}(total_w, skip_w, rng) where {T,R} = new{T,R}(total_w, skip_w, rng)
end
const SampleSingleAlgAExpJ = Union{ImmutSampleSingleAlgAExpJ, MutSampleSingleAlgAExpJ}

Expand All @@ -28,41 +28,41 @@ function ReservoirSample(rng::R, T, ::AlgAExpJ, ::ImmutSample) where {R<:Abstrac
end

function value(s::AbstractWeightedReservoirSampleSingle)
s.state === 0.0 && return nothing
return get_val(s)
s.total_w === 0.0 && return nothing
return get_value(s)
end

@inline function update!(s::SampleSingleAlgAExpJ, el, weight)
@imm_reset s.state += weight
if s.skip_w <= s.state
@imm_reset s.skip_w = s.state/rand(s.rng)
s = set_val(s, el)
@reset s.total_w += weight
if s.skip_w <= s.total_w
@reset s.skip_w = s.total_w/rand(s.rng)
s = reset_value!(s, el)
end
return s
end

function reset!(s::MutSampleSingleAlgAExpJ)
s.state = 0.0
function Base.empty!(s::MutSampleSingleAlgAExpJ)
s.total_w = 0.0
s.skip_w = 0.0
return s
end

get_val(s::ImmutSampleSingleAlgAExpJ) = s.rvalue.value
function set_val(s::ImmutSampleSingleAlgAExpJ, el)
@reset s.rvalue.value = el
get_value(s::MutSampleSingleAlgAExpJ) = s.value
get_value(s::ImmutSampleSingleAlgAExpJ) = s.rvalue.value

function reset_value!(s::MutSampleSingleAlgAExpJ, el)
s.value = el
return s
end
get_val(s::MutSampleSingleAlgAExpJ) = s.value
function set_val(s::MutSampleSingleAlgAExpJ, el)
s.value = el
function reset_value!(s::ImmutSampleSingleAlgAExpJ, el)
@reset s.rvalue.value = el
return s
end

function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
return itsample(Random.default_rng(), iter, wv, method)
end

function itsample(rng::AbstractRNG, iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
s = ReservoirSample(rng, iter_type, method, ims)
Expand Down
16 changes: 8 additions & 8 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ using PrecompileTools
update_s_no_weights!(rs, iter) = for x in iter update!(rs, x) end
update_s!(rs, iter) = for x in iter update!(rs, x, wv(x)) end
@compile_workload let
rs = ReservoirSample(Int, algR)
rs = ReservoirSample(Int, AlgR())
update_s_no_weights!(rs, iter)
rs = ReservoirSample(Int, algAExpJ)
rs = ReservoirSample(Int, AlgAExpJ())
update_s!(rs, iter)
rs = ReservoirSample(Int, 2, algR)
rs = ReservoirSample(Int, 2, AlgR())
update_s_no_weights!(rs, iter)
rs = ReservoirSample(Int, 2, algL)
rs = ReservoirSample(Int, 2, AlgL())
update_s_no_weights!(rs, iter)
rs = ReservoirSample(Int, 2, algRSWRSKIP)
rs = ReservoirSample(Int, 2, AlgRSWRSKIP())
update_s_no_weights!(rs, iter)
rs = ReservoirSample(Int, 2, algARes)
rs = ReservoirSample(Int, 2, AlgARes())
update_s!(rs, iter)
rs = ReservoirSample(Int, 2, algAExpJ)
rs = ReservoirSample(Int, 2, AlgAExpJ())
update_s!(rs, iter)
rs = ReservoirSample(Int, 2, algWRSWRSKIP)
rs = ReservoirSample(Int, 2, AlgWRSWRSKIP())
update_s!(rs, iter)
end
end

0 comments on commit 3fb5fa5

Please sign in to comment.