Skip to content

Commit

Permalink
Use HybridStruct (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Sep 27, 2024
1 parent 8bafd84 commit e384e32
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 64 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.3.10"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HybridStructs = "49057fa9-d513-5ef6-ae80-2dc68a70a2bd"
OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -16,6 +17,7 @@ julia = "1.8"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
HybridStructs = "0.2"
OnlineStatsBase = "1"
PrecompileTools = "1"
Random = "1"
Expand Down
3 changes: 2 additions & 1 deletion src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module StreamSampling
import Accessors
using DataStructures
using Distributions
using HybridStructs
using OnlineStatsBase
using Random
using StatsBase
Expand Down Expand Up @@ -97,7 +98,7 @@ macro reset(e)
if ismutabletype(typeof($s))
$e
else
$StreamSampling.Accessors.@reset $e
$StreamSampling.Accessors.@update $e
end
end)
end
Expand Down
92 changes: 33 additions & 59 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,25 @@

const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(last), DataStructures.FasterForward}} where T

struct ImmutSampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti
seen_k::Int
n::Int
rng::R
value::BH
end
mutable struct MutSampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti
@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti
seen_k::Int
n::Int
const rng::R
const value::BH
end
const SampleMultiAlgARes = Union{ImmutSampleMultiAlgARes, MutSampleMultiAlgARes}
const SampleMultiOrdAlgARes = Union{ImmutSampleMultiAlgARes{<:OrdWeighted}, MutSampleMultiAlgARes{<:OrdWeighted}}
const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}}

struct ImmutSampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
state::Float64
min_priority::Float64
seen_k::Int
n::Int
rng::R
value::BH
end
mutable struct MutSampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
state::Float64
min_priority::Float64
seen_k::Int
const n::Int
const rng::R
const value::BH
end
const SampleMultiAlgAExpJ = Union{ImmutSampleMultiAlgAExpJ, MutSampleMultiAlgAExpJ}
const SampleMultiOrdAlgAExpJ = Union{ImmutSampleMultiAlgAExpJ{<:OrdWeighted}, MutSampleMultiAlgAExpJ{<:OrdWeighted}}
const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}}

struct ImmutSampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti
state::Float64
skip_w::Float64
seen_k::Int
rng::R
weights::Vector{Float64}
value::Vector{T}
ord::O
end
mutable struct MutSampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti
@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirSampleMulti
state::Float64
skip_w::Float64
seen_k::Int
Expand All @@ -53,62 +28,61 @@ mutable struct MutSampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractWeightedWrReservoirS
const value::Vector{T}
const ord::O
end
const SampleMultiAlgWRSWRSKIP = Union{ImmutSampleMultiAlgWRSWRSKIP, MutSampleMultiAlgWRSWRSKIP}
const SampleMultiOrdAlgWRSWRSKIP = Union{ImmutSampleMultiAlgWRSWRSKIP{<:Vector}, MutSampleMultiAlgWRSWRSKIP{<:Vector}}
const SampleMultiOrdAlgWRSWRSKIP = Union{SampleMultiAlgWRSWRSKIP_Immut{<:Vector}, SampleMultiAlgWRSWRSKIP_Mut{<:Vector}}

function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::MutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return MutSampleMultiAlgAExpJ(0.0, 0.0, 0, n, rng, value)
return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::MutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return MutSampleMultiAlgAExpJ(0.0, 0.0, 0, n, rng, value)
return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return ImmutSampleMultiAlgAExpJ(0.0, 0.0, 0, n, rng, value)
return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return ImmutSampleMultiAlgAExpJ(0.0, 0.0, 0, n, rng, value)
return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::MutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return MutSampleMultiAlgARes(0, n, rng, value)
return SampleMultiAlgARes_Mut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::MutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return MutSampleMultiAlgARes(0, n, rng, value)
return SampleMultiAlgARes_Mut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::ImmutSample, ::Ord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return ImmutSampleMultiAlgARes(0, n, rng, value)
return SampleMultiAlgARes_Immut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgARes, ::ImmutSample, ::Unord)
value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[])
sizehint!(value, n)
return ImmutSampleMultiAlgARes(0, n, rng, value)
return SampleMultiAlgARes_Immut(0, n, rng, value)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ms::MutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ::MutSample, ::Ord)
ord = collect(1:n)
return MutSampleMultiAlgWRSWRSKIP(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord)
return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ms::MutSample, ::Unord)
return MutSampleMultiAlgWRSWRSKIP(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ::MutSample, ::Unord)
return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ims::ImmutSample, ::Ord)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ::ImmutSample, ::Ord)
ord = collect(1:n)
return ImmutSampleMultiAlgWRSWRSKIP(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord)
return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ims::ImmutSample, ::Unord)
return ImmutSampleMultiAlgWRSWRSKIP(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP, ::ImmutSample, ::Unord)
return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing)
end

@inline function update!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w)
Expand Down Expand Up @@ -181,39 +155,39 @@ end
return s
end

function Base.empty!(s::MutSampleMultiAlgARes)
function Base.empty!(s::SampleMultiAlgARes_Mut)
s.seen_k = 0
empty!(s.value)
sizehint!(s.value, s.n)
return s
end
function Base.empty!(s::MutSampleMultiAlgAExpJ)
function Base.empty!(s::SampleMultiAlgAExpJ_Mut)
s.state = 0.0
s.min_priority = 0.0
s.seen_k = 0
empty!(s.value)
sizehint!(s.value, s.n)
return s
end
function Base.empty!(s::MutSampleMultiAlgWRSWRSKIP)
function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut)
s.state = 0.0
s.skip_w = 0.0
s.seen_k = 0
return s
end

function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w)
@reset s.seen_k += 1
@update s.seen_k += 1
return s
end
function update_state!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, w)
@reset s.seen_k += 1
@reset s.state -= w
@update s.seen_k += 1
@update s.state -= w
return s
end
function update_state!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, w)
@reset s.seen_k += 1
@reset s.state += w
@update s.seen_k += 1
@update s.state += w
return s
end

Expand All @@ -223,13 +197,13 @@ function compute_skip_priority(s, w)
end

function recompute_skip!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ})
@reset s.min_priority = last(first(s.value))
@reset s.state = -randexp(s.rng)/log(s.min_priority)
@update s.min_priority = last(first(s.value))
@update s.state = -randexp(s.rng)/log(s.min_priority)
return s
end
function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, n)
q = rand(s.rng)^(1/n)
@reset s.skip_w = s.state/q
@update s.skip_w = s.state/q
return s
end

Expand Down
8 changes: 4 additions & 4 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ function value(s::AbstractWeightedReservoirSampleSingle)
end

@inline function update!(s::SampleSingleAlgAExpJ, el, weight)
@reset s.seen_k += 1
@reset s.total_w += weight
@update s.seen_k += 1
@update s.total_w += weight
if s.skip_w <= s.total_w
@reset s.skip_w = s.total_w/rand(s.rng)
@update s.skip_w = s.total_w/rand(s.rng)
s = reset_value!(s, el)
end
return s
Expand All @@ -59,7 +59,7 @@ function reset_value!(s::MutSampleSingleAlgAExpJ, el)
return s
end
function reset_value!(s::ImmutSampleSingleAlgAExpJ, el)
@reset s.rvalue.value = el
@update s.rvalue.value = el
return s
end

Expand Down

0 comments on commit e384e32

Please sign in to comment.