From 1b781abd6bd959dc0884119e577cad5e45caf9ff Mon Sep 17 00:00:00 2001
From: Patrick Aschermayr
Date: Sun, 29 Oct 2023 14:13:55 +0100
Subject: [PATCH] Add Custom Kernel
---
Project.toml | 9 ++--
src/Baytes.jl | 32 ++++++++++++--
test/runtests.jl | 1 +
test/test-construction.jl | 5 +++
test/test-custom.jl | 90 +++++++++++++++++++++++++++++++++++++++
5 files changed, 129 insertions(+), 8 deletions(-)
create mode 100644 test/test-custom.jl
diff --git a/Project.toml b/Project.toml
index 3b56944..ff08467 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr "]
-version = "0.3.14"
+version = "0.3.15"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -36,17 +36,18 @@ MCMCDiagnosticTools = "0.3"
ModelWrappers = "0.5"
PrettyTables = "2"
ProgressMeter = "1.7"
-SimpleUnPack = "1"
Random = "1.9"
+SimpleUnPack = "1"
Statistics = "1.9"
julia = "^1.9"
[extras]
+BaytesDiff = "12a76ff9-393d-487f-8b39-e615b97e2f77"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
-ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["Test", "Distributions", "NLSolversBase", "Optim", "ForwardDiff"]
+test = ["Test", "BaytesDiff", "Distributions", "NLSolversBase", "Optim", "ForwardDiff"]
diff --git a/src/Baytes.jl b/src/Baytes.jl
index 32ae6f9..c135a91 100644
--- a/src/Baytes.jl
+++ b/src/Baytes.jl
@@ -3,7 +3,21 @@ module Baytes
################################################################################
#Import modules
-import BaytesCore: BaytesCore, update!, infer, results, init, init!, propose, propose!, propagate!
+import BaytesCore:
+ BaytesCore,
+ update!,
+ infer,
+ results,
+ init,
+ init!,
+ propose,
+ propose!,
+ propagate!,
+ generate,
+ generate_showvalues,
+ get_result,
+ result!
+
using BaytesCore:
BaytesCore,
AbstractAlgorithm,
@@ -25,7 +39,6 @@ using BaytesCore:
Batch,
SubSampled,
adjust,
- generate_showvalues,
TemperingMethod,
IterationTempering,
JointTempering,
@@ -41,7 +54,8 @@ using BaytesCore:
import ModelWrappers:
ModelWrappers,
sample,
- sample!
+ sample!,
+ predict
using ModelWrappers:
ModelWrappers,
@@ -66,6 +80,12 @@ using ModelWrappers:
UnflattenTypes,
UnflattenStrict,
UnflattenFlexible
+#=
+using BaytesDiff:
+ BaytesDiff,
+ ℓObjectiveResult,
+ ℓDensityResult
+=#
using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC, BaytesOptim
@@ -202,7 +222,11 @@ export
Optimizer,
OptimConstructor,
OptimDefault,
-
+ OptimLBFG,
+ CustomAlgorithmDefault,
+ CustomAlgorithm,
+ CustomAlgorithmConstructor,
+
## BaytesSMC
SMC,
SMCDefault,
diff --git a/test/runtests.jl b/test/runtests.jl
index 606aeb2..8fb7ede 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -24,4 +24,5 @@ include("testhelper/TestHelper.jl")
# Run Tests
@testset "All tests" begin
include("test-construction.jl")
+ include("test-custom.jl")
end
diff --git a/test/test-construction.jl b/test/test-construction.jl
index 36ede49..4d6aa9d 100644
--- a/test/test-construction.jl
+++ b/test/test-construction.jl
@@ -373,6 +373,11 @@ using Optim, NLSolversBase
end
end
+############################################################################################
+# Check if Custom Sampler works
+
+
+
############################################################################################
#Utility
@testset "Utility, maxiterations" begin
diff --git a/test/test-custom.jl b/test/test-custom.jl
new file mode 100644
index 0000000..8cfd817
--- /dev/null
+++ b/test/test-custom.jl
@@ -0,0 +1,90 @@
+############################################################################################
+# Models to be used in construction
+objectives = [
+ Objective(ModelWrapper(MyBaseModel(), myparameter1, (;), FlattenDefault()), data_uv),
+ Objective(ModelWrapper(MyBaseModel(), myparameter1, (;), FlattenDefault(; output = Float32)), data_uv)
+]
+Nchains = 4
+tempermethods = [
+ IterationTempering(Float64, UpdateFalse(), 1.0, 1000),
+ IterationTempering(Float64, UpdateTrue(), 1.0, 1000),
+ JointTempering(Float64, UpdateFalse(), .5, Float64(Nchains), Nchains),
+ JointTempering(Float64, UpdateTrue(), .5, Float64(Nchains), Nchains)
+]
+
+#=
+iter = 2
+tempermethod = tempermethods[iter]
+=#
+
+## Add custom Step for propagate
+using BaytesDiff
+import BaytesOptim: BaytesOptim, propagate
+## Extend Custom Method
+function propagate(
+ _rng::Random.AbstractRNG, algorithm::CustomAlgorithm, objective::Objective{<:ModelWrapper{MyBaseModel}})
+ logobjective = BaytesDiff.ℓDensityResult(objective)
+ #logobjective.θᵤ[1] = 5
+ logobjective.θᵤ[1] = rand()
+ return logobjective
+end
+
+############################################################################################
+@testset "Sampling, type conversion" begin
+ for tempermethod in tempermethods
+ for iter in eachindex(objectives)
+ #println(tempermethod, " ", iter)
+ sampledefault = SampleDefault(;
+ dataformat=Batch(),
+ tempering=deepcopy(tempermethod), #IterationTempering(Float64, UpdateFalse(), 1.0, 1000),
+ chains=4,
+ iterations=100,
+ burnin=max(1, Int64(floor(10/10))),
+ thinning = 1,
+ safeoutput=false,
+ printoutput=false,
+ printdefault=PrintDefault(),
+ report=ProgressReport(;
+ bar=false,
+ log=SilentLog()
+ ),
+ )
+ temperupdate = sampledefault.tempering.adaption
+ _obj = deepcopy(objectives[iter])
+ _flattentype = _obj.model.info.reconstruct.default.output
+
+ # Create Custom Algorithm
+ def = CustomAlgorithmDefault(;
+ generated=UpdateTrue()
+ )
+ opt = CustomAlgorithm(
+ _rng,
+ _obj,
+ def,
+ )
+
+ ## Sample on its own
+ customconstruct = CustomAlgorithm(:μ) #CustomAlgorithm(keys(_obj.model.val))
+ trace, algorithms = sample(_rng, _obj.model, _obj.data, customconstruct ; default = deepcopy(sampledefault))
+ trace.val
+
+ ## Combine with MCMC
+ mcmc = MCMC(NUTS,(:σ,); stepsize = ConfigStepsize(;stepsizeadaption = UpdateFalse()))
+ trace, algorithms = sample(_rng, _obj.model, _obj.data, customconstruct, mcmc ; default = deepcopy(sampledefault))
+ trace.val
+
+ ## Use as Propagation Kernel in SMC
+ ibis = SMCConstructor(customconstruct, SMCDefault(jitterthreshold=0.99, resamplingthreshold=1.0))
+ trace, algorithms = sample(_rng, _obj.model, _obj.data, ibis; default = deepcopy(sampledefault))
+ trace.val
+ ## Always update Gradient Result if new data is added
+ #!NOTE: But after first iteration, can capture results
+ @test isa(trace.summary.info.captured, UpdateFalse)
+ ## Continue sampling
+ newdat = randn(_rng, length(_obj.data)+100)
+ trace2, algorithms2 = sample!(100, _rng, _obj.model, newdat, trace, algorithms)
+ #!NOTE: But after first iteration, can capture results
+ @test isa(trace2.summary.info.captured, UpdateFalse)
+ end
+ end
+end