From 6e4340ff6ffe593c14b99d0c8287f59e9a9b4f82 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 6 Oct 2024 23:43:00 +0200 Subject: [PATCH] Improve merging (#95) --- README.md | 24 ++++++------- docs/src/index.md | 6 ++-- src/SamplingInterface.jl | 26 +++++++++++++- src/SamplingReduction.jl | 27 ++++++++++++++ src/SamplingUtils.jl | 3 ++ src/StreamSampling.jl | 3 +- src/UnweightedSamplingMulti.jl | 66 +++++++++------------------------- test/merge_tests.jl | 2 +- 8 files changed, 89 insertions(+), 68 deletions(-) create mode 100644 src/SamplingReduction.jl diff --git a/README.md b/README.md index bbe9a8d..e0c6626 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ This has some advantages over other sampling procedures: - In some cases, sampling with the techniques implemented in this library can bring considerable performance gains, since the population of items doesn't need to be previously stored in memory. -## Brief overview of the functionalities +## Overview of the functionalities The `itsample` function allows to consume all the stream at once and return the sample collected: @@ -33,6 +33,7 @@ julia> itsample(st, 5) 96 91 ``` + In some cases, one needs to control the updates the `ReservoirSample` will be subject to. In this case you can simply use the `fit!` function to update the reservoir: @@ -71,39 +72,36 @@ julia> rng = Xoshiro(42); julia> iter = Iterators.filter(x -> x != 10, 1:10^7); -julia> wv(el) = 1.0; +julia> wv(el) = Float64(el); julia> @btime itsample($rng, $iter, 10^4, AlgRSWRSKIP()); - 12.457 ms (4 allocations: 156.34 KiB) + 12.301 ms (6 allocations: 156.38 KiB) julia> @btime sample($rng, collect($iter), 10^4; replace=true); - 134.152 ms (20 allocations: 146.91 MiB) + 92.936 ms (35 allocations: 290.93 MiB) julia> @btime itsample($rng, $iter, 10^4, AlgL()); - 8.262 ms (2 allocations: 78.17 KiB) + 12.719 ms (3 allocations: 78.19 KiB) julia> @btime sample($rng, collect($iter), 10^4; replace=false); - 138.054 ms (27 allocations: 147.05 MiB) + 93.544 ms (41 allocations: 291.08 MiB) julia> @btime itsample($rng, $iter, $wv, 10^4, AlgWRSWRSKIP()); - 14.479 ms (15 allocations: 547.23 KiB) + 18.672 ms (22 allocations: 547.34 KiB) julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=true); - 343.936 ms (49 allocations: 675.21 MiB) + 377.567 ms (83 allocations: 963.26 MiB) julia> @btime itsample($rng, $iter, $wv, 10^4, AlgAExpJ()); - 30.523 ms (6 allocations: 234.62 KiB) + 37.600 ms (8 allocations: 234.55 KiB) julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=false); - 294.242 ms (43 allocations: 370.19 MiB) + 258.426 ms (74 allocations: 658.24 MiB) ``` Some more performance comparisons in respect to `StatsBase` methods are in the [benchmark](https://github.com/JuliaDynamics/StreamSampling.jl/blob/main/benchmark/) folder. - - ## Contributing Contributions are welcome! If you encounter any issues, have suggestions for improvements, or would like to add new features, feel free to open an issue or submit a pull request. - diff --git a/docs/src/index.md b/docs/src/index.md index 4f637a6..f587264 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -3,11 +3,13 @@ This is the API page of the package. For a general overview of the functionalities consult the [ReadMe](https://github.com/JuliaDynamics/StreamSampling.jl). -## General functionalities +## General Functionalities ```@docs ReservoirSample fit! +merge! +merge empty! value ordvalue @@ -15,7 +17,7 @@ nobs itsample ``` -## Sampling algorithms +## Sampling Algorithms ```@docs StreamSampling.AlgR diff --git a/src/SamplingInterface.jl b/src/SamplingInterface.jl index ee6f35b..af9b725 100644 --- a/src/SamplingInterface.jl +++ b/src/SamplingInterface.jl @@ -8,7 +8,9 @@ Initializes a reservoir sample which can then be fitted with [`fit!`](@ref). The first signature represents a sample where only a single element is collected. A weight function `wfunc` can be passed to apply weighted sampling. Look at the -[`Algorithms`](@ref) section for the supported methods. +[`Sampling Algorithms`](@ref) section for the supported methods. If `ordered` is +true, the reservoir sample values can be retrived in the order they were collected +with [`ordvalue`](@ref). """ function ReservoirSample(T, method::ReservoirAlgorithm = AlgRSWRSKIP()) return ReservoirSample(Random.default_rng(), T, method, MutSample()) @@ -86,6 +88,28 @@ function Base.empty!(::AbstractReservoirSample) error("Abstract Version") end +""" + Base.merge!(rs::AbstractReservoirSample, rs::AbstractReservoirSample...) + +Updates the first reservoir sample by merging its value with the values +of the other samples. Currently only supported for samples with replacement. +""" +function Base.merge!(::AbstractReservoirSample) + error("Abstract Version") +end + + +""" + Base.merge(rs::AbstractReservoirSample...) + +Creates a new reservoir sample by merging the values +of the samples passed. Currently only supported for sample +with replacement. +""" +function OnlineStatsBase.merge(::AbstractReservoirSample) + error("Abstract Version") +end + """ itsample([rng], iter, method = AlgRSWRSKIP()) itsample([rng], iter, wfunc, method = AlgWRSWRSKIP()) diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl new file mode 100644 index 0000000..1239309 --- /dev/null +++ b/src/SamplingReduction.jl @@ -0,0 +1,27 @@ + +const SMWR = Union{SampleMultiAlgRSWRSKIP, SampleMultiAlgWRSWRSKIP} + +reduce_samples(t) = error() +function reduce_samples(t, ss::T...) where {T<:SMWR} + nt = length(ss) + v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt) + ns = rand(ss[1].rng, Multinomial(length(value(ss[1])), get_ps(ss...))) + Threads.@threads for i in 1:nt + v[i] = sample(ss[i].rng, value(ss[i]), ns[i]; replace = false) + end + return reduce(vcat, v) +end + +function get_ps(ss::SampleMultiAlgRSWRSKIP...) + sum_w = sum(getfield(s, :seen_k) for s in ss) + return [s.seen_k/sum_w for s in ss] +end +function get_ps(ss::SampleMultiAlgWRSWRSKIP...) + sum_w = sum(getfield(s, :state) for s in ss) + return [s.state/sum_w for s in ss] +end + +get_type_rs(::TypeS, s1::T, ss::T...) where {T<:SMWR} = eltype(value(s1)) +function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T<:SMWR} + return Union{eltype(value(s1)), Union{(eltype(value(s)) for s in ss)...}} +end \ No newline at end of file diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 56daf3c..d8ce5d5 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -1,4 +1,7 @@ +struct TypeS end +struct TypeUnion end + @hybrid struct RefVal{T} value::T RefVal{T}() where T = new{T}() diff --git a/src/StreamSampling.jl b/src/StreamSampling.jl index 0540dc9..030c87c 100644 --- a/src/StreamSampling.jl +++ b/src/StreamSampling.jl @@ -8,7 +8,7 @@ using OnlineStatsBase using Random using StatsBase -export fit!, value, ordvalue, nobs, itsample +export fit!, merge!, value, ordvalue, nobs, itsample export AbstractReservoirSample, ReservoirSample export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP @@ -92,6 +92,7 @@ include("UnweightedSamplingSingle.jl") include("UnweightedSamplingMulti.jl") include("WeightedSamplingSingle.jl") include("WeightedSamplingMulti.jl") +include("SamplingReduction.jl") include("precompile.jl") end diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 1192ce1..93c4253 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -85,7 +85,7 @@ end if s.seen_k === n s = @inline recompute_skip!(s, n) end - elseif s.skip_k < 0 + elseif s.skip_k < s.seen_k j = rand(s.rng, 1:n) @inbounds s.value[j] = el update_order!(s, j) @@ -105,7 +105,7 @@ end s.value[i] = new_values[i] end end - elseif s.skip_k < 0 + elseif s.skip_k < s.seen_k p = 1/s.seen_k z = (1-p)^(n-3) q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0)) @@ -151,23 +151,21 @@ function update_state!(s::SampleMultiAlgR) end function update_state!(s::SampleMultiAlgL) @update s.seen_k += 1 - @update s.skip_k -= 1 return s end function update_state!(s::AbstractWrReservoirSampleMulti) @update s.seen_k += 1 - @update s.skip_k -= 1 return s end function recompute_skip!(s::AbstractWorReservoirSampleMulti, n) @update s.state += randexp(s.rng) - @update s.skip_k = -ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n))) + @update s.skip_k = s.seen_k-ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n))) return s end function recompute_skip!(s::AbstractWrReservoirSampleMulti, n) q = rand(s.rng)^(1/n) - @update s.skip_k = ceil(Int, s.seen_k/q - s.seen_k - 1) + @update s.skip_k = ceil(Int, s.seen_k/q)-1 return s end @@ -200,54 +198,22 @@ function update_order_multi!(s::SampleMultiOrdAlgRSWRSKIP, r, j) end is_ordered(s::SampleMultiOrdAlgRSWRSKIP) = true -is_ordered(s::AbstractWrReservoirSampleMulti) = false +is_ordered(s::SampleMultiAlgRSWRSKIP) = false -function Base.merge(s1::AbstractWrReservoirSampleMulti, s2::AbstractWrReservoirSampleMulti) - len1, len2, n1, n2 = check_merging_support(s1, s2) - shuffle!(s1.rng, s1.value) - shuffle!(s2.rng, s2.value) - n_tot = n1 + n2 - p = n2 / n_tot - value = create_new_res_vec(s1, s2, p, len1) - s_merged = typeof(s1)(0, n_tot, s1.rng, value, nothing) - recompute_skip!(s_merged, len1) - return s_merged +function Base.merge(ss::SampleMultiAlgRSWRSKIP...) + newvalue = reduce_samples(TypeUnion(), ss...) + skip_k = sum(getfield(s, :skip_k) for s in ss) + seen_k = sum(getfield(s, :seen_k) for s in ss) + return SampleMultiAlgRSWRSKIP_Mut(skip_k, seen_k, ss[1].rng, newvalue, nothing) end -function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMulti) - len1, len2, n1, n2 = check_merging_support(s1, s2) - shuffle!(s1.rng, s1.value) - shuffle!(s2.rng, s2.value) - n_tot = n1 + n2 - p = n2 / n_tot - merge_res_vec!(s1, s2, p, len1, n_tot) - recompute_skip!(s1, len1) - return s1 -end - -function check_merging_support(s1, s2) - len1, len2 = length(s1.value), length(s2.value) - len1 != len2 && error("Merging samples with different sizes is not supported") - n1, n2 = nobs(s1), nobs(s2) - n1 < len1 || n2 < len2 && error("Merging samples with different sizes is not supported") - return len1, len2, n1, n2 -end - -function create_new_res_vec(s1, s2, p, len1) - value = similar(s1.value) - @inbounds for j in 1:len1 - value[j] = rand(s1.rng) < p ? s2.value[j] : s1.value[j] - end - return value -end - -function merge_res_vec!(s1, s2, p, len1, n_tot) - @inbounds for j in 1:len1 - if rand(s1.rng) < p - s1.value[j] = s2.value[j] - end +function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRSWRSKIP...) + newvalue = reduce_samples(TypeS(), s1, ss...) + for i in 1:length(newvalue) + @inbounds s1.value[i] = newvalue[i] end - s1.seen_k = n_tot + s1.skip_k += sum(getfield(s, :skip_k) for s in ss) + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) return s1 end diff --git a/test/merge_tests.jl b/test/merge_tests.jl index 2024fed..f61ae70 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -16,7 +16,7 @@ end end s_merged = merge(s1, s2) - res[value(s_merged)...] += 1 + res[shuffle!(rng, value(s_merged))...] += 1 end cases = m1 == AlgRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size) ps_exact = [1/cases for _ in 1:cases]