Skip to content

Commit

Permalink
Improve merging (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Oct 6, 2024
1 parent c9feabb commit 6e4340f
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 68 deletions.
24 changes: 11 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This has some advantages over other sampling procedures:
- In some cases, sampling with the techniques implemented in this library can bring considerable performance gains, since
the population of items doesn't need to be previously stored in memory.

## Brief overview of the functionalities
## Overview of the functionalities

The `itsample` function allows to consume all the stream at once and return the sample collected:

Expand All @@ -33,6 +33,7 @@ julia> itsample(st, 5)
96
91
```

In some cases, one needs to control the updates the `ReservoirSample` will be subject to. In this case
you can simply use the `fit!` function to update the reservoir:

Expand Down Expand Up @@ -71,39 +72,36 @@ julia> rng = Xoshiro(42);

julia> iter = Iterators.filter(x -> x != 10, 1:10^7);

julia> wv(el) = 1.0;
julia> wv(el) = Float64(el);

julia> @btime itsample($rng, $iter, 10^4, AlgRSWRSKIP());
12.457 ms (4 allocations: 156.34 KiB)
12.301 ms (6 allocations: 156.38 KiB)

julia> @btime sample($rng, collect($iter), 10^4; replace=true);
134.152 ms (20 allocations: 146.91 MiB)
92.936 ms (35 allocations: 290.93 MiB)

julia> @btime itsample($rng, $iter, 10^4, AlgL());
8.262 ms (2 allocations: 78.17 KiB)
12.719 ms (3 allocations: 78.19 KiB)

julia> @btime sample($rng, collect($iter), 10^4; replace=false);
138.054 ms (27 allocations: 147.05 MiB)
93.544 ms (41 allocations: 291.08 MiB)

julia> @btime itsample($rng, $iter, $wv, 10^4, AlgWRSWRSKIP());
14.479 ms (15 allocations: 547.23 KiB)
18.672 ms (22 allocations: 547.34 KiB)

julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=true);
343.936 ms (49 allocations: 675.21 MiB)
377.567 ms (83 allocations: 963.26 MiB)

julia> @btime itsample($rng, $iter, $wv, 10^4, AlgAExpJ());
30.523 ms (6 allocations: 234.62 KiB)
37.600 ms (8 allocations: 234.55 KiB)

julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=false);
294.242 ms (43 allocations: 370.19 MiB)
258.426 ms (74 allocations: 658.24 MiB)
```

Some more performance comparisons in respect to `StatsBase` methods are in the [benchmark](https://github.com/JuliaDynamics/StreamSampling.jl/blob/main/benchmark/) folder.



## Contributing

Contributions are welcome! If you encounter any issues, have suggestions for improvements, or would like to add new
features, feel free to open an issue or submit a pull request.

6 changes: 4 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
This is the API page of the package. For a general overview of the functionalities
consult the [ReadMe](https://github.com/JuliaDynamics/StreamSampling.jl).

## General functionalities
## General Functionalities

```@docs
ReservoirSample
fit!
merge!
merge
empty!
value
ordvalue
nobs
itsample
```

## Sampling algorithms
## Sampling Algorithms

```@docs
StreamSampling.AlgR
Expand Down
26 changes: 25 additions & 1 deletion src/SamplingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
Initializes a reservoir sample which can then be fitted with [`fit!`](@ref).
The first signature represents a sample where only a single element is collected.
A weight function `wfunc` can be passed to apply weighted sampling. Look at the
[`Algorithms`](@ref) section for the supported methods.
[`Sampling Algorithms`](@ref) section for the supported methods. If `ordered` is
true, the reservoir sample values can be retrived in the order they were collected
with [`ordvalue`](@ref).
"""
function ReservoirSample(T, method::ReservoirAlgorithm = AlgRSWRSKIP())
return ReservoirSample(Random.default_rng(), T, method, MutSample())
Expand Down Expand Up @@ -86,6 +88,28 @@ function Base.empty!(::AbstractReservoirSample)
error("Abstract Version")
end

"""
Base.merge!(rs::AbstractReservoirSample, rs::AbstractReservoirSample...)
Updates the first reservoir sample by merging its value with the values
of the other samples. Currently only supported for samples with replacement.
"""
function Base.merge!(::AbstractReservoirSample)
error("Abstract Version")
end


"""
Base.merge(rs::AbstractReservoirSample...)
Creates a new reservoir sample by merging the values
of the samples passed. Currently only supported for sample
with replacement.
"""
function OnlineStatsBase.merge(::AbstractReservoirSample)
error("Abstract Version")
end

"""
itsample([rng], iter, method = AlgRSWRSKIP())
itsample([rng], iter, wfunc, method = AlgWRSWRSKIP())
Expand Down
27 changes: 27 additions & 0 deletions src/SamplingReduction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

const SMWR = Union{SampleMultiAlgRSWRSKIP, SampleMultiAlgWRSWRSKIP}

reduce_samples(t) = error()
function reduce_samples(t, ss::T...) where {T<:SMWR}
nt = length(ss)
v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt)
ns = rand(ss[1].rng, Multinomial(length(value(ss[1])), get_ps(ss...)))
Threads.@threads for i in 1:nt
v[i] = sample(ss[i].rng, value(ss[i]), ns[i]; replace = false)
end
return reduce(vcat, v)
end

function get_ps(ss::SampleMultiAlgRSWRSKIP...)
sum_w = sum(getfield(s, :seen_k) for s in ss)
return [s.seen_k/sum_w for s in ss]
end
function get_ps(ss::SampleMultiAlgWRSWRSKIP...)
sum_w = sum(getfield(s, :state) for s in ss)
return [s.state/sum_w for s in ss]
end

get_type_rs(::TypeS, s1::T, ss::T...) where {T<:SMWR} = eltype(value(s1))
function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T<:SMWR}
return Union{eltype(value(s1)), Union{(eltype(value(s)) for s in ss)...}}
end
3 changes: 3 additions & 0 deletions src/SamplingUtils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

struct TypeS end
struct TypeUnion end

@hybrid struct RefVal{T}
value::T
RefVal{T}() where T = new{T}()
Expand Down
3 changes: 2 additions & 1 deletion src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using OnlineStatsBase
using Random
using StatsBase

export fit!, value, ordvalue, nobs, itsample
export fit!, merge!, value, ordvalue, nobs, itsample
export AbstractReservoirSample, ReservoirSample
export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP

Expand Down Expand Up @@ -92,6 +92,7 @@ include("UnweightedSamplingSingle.jl")
include("UnweightedSamplingMulti.jl")
include("WeightedSamplingSingle.jl")
include("WeightedSamplingMulti.jl")
include("SamplingReduction.jl")
include("precompile.jl")

end
66 changes: 16 additions & 50 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
if s.seen_k === n
s = @inline recompute_skip!(s, n)
end
elseif s.skip_k < 0
elseif s.skip_k < s.seen_k
j = rand(s.rng, 1:n)
@inbounds s.value[j] = el
update_order!(s, j)
Expand All @@ -105,7 +105,7 @@ end
s.value[i] = new_values[i]
end
end
elseif s.skip_k < 0
elseif s.skip_k < s.seen_k
p = 1/s.seen_k
z = (1-p)^(n-3)
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
Expand Down Expand Up @@ -151,23 +151,21 @@ function update_state!(s::SampleMultiAlgR)
end
function update_state!(s::SampleMultiAlgL)
@update s.seen_k += 1
@update s.skip_k -= 1
return s
end
function update_state!(s::AbstractWrReservoirSampleMulti)
@update s.seen_k += 1
@update s.skip_k -= 1
return s
end

function recompute_skip!(s::AbstractWorReservoirSampleMulti, n)
@update s.state += randexp(s.rng)
@update s.skip_k = -ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
@update s.skip_k = s.seen_k-ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
return s
end
function recompute_skip!(s::AbstractWrReservoirSampleMulti, n)
q = rand(s.rng)^(1/n)
@update s.skip_k = ceil(Int, s.seen_k/q - s.seen_k - 1)
@update s.skip_k = ceil(Int, s.seen_k/q)-1
return s
end

Expand Down Expand Up @@ -200,54 +198,22 @@ function update_order_multi!(s::SampleMultiOrdAlgRSWRSKIP, r, j)
end

is_ordered(s::SampleMultiOrdAlgRSWRSKIP) = true
is_ordered(s::AbstractWrReservoirSampleMulti) = false
is_ordered(s::SampleMultiAlgRSWRSKIP) = false

function Base.merge(s1::AbstractWrReservoirSampleMulti, s2::AbstractWrReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
shuffle!(s1.rng, s1.value)
shuffle!(s2.rng, s2.value)
n_tot = n1 + n2
p = n2 / n_tot
value = create_new_res_vec(s1, s2, p, len1)
s_merged = typeof(s1)(0, n_tot, s1.rng, value, nothing)
recompute_skip!(s_merged, len1)
return s_merged
function Base.merge(ss::SampleMultiAlgRSWRSKIP...)
newvalue = reduce_samples(TypeUnion(), ss...)
skip_k = sum(getfield(s, :skip_k) for s in ss)
seen_k = sum(getfield(s, :seen_k) for s in ss)
return SampleMultiAlgRSWRSKIP_Mut(skip_k, seen_k, ss[1].rng, newvalue, nothing)
end

function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
shuffle!(s1.rng, s1.value)
shuffle!(s2.rng, s2.value)
n_tot = n1 + n2
p = n2 / n_tot
merge_res_vec!(s1, s2, p, len1, n_tot)
recompute_skip!(s1, len1)
return s1
end

function check_merging_support(s1, s2)
len1, len2 = length(s1.value), length(s2.value)
len1 != len2 && error("Merging samples with different sizes is not supported")
n1, n2 = nobs(s1), nobs(s2)
n1 < len1 || n2 < len2 && error("Merging samples with different sizes is not supported")
return len1, len2, n1, n2
end

function create_new_res_vec(s1, s2, p, len1)
value = similar(s1.value)
@inbounds for j in 1:len1
value[j] = rand(s1.rng) < p ? s2.value[j] : s1.value[j]
end
return value
end

function merge_res_vec!(s1, s2, p, len1, n_tot)
@inbounds for j in 1:len1
if rand(s1.rng) < p
s1.value[j] = s2.value[j]
end
function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRSWRSKIP...)
newvalue = reduce_samples(TypeS(), s1, ss...)
for i in 1:length(newvalue)
@inbounds s1.value[i] = newvalue[i]
end
s1.seen_k = n_tot
s1.skip_k += sum(getfield(s, :skip_k) for s in ss)
s1.seen_k += sum(getfield(s, :seen_k) for s in ss)
return s1
end

Expand Down
2 changes: 1 addition & 1 deletion test/merge_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
end
end
s_merged = merge(s1, s2)
res[value(s_merged)...] += 1
res[shuffle!(rng, value(s_merged))...] += 1
end
cases = m1 == AlgRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size)
ps_exact = [1/cases for _ in 1:cases]
Expand Down

0 comments on commit 6e4340f

Please sign in to comment.