From 1b781abd6bd959dc0884119e577cad5e45caf9ff Mon Sep 17 00:00:00 2001 From: Patrick Aschermayr Date: Sun, 29 Oct 2023 14:13:55 +0100 Subject: [PATCH] Add Custom Kernel --- Project.toml | 9 ++-- src/Baytes.jl | 32 ++++++++++++-- test/runtests.jl | 1 + test/test-construction.jl | 5 +++ test/test-custom.jl | 90 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 8 deletions(-) create mode 100644 test/test-custom.jl diff --git a/Project.toml b/Project.toml index 3b56944..ff08467 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Baytes" uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f" authors = ["Patrick Aschermayr "] -version = "0.3.14" +version = "0.3.15" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -36,17 +36,18 @@ MCMCDiagnosticTools = "0.3" ModelWrappers = "0.5" PrettyTables = "2" ProgressMeter = "1.7" -SimpleUnPack = "1" Random = "1.9" +SimpleUnPack = "1" Statistics = "1.9" julia = "^1.9" [extras] +BaytesDiff = "12a76ff9-393d-487f-8b39-e615b97e2f77" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Distributions", "NLSolversBase", "Optim", "ForwardDiff"] +test = ["Test", "BaytesDiff", "Distributions", "NLSolversBase", "Optim", "ForwardDiff"] diff --git a/src/Baytes.jl b/src/Baytes.jl index 32ae6f9..c135a91 100644 --- a/src/Baytes.jl +++ b/src/Baytes.jl @@ -3,7 +3,21 @@ module Baytes ################################################################################ #Import modules -import BaytesCore: BaytesCore, update!, infer, results, init, init!, propose, propose!, propagate! +import BaytesCore: + BaytesCore, + update!, + infer, + results, + init, + init!, + propose, + propose!, + propagate!, + generate, + generate_showvalues, + get_result, + result! + using BaytesCore: BaytesCore, AbstractAlgorithm, @@ -25,7 +39,6 @@ using BaytesCore: Batch, SubSampled, adjust, - generate_showvalues, TemperingMethod, IterationTempering, JointTempering, @@ -41,7 +54,8 @@ using BaytesCore: import ModelWrappers: ModelWrappers, sample, - sample! + sample!, + predict using ModelWrappers: ModelWrappers, @@ -66,6 +80,12 @@ using ModelWrappers: UnflattenTypes, UnflattenStrict, UnflattenFlexible +#= +using BaytesDiff: + BaytesDiff, + ℓObjectiveResult, + ℓDensityResult +=# using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC, BaytesOptim @@ -202,7 +222,11 @@ export Optimizer, OptimConstructor, OptimDefault, - + OptimLBFG, + CustomAlgorithmDefault, + CustomAlgorithm, + CustomAlgorithmConstructor, + ## BaytesSMC SMC, SMCDefault, diff --git a/test/runtests.jl b/test/runtests.jl index 606aeb2..8fb7ede 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,5 @@ include("testhelper/TestHelper.jl") # Run Tests @testset "All tests" begin include("test-construction.jl") + include("test-custom.jl") end diff --git a/test/test-construction.jl b/test/test-construction.jl index 36ede49..4d6aa9d 100644 --- a/test/test-construction.jl +++ b/test/test-construction.jl @@ -373,6 +373,11 @@ using Optim, NLSolversBase end end +############################################################################################ +# Check if Custom Sampler works + + + ############################################################################################ #Utility @testset "Utility, maxiterations" begin diff --git a/test/test-custom.jl b/test/test-custom.jl new file mode 100644 index 0000000..8cfd817 --- /dev/null +++ b/test/test-custom.jl @@ -0,0 +1,90 @@ +############################################################################################ +# Models to be used in construction +objectives = [ + Objective(ModelWrapper(MyBaseModel(), myparameter1, (;), FlattenDefault()), data_uv), + Objective(ModelWrapper(MyBaseModel(), myparameter1, (;), FlattenDefault(; output = Float32)), data_uv) +] +Nchains = 4 +tempermethods = [ + IterationTempering(Float64, UpdateFalse(), 1.0, 1000), + IterationTempering(Float64, UpdateTrue(), 1.0, 1000), + JointTempering(Float64, UpdateFalse(), .5, Float64(Nchains), Nchains), + JointTempering(Float64, UpdateTrue(), .5, Float64(Nchains), Nchains) +] + +#= +iter = 2 +tempermethod = tempermethods[iter] +=# + +## Add custom Step for propagate +using BaytesDiff +import BaytesOptim: BaytesOptim, propagate +## Extend Custom Method +function propagate( + _rng::Random.AbstractRNG, algorithm::CustomAlgorithm, objective::Objective{<:ModelWrapper{MyBaseModel}}) + logobjective = BaytesDiff.ℓDensityResult(objective) + #logobjective.θᵤ[1] = 5 + logobjective.θᵤ[1] = rand() + return logobjective +end + +############################################################################################ +@testset "Sampling, type conversion" begin + for tempermethod in tempermethods + for iter in eachindex(objectives) + #println(tempermethod, " ", iter) + sampledefault = SampleDefault(; + dataformat=Batch(), + tempering=deepcopy(tempermethod), #IterationTempering(Float64, UpdateFalse(), 1.0, 1000), + chains=4, + iterations=100, + burnin=max(1, Int64(floor(10/10))), + thinning = 1, + safeoutput=false, + printoutput=false, + printdefault=PrintDefault(), + report=ProgressReport(; + bar=false, + log=SilentLog() + ), + ) + temperupdate = sampledefault.tempering.adaption + _obj = deepcopy(objectives[iter]) + _flattentype = _obj.model.info.reconstruct.default.output + + # Create Custom Algorithm + def = CustomAlgorithmDefault(; + generated=UpdateTrue() + ) + opt = CustomAlgorithm( + _rng, + _obj, + def, + ) + + ## Sample on its own + customconstruct = CustomAlgorithm(:μ) #CustomAlgorithm(keys(_obj.model.val)) + trace, algorithms = sample(_rng, _obj.model, _obj.data, customconstruct ; default = deepcopy(sampledefault)) + trace.val + + ## Combine with MCMC + mcmc = MCMC(NUTS,(:σ,); stepsize = ConfigStepsize(;stepsizeadaption = UpdateFalse())) + trace, algorithms = sample(_rng, _obj.model, _obj.data, customconstruct, mcmc ; default = deepcopy(sampledefault)) + trace.val + + ## Use as Propagation Kernel in SMC + ibis = SMCConstructor(customconstruct, SMCDefault(jitterthreshold=0.99, resamplingthreshold=1.0)) + trace, algorithms = sample(_rng, _obj.model, _obj.data, ibis; default = deepcopy(sampledefault)) + trace.val + ## Always update Gradient Result if new data is added + #!NOTE: But after first iteration, can capture results + @test isa(trace.summary.info.captured, UpdateFalse) + ## Continue sampling + newdat = randn(_rng, length(_obj.data)+100) + trace2, algorithms2 = sample!(100, _rng, _obj.model, newdat, trace, algorithms) + #!NOTE: But after first iteration, can capture results + @test isa(trace2.summary.info.captured, UpdateFalse) + end + end +end