Skip to content

Commit

Permalink
Add Custom Kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Oct 29, 2023
1 parent 558618f commit 1b781ab
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 8 deletions.
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr <[email protected]>"]
version = "0.3.14"
version = "0.3.15"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -36,17 +36,18 @@ MCMCDiagnosticTools = "0.3"
ModelWrappers = "0.5"
PrettyTables = "2"
ProgressMeter = "1.7"
SimpleUnPack = "1"
Random = "1.9"
SimpleUnPack = "1"
Statistics = "1.9"
julia = "^1.9"

[extras]
BaytesDiff = "12a76ff9-393d-487f-8b39-e615b97e2f77"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Distributions", "NLSolversBase", "Optim", "ForwardDiff"]
test = ["Test", "BaytesDiff", "Distributions", "NLSolversBase", "Optim", "ForwardDiff"]
32 changes: 28 additions & 4 deletions src/Baytes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@ module Baytes

################################################################################
#Import modules
import BaytesCore: BaytesCore, update!, infer, results, init, init!, propose, propose!, propagate!
import BaytesCore:
BaytesCore,
update!,
infer,
results,
init,
init!,
propose,
propose!,
propagate!,
generate,
generate_showvalues,
get_result,
result!

using BaytesCore:
BaytesCore,
AbstractAlgorithm,
Expand All @@ -25,7 +39,6 @@ using BaytesCore:
Batch,
SubSampled,
adjust,
generate_showvalues,
TemperingMethod,
IterationTempering,
JointTempering,
Expand All @@ -41,7 +54,8 @@ using BaytesCore:
import ModelWrappers:
ModelWrappers,
sample,
sample!
sample!,
predict

using ModelWrappers:
ModelWrappers,
Expand All @@ -66,6 +80,12 @@ using ModelWrappers:
UnflattenTypes,
UnflattenStrict,
UnflattenFlexible
#=
using BaytesDiff:
BaytesDiff,
ℓObjectiveResult,
ℓDensityResult
=#

using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC, BaytesOptim

Expand Down Expand Up @@ -202,7 +222,11 @@ export
Optimizer,
OptimConstructor,
OptimDefault,

OptimLBFG,
CustomAlgorithmDefault,
CustomAlgorithm,
CustomAlgorithmConstructor,

## BaytesSMC
SMC,
SMCDefault,
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ include("testhelper/TestHelper.jl")
# Run Tests
@testset "All tests" begin
include("test-construction.jl")
include("test-custom.jl")
end
5 changes: 5 additions & 0 deletions test/test-construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,11 @@ using Optim, NLSolversBase
end
end

############################################################################################
# Check if Custom Sampler works



############################################################################################
#Utility
@testset "Utility, maxiterations" begin
Expand Down
90 changes: 90 additions & 0 deletions test/test-custom.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
############################################################################################
# Models to be used in construction
objectives = [
Objective(ModelWrapper(MyBaseModel(), myparameter1, (;), FlattenDefault()), data_uv),
Objective(ModelWrapper(MyBaseModel(), myparameter1, (;), FlattenDefault(; output = Float32)), data_uv)
]
Nchains = 4
tempermethods = [
IterationTempering(Float64, UpdateFalse(), 1.0, 1000),
IterationTempering(Float64, UpdateTrue(), 1.0, 1000),
JointTempering(Float64, UpdateFalse(), .5, Float64(Nchains), Nchains),
JointTempering(Float64, UpdateTrue(), .5, Float64(Nchains), Nchains)
]

#=
iter = 2
tempermethod = tempermethods[iter]
=#

## Add custom Step for propagate
using BaytesDiff
import BaytesOptim: BaytesOptim, propagate
## Extend Custom Method
function propagate(
_rng::Random.AbstractRNG, algorithm::CustomAlgorithm, objective::Objective{<:ModelWrapper{MyBaseModel}})
logobjective = BaytesDiff.ℓDensityResult(objective)
#logobjective.θᵤ[1] = 5
logobjective.θᵤ[1] = rand()
return logobjective
end

############################################################################################
@testset "Sampling, type conversion" begin
for tempermethod in tempermethods
for iter in eachindex(objectives)
#println(tempermethod, " ", iter)
sampledefault = SampleDefault(;
dataformat=Batch(),
tempering=deepcopy(tempermethod), #IterationTempering(Float64, UpdateFalse(), 1.0, 1000),
chains=4,
iterations=100,
burnin=max(1, Int64(floor(10/10))),
thinning = 1,
safeoutput=false,
printoutput=false,
printdefault=PrintDefault(),
report=ProgressReport(;
bar=false,
log=SilentLog()
),
)
temperupdate = sampledefault.tempering.adaption
_obj = deepcopy(objectives[iter])
_flattentype = _obj.model.info.reconstruct.default.output

# Create Custom Algorithm
def = CustomAlgorithmDefault(;
generated=UpdateTrue()
)
opt = CustomAlgorithm(
_rng,
_obj,
def,
)

## Sample on its own
customconstruct = CustomAlgorithm() #CustomAlgorithm(keys(_obj.model.val))
trace, algorithms = sample(_rng, _obj.model, _obj.data, customconstruct ; default = deepcopy(sampledefault))
trace.val

## Combine with MCMC
mcmc = MCMC(NUTS,(,); stepsize = ConfigStepsize(;stepsizeadaption = UpdateFalse()))
trace, algorithms = sample(_rng, _obj.model, _obj.data, customconstruct, mcmc ; default = deepcopy(sampledefault))
trace.val

## Use as Propagation Kernel in SMC
ibis = SMCConstructor(customconstruct, SMCDefault(jitterthreshold=0.99, resamplingthreshold=1.0))
trace, algorithms = sample(_rng, _obj.model, _obj.data, ibis; default = deepcopy(sampledefault))
trace.val
## Always update Gradient Result if new data is added
#!NOTE: But after first iteration, can capture results
@test isa(trace.summary.info.captured, UpdateFalse)
## Continue sampling
newdat = randn(_rng, length(_obj.data)+100)
trace2, algorithms2 = sample!(100, _rng, _obj.model, newdat, trace, algorithms)
#!NOTE: But after first iteration, can capture results
@test isa(trace2.summary.info.captured, UpdateFalse)
end
end
end

2 comments on commit 1b781ab

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/94345

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.15 -m "<description of version>" 1b781abd6bd959dc0884119e577cad5e45caf9ff
git push origin v0.3.15

Please sign in to comment.