Skip to content

Commit

Permalink
Add interface-based testing (#22)
Browse files Browse the repository at this point in the history
* Add interace-based testing

* Add compat for RI

* Leverage RequiredInterfaces v0.1.3

* Add failing interface

* Test wrong interface
  • Loading branch information
gdalle authored Jul 11, 2023
1 parent 77e648c commit 5e53cb4
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 779 deletions.
27 changes: 26 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle", "Maxime Mouchet"]
version = "0.2.0"
version = "0.2.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,6 +10,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -25,7 +26,31 @@ ChainRulesCore = "1.16"
DensityInterface = "0.4"
Distributions = "0.25"
PrecompileTools = "1.1"
RequiredInterfaces = "0.1.3"
Requires = "1.3"
SimpleUnPack = "1.1"
StatsAPI = "1.6"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Distributions", "Documenter", "FiniteDifferences", "ForwardDiff", "HMMBase", "JET", "JuliaFormatter", "LinearAlgebra", "Pkg", "Random", "SimpleUnPack", "SparseArrays", "StaticArrays", "Statistics", "Test", "Zygote"]
12 changes: 9 additions & 3 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "120abd36dc09fb9f88978b795d680e2bb979c4ce"
project_hash = "d55ef3120e8c75c1ed5c15bb45fc95d39039c78f"

[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
Expand Down Expand Up @@ -314,10 +314,10 @@ uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566"
version = "2.8.1+1"

[[deps.HiddenMarkovModels]]
deps = ["ChainRulesCore", "DensityInterface", "Distributions", "LinearAlgebra", "PrecompileTools", "Random", "Requires", "SimpleUnPack", "StatsAPI"]
deps = ["ChainRulesCore", "DensityInterface", "Distributions", "LinearAlgebra", "PrecompileTools", "Random", "RequiredInterfaces", "Requires", "SimpleUnPack", "StatsAPI"]
path = ".."
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.2.0"
version = "0.2.1"

[deps.HiddenMarkovModels.extensions]
HiddenMarkovModelsHMMBaseExt = "HMMBase"
Expand Down Expand Up @@ -764,6 +764,12 @@ git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691"
uuid = "05181044-ff0b-4ac5-8273-598c1e38db00"
version = "1.0.0"

[[deps.RequiredInterfaces]]
deps = ["InteractiveUtils", "Logging", "Test"]
git-tree-sha1 = "0431cf93378698d6ea99662b8bd188e59221d1b6"
uuid = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
version = "0.1.3"

[[deps.Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if benchmarks_successful
insert!(pages, length(pages) - 1, "Benchmarks" => "benchmarks.md")
end

format = Documenter.HTML(;
fmt = Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://gdalle.github.io/HiddenMarkovModels.jl",
edit_link="main",
Expand All @@ -37,7 +37,7 @@ makedocs(;
authors="Guillaume Dalle, Maxime Mouchet and contributors",
repo="https://github.com/gdalle/HiddenMarkovModels.jl/blob/{commit}{path}#{line}",
sitename="HiddenMarkovModels.jl",
format=format,
format=fmt,
pages=pages,
linkcheck=true,
strict=false,
Expand Down
28 changes: 26 additions & 2 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
In the meantime, you can take a look at the files in `test`, which demonstrate more sophisticated ways to use the package.

```@repl tuto
using HiddenMarkovModels, Distributions
using HiddenMarkovModels
using Distributions
```

Constructing an HMM:
## Using the built-in HMM

Constructing a model:

```@repl tuto
function random_gaussian_hmm(N)
Expand Down Expand Up @@ -56,3 +59,24 @@ first(logL_evolution), last(logL_evolution)
transition_matrix(hmm_est)
[obs_distribution(hmm_est, i) for i in 1:length(hmm)]
```

## Making your own HMM

The built-in HMM is perfect when the initial state distribution `p`, transition matrix `A` and emission distributions `dists` are three separate objects, which means their re-estimation can be done separately.
But in some cases these parameters might be correlated.
For instance, you may want an HMM whose initial state distribution always corresponds to the equilibrium distribution associated with the transition matrix.

In such cases, it is necessary to implement a new subtype of [`AbstractHMM`](@ref) with all its required methods.
To ascertain that your type indeed satisfies the interface, you can use [RequiredInterfaces.jl](https://github.com/Seelengrab/RequiredInterfaces.jl) as follows:

```@repl tuto
using RequiredInterfaces: check_interface_implemented
struct EmptyHMM end
check_interface_implemented(AbstractHMM, HMM)
check_interface_implemented(AbstractHMM, EmptyHMM)
```

Note that this test does not check the `fit!` method.
Since it is only used in the Baum-Welch algorithm, it is an optional part of the `AbstractHMM` interface.
1 change: 1 addition & 0 deletions src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using Distributions:
using LinearAlgebra: Diagonal, dot, mul!
using PrecompileTools: @compile_workload, @setup_workload
using Random: AbstractRNG, GLOBAL_RNG
using RequiredInterfaces: @required
using Requires: @require
using SimpleUnPack: @unpack
using StatsAPI: StatsAPI, fit, fit!
Expand Down
7 changes: 7 additions & 0 deletions src/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ const AbstractHMM = AbstractHiddenMarkovModel

@inline DensityInterface.DensityKind(::AbstractHMM) = HasDensity()

@required AbstractHMM begin
Base.length(::AbstractHMM)
initial_distribution(::AbstractHMM)
transition_matrix(::AbstractHMM)
obs_distribution(::AbstractHMM, ::Integer)
end

"""
length(hmm::AbstractHMM)
Expand Down
39 changes: 31 additions & 8 deletions src/learning/baum_welch.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
function baum_welch!(hmm::AbstractHMM, obs_seqs; max_iterations, rtol)
function baum_welch!(
hmm::AbstractHMM, obs_seqs; max_iterations, rtol, check_loglikelihood_increasing=true
)
# Pre-allocate nearly all necessary memory
logB = loglikelihoods(hmm, obs_seqs[1])
fb = initialize_forward_backward(hmm, logB)
Expand Down Expand Up @@ -38,7 +40,7 @@ function baum_welch!(hmm::AbstractHMM, obs_seqs; max_iterations, rtol)
(logL_evolution[end] - logL_evolution[end - 1]) /
abs(logL_evolution[end - 1])
)
if progress < -eps(progress)
if check_loglikelihood_increasing && progress < -eps(progress)
error("Loglikelihood decreased in Baum-Welch")
elseif progress < rtol
break
Expand All @@ -50,32 +52,53 @@ function baum_welch!(hmm::AbstractHMM, obs_seqs; max_iterations, rtol)
end

"""
baum_welch(hmm_init, obs_seq; max_iterations, rtol)
baum_welch(
hmm_init, obs_seq;
max_iterations, rtol, check_loglikelihood_increasing
)
Apply the Baum-Welch algorithm to estimate the parameters of an HMM and return a tuple `(hmm, logL_evolution)`.
The procedure is based on a single observation sequence and initialized with `hmm_init`.
"""
function baum_welch(hmm_init::AbstractHMM, obs_seq; max_iterations=100, rtol=1e-3)
function baum_welch(
hmm_init::AbstractHMM,
obs_seq;
max_iterations=100,
rtol=1e-3,
check_loglikelihood_increasing=true,
)
hmm = deepcopy(hmm_init)
logL_evolution = baum_welch!(hmm, [obs_seq]; max_iterations, rtol)
logL_evolution = baum_welch!(
hmm, [obs_seq]; max_iterations, rtol, check_loglikelihood_increasing
)
return hmm, logL_evolution
end

"""
baum_welch(hmm_init, obs_seqs, nb_seqs; max_iterations, rtol)
baum_welch(
hmm_init, obs_seqs, nb_seqs;
max_iterations, rtol, check_loglikelihood_increasing
)
Apply the Baum-Welch algorithm to estimate the parameters of an HMM and return a tuple `(hmm, logL_evolution)`.
The procedure is based on multiple observation sequences and initialized with `hmm_init`.
"""
function baum_welch(
hmm_init::AbstractHMM, obs_seqs, nb_seqs::Integer; max_iterations=100, rtol=1e-3
hmm_init::AbstractHMM,
obs_seqs,
nb_seqs::Integer;
max_iterations=100,
rtol=1e-3,
check_loglikelihood_increasing=true,
)
if nb_seqs != length(obs_seqs)
throw(ArgumentError("nb_seqs != length(obs_seqs)"))
end
hmm = deepcopy(hmm_init)
logL_evolution = baum_welch!(hmm, obs_seqs; max_iterations, rtol)
logL_evolution = baum_welch!(
hmm, obs_seqs; max_iterations, rtol, check_loglikelihood_increasing
)
return hmm, logL_evolution
end
Loading

0 comments on commit 5e53cb4

Please sign in to comment.