Skip to content

Commit

Permalink
Make HMMBase optional in the test suite (#111)
Browse files Browse the repository at this point in the history
* Put HMMBase in extension of HMMTest

* Split test suite

* Using

* Fix

* Fixes

* Increase test case sizes

* Fixes

* Fix

* Split better

* More fixes

* Rm show

* Fix

* Subtle change

* Fixes
  • Loading branch information
gdalle authored Sep 30, 2024
1 parent 0e92d51 commit 284189f
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 88 deletions.
17 changes: 10 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
push:
branches:
- main
tags: ['*']
tags: ["*"]
pull_request:
concurrency:
# Skip intermediate builds: always.
Expand All @@ -12,16 +12,19 @@ concurrency:
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ github.event_name }}
name: Julia ${{ matrix.version }} - ${{ matrix.test_suite }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- '1.9'
- '1'
os:
- ubuntu-latest
- "1.9"
- "1"
test_suite:
- "Standard"
- "HMMBase"
env:
JULIA_HMM_TEST_SUITE: ${{ matrix.test_suite }}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -36,4 +39,4 @@ jobs:
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: true
1 change: 1 addition & 0 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Enzyme: Enzyme
using ForwardDiff: ForwardDiff
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using HMMTest #src
using LinearAlgebra
using Random: Random, AbstractRNG
using StableRNGs
Expand Down
3 changes: 1 addition & 2 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ This is important to keep in mind when testing new models.
In many applications, we have access to various observation sequences of different lengths.
=#

nb_seqs = 300
nb_seqs = 1000
long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs];
typeof(long_obs_seqs)

Expand Down Expand Up @@ -258,6 +258,5 @@ hcat(initialization(hmm_est_concat), initialization(hmm))
# ## Tests #src

control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
4 changes: 2 additions & 2 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Simulation requires a vector of controls, each being a vector itself with the ri
Let us build several sequences of variable lengths.
=#

control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:100];
control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];

obs_seq = reduce(vcat, obs_seqs)
Expand Down Expand Up @@ -151,5 +151,5 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2])

@test hmm_est.dist_coeffs[1] hmm.dist_coeffs[1] atol = 0.05 #src
@test hmm_est.dist_coeffs[2] hmm.dist_coeffs[2] atol = 0.05 #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
4 changes: 2 additions & 2 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,6 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2)))

# ## Tests #src

@test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.09, init=false) #src
@test mean(obs_seqs[1][1:2:end]) < 0 < mean(obs_seqs[1][2:2:end]) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
5 changes: 2 additions & 3 deletions examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St

@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src

seq_ends = cumsum(rand(rng, 100:200, 100)); #src
seq_ends = cumsum(rand(rng, 100:200, 1000)); #src
control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
7 changes: 6 additions & 1 deletion libs/HMMTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ version = "0.1.0"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"

[extensions]
HMMTestHMMBaseExt = "HMMBase"
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
module HMMTestHMMBaseExt

function test_identical_hmmbase(
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using HMMBase: HMMBase
using HMMTest
using Random: AbstractRNG
using Statistics: mean
using Test: @test, @testset, @test_broken

function HMMTest.test_identical_hmmbase(
rng::AbstractRNG,
hmm::AbstractHMM,
T::Integer;
Expand Down Expand Up @@ -54,3 +63,5 @@ function test_identical_hmmbase(
end
end
end

end
4 changes: 2 additions & 2 deletions libs/HMMTest/src/HMMTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ using BenchmarkTools: @ballocated
using HiddenMarkovModels
using HiddenMarkovModels: AbstractVectorOrNTuple
import HiddenMarkovModels as HMMs
using HMMBase: HMMBase
using JET: @test_opt, @test_call
using Random: AbstractRNG
using Statistics: mean
using Test: @test, @testset, @test_broken

function test_identical_hmmbase end # in extension

export transpose_hmm
export test_equal_hmms, test_coherent_algorithms
export test_identical_hmmbase
Expand All @@ -19,7 +20,6 @@ export test_type_stability
include("utils.jl")
include("coherence.jl")
include("allocations.jl")
include("hmmbase.jl")
include("jet.jl")

end
2 changes: 1 addition & 1 deletion src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof
using DocStringExtensions
using FillArrays: Fill
using LinearAlgebra: Transpose, dot, ldiv!, lmul!, mul!, parent
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent
using Random: Random, AbstractRNG, default_rng
using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals
using StatsAPI: StatsAPI, fit, fit!
Expand Down
2 changes: 1 addition & 1 deletion src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function baum_welch_has_converged(
logL, logL_prev = logL_evolution[end], logL_evolution[end - 1]
progress = logL - logL_prev
if loglikelihood_increasing && progress < min(0, -atol)
error("Loglikelihood decreased in Baum-Welch")
error("Loglikelihood decreased from $logL_prev to $logL in Baum-Welch")
elseif progress < atol
return true
end
Expand Down
2 changes: 1 addition & 1 deletion src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function obs_logdensities!(
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control
) where {T}
dists = obs_distributions(hmm, control)
@inbounds @simd for i in eachindex(logb, dists)
@simd for i in eachindex(logb, dists)
logb[i] = logdensityof(dists[i], obs)
end
@argcheck maximum(logb) < typemax(T)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/lightcategorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function StatsAPI.fit!(
@argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p)
w_tot = sum(w)
fill!(dist.p, zero(T1))
@inbounds @simd for i in eachindex(x, w)
@simd for i in eachindex(x, w)
dist.p[x[i]] += w[i]
end
dist.p ./= w_tot
Expand Down
8 changes: 4 additions & 4 deletions src/utils/lightdiagnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function DensityInterface.logdensityof(
) where {T1,T2,T3}
l = zero(promote_type(T1, T2, T3, eltype(x)))
l -= sum(dist.logσ) + log2π * length(x) / 2
@inbounds @simd for i in eachindex(x, dist.μ, dist.σ)
@simd for i in eachindex(x, dist.μ, dist.σ)
l -= abs2(x[i] - dist.μ[i]) / (2 * abs2(dist.σ[i]))
end
return l
Expand All @@ -58,11 +58,11 @@ function StatsAPI.fit!(
w_tot = sum(w)
fill!(dist.μ, zero(T1))
fill!(dist.σ, zero(T2))
@inbounds @simd for i in eachindex(x, w)
dist.μ .+= x[i] .* w[i]
@simd for i in eachindex(x, w)
axpy!(w[i], x[i], dist.μ)
end
dist.μ ./= w_tot
@inbounds @simd for i in eachindex(x, w)
@simd for i in eachindex(x, w)
dist.σ .+= abs2.(x[i] .- dist.μ) .* w[i]
end
dist.σ .= sqrt.(dist.σ ./ w_tot)
Expand Down
91 changes: 59 additions & 32 deletions test/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ using SparseArrays
using StableRNGs
using Test

rng = StableRNG(63)
TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard")

## Settings

T, K = 50, 200
T, K = 100, 200

init = [0.4, 0.6]
init_guess = [0.5, 0.5]
Expand All @@ -29,26 +29,31 @@ p_guess = [[0.7, 0.3], [0.3, 0.7]]

σ = ones(2)

rng = StableRNG(63)
control_seqs = [fill(nothing, rand(rng, T:(2T))) for k in 1:K];
control_seq = reduce(vcat, control_seqs);
seq_ends = cumsum(length.(control_seqs));

## Uncontrolled

@testset "Normal" begin
@testset verbose = true "Normal" begin
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "DiagNormal" begin
@testset verbose = true "DiagNormal" begin
dists = [MvNormal(μ[1], Diagonal(abs2.(σ))), MvNormal(μ[2], Diagonal(abs2.(σ)))]
dists_guess = [
MvNormal(μ_guess[1], Diagonal(abs2.(σ))), MvNormal(μ_guess[2], Diagonal(abs2.(σ)))
Expand All @@ -57,68 +62,90 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "LightCategorical" begin
@testset verbose = true "LightCategorical" begin
dists = [LightCategorical(p[1]), LightCategorical(p[2])]
dists_guess = [LightCategorical(p_guess[1]), LightCategorical(p_guess[2])]

hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
rng = StableRNG(63)
if TEST_SUITE != "HMMBase"
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "LightDiagNormal" begin
@testset verbose = true "LightDiagNormal" begin
dists = [LightDiagNormal(μ[1], σ), LightDiagNormal(μ[2], σ)]
dists_guess = [LightDiagNormal(μ_guess[1], σ), LightDiagNormal(μ_guess[2], σ)]

hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
rng = StableRNG(63)
if TEST_SUITE != "HMMBase"
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "Normal (sparse)" begin
@testset verbose = true "Normal (sparse)" begin
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

hmm = HMM(init, sparse(trans), dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "Normal transposed" begin # issue 99
@testset verbose = true "Normal transposed" begin # issue 99
dists = [Normal(μ[1][1]), Normal(μ[2][1])]
dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])]

hmm = transpose_hmm(HMM(init, trans, dists))
hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess))

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess)
end
end

@testset "Normal and Exponential" begin # issue 101
@testset verbose = true "Normal and Exponential" begin # issue 101
dists = [Normal(μ[1][1]), Exponential(1.0)]
dists_guess = [Normal(μ_guess[1][1]), Exponential(0.8)]

hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

test_identical_hmmbase(rng, hmm, T; hmm_guess)
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
end
end
Loading

0 comments on commit 284189f

Please sign in to comment.