From cc5597ac06fbe5fa80596ae46efe4cda293af043 Mon Sep 17 00:00:00 2001
From: Patrick Aschermayr
Date: Sat, 25 Nov 2023 19:27:17 +0100
Subject: [PATCH] add crosschainmean
---
Project.toml | 2 +-
src/sampling/inference.jl | 23 +++++++++++++++++++++++
test/test-construction.jl | 4 +++-
3 files changed, 27 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index 9108371..8b9a7cb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr "]
-version = "0.3.15"
+version = "0.3.16"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
diff --git a/src/sampling/inference.jl b/src/sampling/inference.jl
index e95513d..5a5b8fa 100644
--- a/src/sampling/inference.jl
+++ b/src/sampling/inference.jl
@@ -295,6 +295,27 @@ function trace_to_posteriormean(
)
end
+"""
+$(SIGNATURES)
+Change trace.val to 3d Array and return Posterior mean as NamedTuple and as Vector for each parameter for each MCMC iteration across MCMC kernels
+
+# Examples
+```julia
+```
+
+"""
+function trace_to_crosschainmean(trace, transform)
+ @unpack tagged = transform
+ # Flatten Trace of different chains to 3-D Array
+ val3d = trace_to_3DArray(trace, transform)
+ # Get Average across chain parameter for each smc/mcmc iteration
+ crosschain_vec = [ map(iter -> mean(view(val3d, time, :, iter)), Base.OneTo(size(val3d, 3))) for time in Base.OneTo(size(val3d, 1)) ]
+ # Map back each vector to a NamedTuple with the across chain parameter mean estimates
+ crosschain_nt = [ ModelWrappers.unflatten(tagged.info.reconstruct, crosschain_vec[iter]) for iter in eachindex(crosschain_vec)]
+ # Return vector and NamedTuple
+ return crosschain_vec, crosschain_nt
+end
+
############################################################################################
"""
$(SIGNATURES)
@@ -461,6 +482,7 @@ export
trace_to_3DArrayᵤ,
trace_to_2DArrayᵤ,
trace_to_posteriormean,
+ trace_to_crosschainmean,
get_chainvals,
get_chaindiagnostics,
@@ -470,4 +492,5 @@ export
val_to_2DArray,
val_to_2DArrayᵤ,
Array2D_to_NamedTuple
+# array_to_posteriormean
diff --git a/test/test-construction.jl b/test/test-construction.jl
index 4d6aa9d..5e10b3a 100644
--- a/test/test-construction.jl
+++ b/test/test-construction.jl
@@ -52,6 +52,7 @@ tempermethod = tempermethods[iter]
## Inference Section
transform = Baytes.TraceTransform(trace, _obj.model)
postmean = trace_to_posteriormean(trace, transform)
+ trace_to_crosschainmean(trace, transform)
post3D = trace_to_3DArray(trace, transform)
post3Dᵤ = trace_to_3DArrayᵤ(trace, transform)
@@ -70,7 +71,7 @@ tempermethod = tempermethods[iter]
_vals2d = val_to_2DArray(_vals, _transform)
_vals2dᵤ = val_to_2DArrayᵤ(_vals, _transform)
_tup2d = Array2D_to_NamedTuple(_vals2d, _tagged)
-
+
###
#Check trace transforms
@@ -331,6 +332,7 @@ using Optim, NLSolversBase
## Inference Section
transform = Baytes.TraceTransform(trace, _obj.model)
postmean = trace_to_posteriormean(trace, transform)
+ trace_to_crosschainmean(trace, transform)
post3D = trace_to_3DArray(trace, transform)
post3Dᵤ = trace_to_3DArrayᵤ(trace, transform)