Skip to content

Commit

Permalink
low level changes, transform fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Oct 15, 2024
1 parent 247b8e3 commit f81aa7a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 110 deletions.
13 changes: 7 additions & 6 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ struct BNNODE{C, K, IT <: NamedTuple,
param::P
l2std::Vector{Float64}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
dataset::D
physdt::Float64
MCMCkwargs::H
Expand All @@ -107,7 +108,7 @@ struct BNNODE{C, K, IT <: NamedTuple,
verbose::Bool
end
function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05], phynewstd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,), nchains = 1,
init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Expand All @@ -121,7 +122,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
(chain = adapt(FromFluxAdaptor(false, false), chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
phystd, phynewstd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
Expand Down Expand Up @@ -186,9 +187,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
reltol = 1.0f-3,
verbose = false,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
maxiters = nothing,)
@unpack chain, l2std, phystd, phynewstd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
Expand All @@ -206,7 +206,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
strategy = strategy, dataset = dataset,
draw_samples = draw_samples,
init_params = init_params,
physdt = physdt, l2std = l2std,
physdt = physdt, phynewstd = phynewstd,
l2std = l2std,
phystd = phystd,
priorsNNw = priorsNNw,
param = param,
Expand Down
8 changes: 4 additions & 4 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,14 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
mcmc_chain = MCMCChains.Chains(matrix_samples')

@info("Sampling Complete.")
@info("Current Physics Log-likelihood : ",
@info("Final Physics Log-likelihood : ",
ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]),
ℓπ.allstd))
@info("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ",
@info("Final Prior Log-likelihood : ", priorlogpdf(ℓπ, samples[end]))
@info("Final MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
if !(newloss isa Nothing)
@info("Current L2_LOSSY : ",
@info("Final L2_LOSSY : ",
ℓπ.L2_loss2(setparameters(ℓπ, samples[end]),
ℓπ.allstd))
end
Expand Down
34 changes: 20 additions & 14 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
dataset::D
priors::P
phystd::Vector{Float64}
phynewstd::Vector{Float64}
l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
Expand All @@ -20,7 +21,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
priors, phystd, phynewstd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Expand All @@ -36,6 +37,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
dataset,
priors,
phystd,
phynewstd,
l2std,
autodiff,
physdt,
Expand All @@ -45,7 +47,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
priors, phystd, phynewstd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple, estim_collocate)
new{
typeof(chain),
Expand All @@ -58,7 +60,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
prob,
chain, st, strategy,
dataset, priors,
phystd, l2std,
phystd, phynewstd,
l2std,
autodiff,
physdt,
extraparams,
Expand Down Expand Up @@ -136,10 +139,10 @@ function L2loss2(Tar::LogTargetDensity, θ)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
# can add phystdnew[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
(Tar.phynewstd[i]) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
end
Expand All @@ -162,7 +165,7 @@ function L2LossData(Tar::LogTargetDensity, θ)

L2logprob = 0
for i in 1:length(Tar.prob.u0)
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
# for u[i] ith vector must be added to dataset, nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(
MvNormal(nn[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.l2std[i] .*
Expand Down Expand Up @@ -395,7 +398,7 @@ end
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining,
dataset = [nothing],init_params = nothing,
draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric,
Expand Down Expand Up @@ -466,6 +469,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
* `autodiff`: Boolean Value for choice of Derivative Backend(default is numerical)
Expand All @@ -492,7 +496,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false,
Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Expand Down Expand Up @@ -558,7 +562,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)
phystd, phynewstd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Expand All @@ -574,7 +578,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, initial_θ))
@info("Current gradient loss against dataset Log-likelihood : ",
L2loss2(ℓπ, initial_θ))
end

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Expand Down Expand Up @@ -624,11 +629,12 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
adaptor; progress = progress, verbose = verbose)

@info("Sampling Complete.")
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, samples[end]))
@info("Final Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
@info("Final Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
@info("Final MSE against dataset Log-likelihood : ", L2LossData(ℓπ, samples[end]))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, samples[end]))
@info("Final gradient loss against dataset Log-likelihood : ",
L2loss2(ℓπ, samples[end]))
end

# return a chain(basic chain),samples and stats
Expand Down
2 changes: 1 addition & 1 deletion src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
multioutput = chain isa AbstractArray
if multioutput
!all(i -> i isa Lux.AbstractExplicitLayer, chain) &&
(chain = Lux.transform.(chain))
(chain = [adapt(FromFluxAdaptor(false, false), chain_i) for chain_i in chain])
else
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
Expand Down
11 changes: 6 additions & 5 deletions test/BPINN_PDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ end

sol1 = ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 400,
bcstd = [0.05, 0.05, 0.05, 0.05],
phystd = [0.05],
priorsNNw = (0.0, 1.0),
draw_samples = 200,
bcstd = [0.0025, 0.0025, 0.0025, 0.0025],
phystd = [0.005],
priorsNNw = (0.0, 0.5),
saveats = [1 / 100.0, 1 / 100.0])

xs = sol1.timepoints[1]
Expand All @@ -171,8 +171,9 @@ end
u_predict = pmean(sol1.ensemblesol[1])
u_real = [analytic_sol_func(xs[:, i][1], xs[:, i][2]) for i in 1:length(xs[1, :])]

@test mean(abs2.(u_predict .- u_real)) < 5e-3
@test all(abs.(u_predict .- u_real) .< 15e-3)
@test sum(abs2.(u_predict .- u_real)) < 0.1
@test u_predictu_real atol=0.1
end

@testset "Translating from Flux" begin
Expand Down
97 changes: 17 additions & 80 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ end
dataset = dataset,
draw_samples = 1000,
l2std = [0.1],
phystd = [0.03],
phystd = [0.01],
priorsNNw = (0.0,
1.0),
param = [
Expand All @@ -288,7 +288,8 @@ end
dataset = dataset,
draw_samples = 1000,
l2std = [0.1],
phystd = [0.03],
phystd = [0.01],
phynewstd = [0.01],
priorsNNw = (0.0,
1.0),
param = [
Expand All @@ -299,114 +300,50 @@ end
dataset = dataset,
draw_samples = 1000,
l2std = [0.1],
phystd = [0.03],
phystd = [0.01],
phynewstd = [0.05],
priorsNNw = (0.0,
1.0),
param = [
Normal(-7, 3)
], estim_collocate = true)
], numensemble = 200,
estim_collocate = true)

sol3lux_pestim = solve(prob, alg)

# testing timepoints
t = sol.t
#------------------------------ ahmc_bayesian_pinn_ode() call
# Mean of last 500 sampled parameter's curves(lux chains)[Ensemble predictions]
# Mean of last 200 sampled parameter's curves(lux chains)[Ensemble predictions]
θ = [vector_to_parameters(fhsampleslux12[i][1:(end - 1)], θinit)
for i in 750:length(fhsampleslux12)]
for i in 800:length(fhsampleslux12)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit)
for i in 750:length(fhsampleslux22)]
for i in 800:length(fhsampleslux22)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

@test mean(abs.(sol.u .- meanscurve2_2)) < 6e-2
@test mean(abs.(physsol1 .- meanscurve2_2)) < 6e-2
@test mean(abs.(sol.u .- meanscurve2_2)) < 5e-2
@test mean(abs.(physsol1 .- meanscurve2_2)) < 5e-2
@test mean(abs.(sol.u .- meanscurve2_1)) > mean(abs.(sol.u .- meanscurve2_2))
@test mean(abs.(physsol1 .- meanscurve2_1)) > mean(abs.(physsol1 .- meanscurve2_2))

# estimated parameters(lux chain)
param2 = mean(i[62] for i in fhsampleslux22[750:length(fhsampleslux22)])
@test abs(param2 - p) < abs(0.25 * p)
param2 = mean(i[62] for i in fhsampleslux22[800:length(fhsampleslux22)])
@test abs(param2 - p) < abs(0.2 * p)

param1 = mean(i[62] for i in fhsampleslux12[750:length(fhsampleslux12)])
@test abs(param1 - p) < abs(0.75 * p)
param1 = mean(i[62] for i in fhsampleslux12[800:length(fhsampleslux12)])
@test !(abs(param1 - p) < abs(0.2 * p))
@test abs(param2 - p) < abs(param1 - p)

#-------------------------- solve() call
# (lux chain)
@test mean(abs.(physsol2 .- pmean(sol3lux_pestim.ensemblesol[1]))) < 0.1
@test mean(abs.(physsol2 .- pmean(sol3lux_pestim.ensemblesol[1]))) < 5e-2
# estimated parameters(lux chain)
param3 = sol3lux_pestim.estimated_de_params[1]
@test abs(param3 - p) < abs(0.2 * p)
end

@testset "Example 4 - improvement" begin
function lotka_volterra(u, p, t)
# Model parameters.
α, β, γ, δ = p
# Current state.
x, y = u

# Evaluate differential equations.
dx =- β * y) * x # prey
dy =* x - γ) * y # predator

return [dx, dy]
end

# initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 4.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)

# Solve using OrdinaryDiffEq.jl solver
dt = 0.2
solution = solve(prob, Tsit5(); saveat = dt)

times = solution.t
u = hcat(solution.u...)
x = u[1, :] + (0.8 .* randn(length(u[1, :])))
y = u[2, :] + (0.8 .* randn(length(u[2, :])))
dataset = [x, y, times]

chain = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh),
Lux.Dense(6, 2))

alg1 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.2, 0.2],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 1.0),
param = [
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5)])

alg2 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.2, 0.2],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 1.0),
param = [
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5)], estim_collocate = true)

@time sol_pestim1 = solve(prob, alg1; saveat = dt)
@time sol_pestim2 = solve(prob, alg2; saveat = dt)

unsafe_comparisons(true)
bitvec = abs.(p .- sol_pestim1.estimated_de_params) .>
abs.(p .- sol_pestim2.estimated_de_params)
@test bitvec == ones(size(bitvec))
end

0 comments on commit f81aa7a

Please sign in to comment.