Skip to content

Commit

Permalink
improved BPINN solvers-2
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Oct 19, 2024
1 parent c4330a7 commit fbf2463
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 30 deletions.
14 changes: 8 additions & 6 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric),
Expand Down Expand Up @@ -86,6 +86,7 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
param <: Union{Nothing, Vector{<:Distribution}}
l2std::Vector{Float64}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
physdt::Float64
MCMCkwargs <: NamedTuple
Expand All @@ -102,16 +103,16 @@ end

function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,),
nchains = 1, init_params = nothing,
phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false, autodiff = false, progress = false, verbose = false)
chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return BNNODE(chain, kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd,
dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
phynewstd, dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
Integratorkwargs, numensemble, estim_collocate, autodiff, progress, verbose)
end

Expand Down Expand Up @@ -157,7 +158,7 @@ end
function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing,
timeseries_errors = true, save_everystep = true, adaptive = false,
abstol = 1.0f-6, reltol = 1.0f-3, verbose = false, saveat = 1 / 50.0,
maxiters = nothing, numensemble = floor(Int, alg.draw_samples / 3))
maxiters = nothing)
(; chain, param, strategy, draw_samples, numensemble, verbose) = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
Expand All @@ -168,7 +169,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt

mcmcchain, samples, statistics = ahmc_bayesian_pinn_ode(
prob, chain; strategy, alg.dataset, alg.draw_samples, alg.init_params,
alg.physdt, alg.l2std, alg.phystd, alg.priorsNNw, param, alg.nchains, alg.autodiff,
alg.physdt, alg.l2std, alg.phystd, alg.phynewstd,
alg.priorsNNw, param, alg.nchains, alg.autodiff,
Kernel = alg.kernel, alg.Adaptorkwargs, alg.Integratorkwargs,
alg.MCMCkwargs, alg.progress, alg.verbose, alg.estim_collocate)

Expand Down
145 changes: 135 additions & 10 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,91 @@
dataset <: Union{Nothing, Vector{<:Matrix{<:Real}}}
priors <: Vector{<:Distribution}
allstd::Vector{Vector{Float64}}
phynewstd::Vector{Float64}
names::Tuple
extraparams::Int
init_params <: Union{AbstractVector, NamedTuple, ComponentArray}
full_loglikelihood
L2_loss2
Φ
end

function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) + priorlogpdf(ltd, θ) +
L2LossData(ltd, θ)
if Tar.L2_loss2 === nothing
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
priorlogpdf(ltd, θ) + L2LossData(ltd, θ)
else
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
priorlogpdf(ltd, θ) + L2LossData(ltd, θ) + ltd.L2_loss2(setparameters(ltd, θ), ltd.phynewstd)
end
end


# you get a vector of losses
function get_lossy(pinnrep, dataset, Dict_differentials)
eqs = pinnrep.eqs
depvars = pinnrep.depvars #depvar order is same as dataset

# Dict_differentials is filled with Differential operator => diff_i key-value pairs
# masking operation
eqs_new = substitute.(eqs, Ref(Dict_differentials))

to_subs, tobe_subs = get_symbols(dataset, depvars, eqs)

# for values of all depvars at corresponding indvar values in dataset, create dictionaries {Dict(x(t) => 1.0496435863173237, y(t) => 1.9227770685615337)}
# In each Dict, num form of depvar is key to its value at certain coords of indvars, n_dicts = n_rows_dataset(or n_indvar_coords_dataset)
eq_subs = [Dict(tobe_subs[depvar] => to_subs[depvar][i] for depvar in depvars)
for i in 1:size(dataset[1][:, 1])[1]]

# for each dataset point(eq_sub dictionary), substitute in masked equations
# n_collocated_equations = n_rows_dataset(or n_indvar_coords_dataset)
masked_colloc_equations = [[substitute(eq, eq_sub) for eq in eqs_new]
for eq_sub in eq_subs]
# now we have vector of dataset depvar's collocated equations

# reverse dict for re-substituting values of Differential(t)(u(t)) etc
rev_Dict_differentials = Dict(value => key for (key, value) in Dict_differentials)

# unmask Differential terms in masked_colloc_equations
colloc_equations = [substitute.(masked_colloc_equation, Ref(rev_Dict_differentials))
for masked_colloc_equation in masked_colloc_equations]

# nested vector of datafree_pde_loss_functions (as in discretize.jl)
# each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset)
# zip each colloc equation with args for each build_loss call per equation vector
datafree_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar)
for (eq, pde_indvar, integration_indvar) in zip(
colloc_equation,
pinnrep.pde_indvars,
pinnrep.pde_integration_vars)]
for colloc_equation in colloc_equations]

return datafree_colloc_loss_functions
end

function get_symbols(dataset, depvars, eqs)
# take only values of depvars from dataset
depvar_vals = [dataset_i[:, 1] for dataset_i in dataset]
# order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same
to_subs = Dict(depvars .=> depvar_vals)

numform_vars = Symbolics.get_variables.(eqs)
Eq_vars = unique(reduce(vcat, numform_vars))
# got equation's depvar num format {x(t)} for use in substitute()

tobe_subs = Dict()
for a in depvars
for i in Eq_vars
expr = toexpr(i)
if (expr isa Expr) && (expr.args[1] == a)
tobe_subs[a] = i
end
end
end
# depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t))

return to_subs, tobe_subs
end

@views function setparameters(ltd::PDELogTargetDensity, θ)
Expand Down Expand Up @@ -180,8 +254,8 @@ end
"""
ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
Expand Down Expand Up @@ -210,6 +284,7 @@ end
each dependant variable of interest.
* `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE
equations.
* `phynewstd`: Vector of standard deviations of new loss term.
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems.
Expand All @@ -235,14 +310,54 @@ end
"""
function ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000, bcstd = [0.01], l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
Kernel = HMC(0.1, 30), Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
pinnrep = symbolic_discretize(pde_system, discretization)
dataset_pde, dataset_bc = discretization.dataset

newloss = if Dict_differentials isa Nothing
nothing
else
datafree_colloc_loss_functions = get_lossy(pinnrep, dataset_pde, Dict_differentials)
# equals number of indvar coords in dataset
# add case for if parameters present in bcs?

train_sets_pde = get_dataset_train_points(pde_system.eqs,
dataset_pde,
pinnrep)
colloc_train_sets = [[hcat(train_sets_pde[i][:, j]...)'
for i in eachindex(datafree_colloc_loss_functions[1])]
for j in eachindex(datafree_colloc_loss_functions)]

# for each datafree_colloc_loss_function create loss_functions by passing dataset's indvar coords as train_sets_pde.
# placeholder strategy = GridTraining(0.1), datafree_bc_loss_function and train_sets_bc must be nothing
# order of indvar coords will be same as corresponding depvar coords values in dataset provided in get_lossy() call.
pde_loss_function_points = [merge_strategy_with_loglikelihood_function(
pinnrep,
GridTraining(0.1),
datafree_colloc_loss_functions[i],
nothing;
train_sets_pde = colloc_train_sets[i],
train_sets_bc = nothing)[1]
for i in eachindex(datafree_colloc_loss_functions)]

function L2_loss2(θ, phynewstd)
# first vector of losses,from tuple -> pde losses, first[1] pde loss
pde_loglikelihoods = [sum([pde_loss_function(θ, phynewstd[i])
for (i, pde_loss_function) in enumerate(pde_loss_functions)])
for pde_loss_functions in pde_loss_function_points]

# bc_loglikelihoods = [sum([bc_loss_function(θ, phynewstd[i]) for (i, bc_loss_function) in enumerate(pde_loss_function_points[1])]) for pde_loss_function_points in pde_loss_functions]
# for (j, bc_loss_function) in enumerate(bc_loss_functions)]

return sum(pde_loglikelihoods)
end
end

# add overall functionality for BC dataset points (case of parametric BC) ?
if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
dataset = nothing
elseif dataset_bc isa Nothing
Expand Down Expand Up @@ -306,7 +421,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = PDELogTargetDensity(
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std],
nparameters, strategy, dataset, priors, [phystd, bcstd, l2std], phynewstd,
names, ninv, initial_nnθ, full_weighted_loglikelihood, Φ)

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Expand All @@ -322,6 +437,11 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, initial_θ))
@printf("Current MSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, initial_θ))
if !(newloss isa Nothing)
@printf("Current new loss : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, initial_θ),
ℓπ.phynewstd))
end
end

# parallel sampling option
Expand Down Expand Up @@ -370,11 +490,16 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;

if verbose
@printf("Sampling Complete.\n")
@printf("Current Physics Log-likelihood : %g\n",
@printf("Final Physics Log-likelihood : %g\n",
ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]), ℓπ.allstd))
@printf("Current Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood : %g\n",
@printf("Final Prior Log-likelihood : %g\n", priorlogpdf(ℓπ, samples[end]))
@printf("Final MSE against dataset Log-likelihood : %g\n",
L2LossData(ℓπ, samples[end]))
if !(newloss isa Nothing)
@printf("Final L2_LOSSY : %g\n",
ℓπ.L2_loss2(setparameters(ℓπ, samples[end]),
ℓπ.phynewstd))
end
end

fullsolution = BPINNstats(mcmc_chain, samples, stats)
Expand Down
24 changes: 13 additions & 11 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
priors <: Vector{<:Distribution}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
Expand Down Expand Up @@ -97,7 +98,7 @@ suggested extra loss function for ODE solver case
for i in 1:length(ltd.prob.u0)
physlogprob += logpdf(
MvNormal(deri_physsol[i, :],
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(nnsol[i, :]))))),
Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
nnsol[i, :]
)
end
Expand Down Expand Up @@ -263,7 +264,7 @@ end
"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
l2std = [0.05], phystd = [0.05], priorsNNw = (0.0, 2.0),
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down Expand Up @@ -336,6 +337,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
Expand Down Expand Up @@ -366,10 +368,10 @@ Incase you are only solving the Equations for solution, do not provide dataset
function ahmc_bayesian_pinn_ode(
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, autodiff = false,
Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false, estim_collocate = false)
@assert !isinplace(prob) "The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."
Expand Down Expand Up @@ -419,7 +421,7 @@ function ahmc_bayesian_pinn_ode(
smodel = StatefulLuxLayer{true}(chain, nothing, st)
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, smodel, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)
phystd, phynewstd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

if verbose
@printf("Current Physics Log-likelihood: %g\n", physloglikelihood(ℓπ, initial_θ))
Expand Down Expand Up @@ -483,13 +485,13 @@ function ahmc_bayesian_pinn_ode(

if verbose
println("Sampling Complete.")
@printf("Current Physics Log-likelihood: %g\n",
@printf("Final Physics Log-likelihood: %g\n",
physloglikelihood(ℓπ, samples[end]))
@printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Current MSE against dataset Log-likelihood: %g\n",
@printf("Final Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Final MSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, samples[end]))
if estim_collocate
@printf("Current gradient loss against dataset Log-likelihood: %g\n",
@printf("Final gradient loss against dataset Log-likelihood: %g\n",
L2loss2(ℓπ, samples[end]))
end
end
Expand Down
1 change: 1 addition & 0 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab
nothing, nothing
end

# this includes losses from dataset domain points as well as discretization points
function full_loss_function(θ, allstd::Vector{Vector{Float64}})
stdpdes, stdbcs, stdextra = allstd
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
Expand Down
Loading

0 comments on commit fbf2463

Please sign in to comment.