From 57fb55fc46f9c0e0ca31bfd548c2518f3332446c Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 3 Sep 2022 23:38:54 +0200 Subject: [PATCH 1/5] Error when negative weights or zero sum are used when sampling These can give misleading results. Checking the sum is cheap, but checking for negative weights is relatively costly. Therefore, compute this information the first time it is requested, and store it in the weights vector like the sum. `efraimidis_ares_wsample_norep!` and `efraimidis_aexpj_wsample_norep!` already checked these, but throwing different exception types. Harmonize exceptions across algorithms as they can all be called by `sample`. --- src/sampling.jl | 27 +++++++++++++++++++-------- src/weights.jl | 42 +++++++++++++++++++++++++++++++++++++----- test/weights.jl | 40 +++++++++++++++++++++++++++++++++++++++- test/wsampling.jl | 15 +++++++++++++++ 4 files changed, 110 insertions(+), 14 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index ea19a9306..b8045e6f5 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -584,7 +584,10 @@ Optionally specify a random number generator `rng` as the first argument function sample(rng::AbstractRNG, wv::AbstractWeights) 1 == firstindex(wv) || throw(ArgumentError("non 1-based arrays are not supported")) - t = rand(rng) * sum(wv) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + s = sum(wv) + s > 0 || throw(ArgumentError("sum of weights must be greater than 0")) + t = rand(rng) * s n = length(wv) i = 1 cw = wv[1] @@ -621,6 +624,8 @@ function direct_sample!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("non 1-based arrays are not supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) for i = 1:length(x) x[i] = a[sample(rng, wv)] end @@ -710,6 +715,8 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, throw(ArgumentError("non 1-based arrays are not supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) # create alias table ap = Vector{Float64}(undef, n) @@ -749,6 +756,8 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) k = length(x) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) w = Vector{Float64}(undef, n) copyto!(w, wv) @@ -795,6 +804,8 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) # calculate keys for all items keys = randexp(rng, n) @@ -845,14 +856,14 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for _s in 1:n s = _s w = wv.values[s] - w < 0 && error("Negative weight found in weight vector at index $s") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s")) if w > 0 i += 1 pq[i] = (w/randexp(rng) => s) end i >= k && break end - i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) + i < k && throw(ArgumentError("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @@ -860,7 +871,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for i in s+1:n w = wv.values[i] - w < 0 && error("Negative weight found in weight vector at index $i") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $i")) w > 0 || continue key = w/randexp(rng) @@ -918,14 +929,14 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for _s in 1:n s = _s w = wv.values[s] - w < 0 && error("Negative weight found in weight vector at index $s") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s")) if w > 0 i += 1 pq[i] = (w/randexp(rng) => s) end i >= k && break end - i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) + i < k && throw(ArgumentError("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @@ -934,7 +945,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for i in s+1:n w = wv.values[i] - w < 0 && error("Negative weight found in weight vector at index $i") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $i")) w > 0 || continue X -= w X <= 0 || continue @@ -991,7 +1002,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs end end else - k <= n || error("Cannot draw $k samples from $n samples without replacement.") + k <= n || throw(ArgumentError("Cannot draw $k samples from $n samples without replacement.")) efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered) end return x diff --git a/src/weights.jl b/src/weights.jl index 78091c2ae..27cc70ed8 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -5,20 +5,24 @@ abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: Abstrac @weights name Generates a new generic weight type with specified `name`, which subtypes `AbstractWeights` -and stores the `values` (`V<:RealVector`) and `sum` (`S<:Real`). +and stores the `values` (`V<:RealVector`), the pre-computed `sum` (`S<:Real`) and +whether all values are `positive`. """ macro weights(name) return quote mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V} values::V sum::S - function $(esc(name)){S, T, V}(values, sum) where {S<:Real, T<:Real, V<:AbstractVector{T}} + positive::Union{Bool, Missing} + function $(esc(name)){S, T, V}(values, sum, positive) where {S<:Real, T<:Real, V<:AbstractVector{T}} isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) - return new{S, T, V}(values, sum) + return new{S, T, V}(values, sum, positive) end end - $(esc(name))(values::AbstractVector{T}, sum::S) where {S<:Real, T<:Real} = $(esc(name)){S, T, typeof(values)}(values, sum) - $(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values)) + $(esc(name))(values::AbstractVector{T}, + sum::S=Base.sum(values), + positive::Union{Bool, Missing}=missing) where {S<:Real, T<:Real} = + $(esc(name)){S, T, typeof(values)}(values, sum, positive) end end @@ -53,9 +57,34 @@ Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) wv.values[i] = v wv.sum = sum + wv.positive = missing v end +function Base.all(f::Base.Fix2{typeof(>=)}, wv::AbstractWeights) + if iszero(f.x) + if ismissing(wv.positive) + # sum is significantly faster than all when no entries are negative + wv.positive = sum(<(0), wv.values) == 0 + end + return wv.positive + else + return all(f, wv.values) + end +end + +function Base.any(f::Base.Fix2{typeof(<)}, wv::AbstractWeights) + if iszero(f.x) + if ismissing(wv.positive) + # sum is significantly faster than all when no entries are negative + wv.positive = sum(<(0), wv.values) == 0 + end + return !wv.positive + else + return any(f, wv.values) + end +end + """ varcorrection(n::Integer, corrected=false) @@ -333,6 +362,9 @@ end Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len) +Base.all(f::Base.Fix2{typeof(>=)}, wv::UnitWeights{T}) where {T} = one(T) >= f.x +Base.any(f::Base.Fix2{typeof(<)}, wv::UnitWeights{T}) where {T} = one(T) < f.x + """ uweights(s::Integer) uweights(::Type{T}, s::Integer) where T<:Real diff --git a/test/weights.jl b/test/weights.jl index 52142efd8..5982fbdec 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -42,7 +42,6 @@ weight_funcs = (weights, aweights, fweights, pweights) @test_throws ArgumentError f([0.1, Inf]) @test_throws ArgumentError f([0.1, NaN]) - end @testset "$f, setindex!" for f in weight_funcs @@ -125,6 +124,45 @@ end @test Base.dataids(wv) == () end +@testset "Fast-path all(<=, wv) and any(<, wv)" begin + for f in weight_funcs + @test all(>=(0), f([1, 2])) + @test all(>=(0), f([-0.0, 0.0])) + @test !all(>=(0), f([1, -2])) + @test !all(>=(0), f([1, NaN])) + @test !any(<(0), f([1, 2])) + @test !any(<(0), f([-0.0, 0.0])) + @test any(<(0), f([1, -2])) + @test !any(<(0), f([1, NaN])) + @test any(<(0), f([-1, NaN])) + + @test all(>=(1), [2, 3, 4]) + @test !all(>=(1), [0, 1, 2]) + @test any(<(3), [2, 3, 4]) + @test !any(<(1), [1, 2, 3]) + + wv = f([1.0, 2.0, 3.0]) + @test all(>=(0), wv) + @test !any(<(0), wv) + wv[2] = -0.0 + @test all(>=(0), wv) + @test !any(<(0), wv) + wv[2] = -1.0 + @test !all(>=(0), wv) + @test any(<(0), wv) + wv[2] = 1.0 + @test all(>=(0), wv) + @test !any(<(0), wv) + end + + @test all(>=(0), uweights(2)) + @test !any(<(0), uweights(2)) + @test all(>=(1), uweights(2)) + @test !any(<(1), uweights(2)) + @test !all(>=(2), uweights(2)) + @test any(<(2), uweights(2)) +end + ## wsum @testset "wsum" begin diff --git a/test/wsampling.jl b/test/wsampling.jl index d1de4c855..cd9210372 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -161,5 +161,20 @@ end # This corner case should theoretically succeed # but it currently fails as Base.mightalias is not smart enough @test_broken f(y, weights(view(x, 5:6)), view(x, 2:4)) + + # Check that negative weights are not allowed + if f === efraimidis_ares_wsample_norep! || f === efraimidis_aexpj_wsample_norep! + y[3] = -0.0 + @test_throws ArgumentError f(x, weights(y), z) + else + y[3] = -0.0 + f(x, weights(y), z) + end + y[3] = -1.0 + @test_throws ArgumentError f(x, weights(y), z) + + # Check that sum of weights cannot be zero + @test_throws ArgumentError f(x, weights(fill(0.0, 10)), z) + @test_throws ArgumentError f(x, weights(fill(-0.0, 10)), z) end end \ No newline at end of file From 81560a331b2d9ce75d69da8fcc923009577eebf4 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 3 Sep 2022 23:53:47 +0200 Subject: [PATCH 2/5] Test fixes to work on master --- test/weights.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/weights.jl b/test/weights.jl index 5982fbdec..c26ec6a30 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -129,12 +129,9 @@ end @test all(>=(0), f([1, 2])) @test all(>=(0), f([-0.0, 0.0])) @test !all(>=(0), f([1, -2])) - @test !all(>=(0), f([1, NaN])) @test !any(<(0), f([1, 2])) @test !any(<(0), f([-0.0, 0.0])) @test any(<(0), f([1, -2])) - @test !any(<(0), f([1, NaN])) - @test any(<(0), f([-1, NaN])) @test all(>=(1), [2, 3, 4]) @test !all(>=(1), [0, 1, 2]) From 3c1e60dc25e2e7ef529edc19451d1cb8151d2963 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 4 Sep 2022 00:02:22 +0200 Subject: [PATCH 3/5] Another fix --- src/weights.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weights.jl b/src/weights.jl index 27cc70ed8..67589ee23 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -14,7 +14,7 @@ macro weights(name) values::V sum::S positive::Union{Bool, Missing} - function $(esc(name)){S, T, V}(values, sum, positive) where {S<:Real, T<:Real, V<:AbstractVector{T}} + function $(esc(name)){S, T, V}(values, sum, positive=missing) where {S<:Real, T<:Real, V<:AbstractVector{T}} isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) return new{S, T, V}(values, sum, positive) end From c8a2129e9979ed94fd593d4c2820d63c29fb9324 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 4 Sep 2022 00:06:30 +0200 Subject: [PATCH 4/5] And another one --- test/sampling.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/sampling.jl b/test/sampling.jl index 27fcd2d3c..bb2947a32 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -213,10 +213,10 @@ x = vcat([sample(a, wv, 4, replace=false) for j in 1:10000]...) @test maximum(x) == 10 @test maximum(abs, proportions(x) .- 0.25) == 0 -@test_throws DimensionMismatch sample(a, wv, 5, replace=false) +@test_throws ArgumentError sample(a, wv, 5, replace=false) wv = Weights([zeros(5); 1:4; -1]) -@test_throws ErrorException sample(a, wv, 1, replace=false) +@test_throws ArgumentError sample(a, wv, 1, replace=false) #### weighted sampling with dimension From 18775955cd14097a216ee0a9211e6f54a0f1c3eb Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 4 Sep 2022 12:13:01 +0200 Subject: [PATCH 5/5] Fixes --- src/weights.jl | 27 ++++++++++++++------------- test/weights.jl | 16 ++++++++++++---- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 67589ee23..465a57211 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -6,23 +6,23 @@ abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: Abstrac Generates a new generic weight type with specified `name`, which subtypes `AbstractWeights` and stores the `values` (`V<:RealVector`), the pre-computed `sum` (`S<:Real`) and -whether all values are `positive`. +whether any values are `negative`. """ macro weights(name) return quote mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V} values::V sum::S - positive::Union{Bool, Missing} - function $(esc(name)){S, T, V}(values, sum, positive=missing) where {S<:Real, T<:Real, V<:AbstractVector{T}} + negative::Union{Bool, Missing} + function $(esc(name)){S, T, V}(values, sum, negative=missing) where {S<:Real, T<:Real, V<:AbstractVector{T}} isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) - return new{S, T, V}(values, sum, positive) + return new{S, T, V}(values, sum, negative) end end $(esc(name))(values::AbstractVector{T}, sum::S=Base.sum(values), - positive::Union{Bool, Missing}=missing) where {S<:Real, T<:Real} = - $(esc(name)){S, T, typeof(values)}(values, sum, positive) + negative::Union{Bool, Missing}=missing) where {S<:Real, T<:Real} = + $(esc(name)){S, T, typeof(values)}(values, sum, negative) end end @@ -57,17 +57,18 @@ Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) wv.values[i] = v wv.sum = sum - wv.positive = missing + wv.negative = v < zero(v) ? true : + wv.negative === false ? false : missing v end function Base.all(f::Base.Fix2{typeof(>=)}, wv::AbstractWeights) if iszero(f.x) - if ismissing(wv.positive) + if wv.negative === missing # sum is significantly faster than all when no entries are negative - wv.positive = sum(<(0), wv.values) == 0 + wv.negative = sum(<(0), wv.values) > 0 end - return wv.positive + return !wv.negative else return all(f, wv.values) end @@ -75,11 +76,11 @@ end function Base.any(f::Base.Fix2{typeof(<)}, wv::AbstractWeights) if iszero(f.x) - if ismissing(wv.positive) + if wv.negative === missing # sum is significantly faster than all when no entries are negative - wv.positive = sum(<(0), wv.values) == 0 + wv.negative = sum(<(0), wv.values) > 0 end - return !wv.positive + return wv.negative else return any(f, wv.values) end diff --git a/test/weights.jl b/test/weights.jl index c26ec6a30..e47b05b91 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -133,10 +133,10 @@ end @test !any(<(0), f([-0.0, 0.0])) @test any(<(0), f([1, -2])) - @test all(>=(1), [2, 3, 4]) - @test !all(>=(1), [0, 1, 2]) - @test any(<(3), [2, 3, 4]) - @test !any(<(1), [1, 2, 3]) + @test all(>=(1), f([2, 3, 4])) + @test !all(>=(1), f([0, 1, 2])) + @test any(<(3), f([2, 3, 4])) + @test !any(<(1), f([1, 2, 3])) wv = f([1.0, 2.0, 3.0]) @test all(>=(0), wv) @@ -150,6 +150,14 @@ end wv[2] = 1.0 @test all(>=(0), wv) @test !any(<(0), wv) + + wv = f([1.0, 2.0, 3.0]) + wv[2] = -1.0 + @test !all(>=(0), wv) + @test any(<(0), wv) + wv[2] = 1.0 + @test all(>=(0), wv) + @test !any(<(0), wv) end @test all(>=(0), uweights(2))