Skip to content

Commit

Permalink
update BPINN_tests.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Oct 19, 2024
1 parent fbf2463 commit d0962a7
Showing 1 changed file with 1 addition and 110 deletions.
111 changes: 1 addition & 110 deletions test/BPINN_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,113 +363,4 @@ end
bitvec = abs.(p .- sol_pestim1.estimated_de_params) .>
abs.(p .- sol_pestim2.estimated_de_params)
@test_broken bitvec == ones(size(bitvec))
end

using MCMCChains, Distributions, OrdinaryDiffEq, OptimizationOptimisers, Lux,
AdvancedHMC, Statistics, Random, Functors, ComponentArrays, MonteCarloMeasurements
import Flux
using NeuralPDE

Random.seed!(100)

function lotka_volterra(u, p, t)
# Model parameters.
α, β, γ, δ = p
# Current state.
x, y = u

# Evaluate differential equations.
dx =- β * y) * x # prey
dy =* x - γ) * y # predator

return [dx, dy]
end

# initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 4.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)

# Solve using OrdinaryDiffEq.jl solver
dt = 0.2
solution = solve(prob, Tsit5(); saveat = dt)

times = solution.t
u = hcat(solution.u...)
x = u[1, :] + (0.8 .* randn(length(u[1, :])))
y = u[2, :] + (0.8 .* randn(length(u[2, :])))
dataset = [x, y, times]

chain = Chain(Dense(1, 6, tanh), Dense(6, 6, tanh), Dense(6, 2))

alg1 = BNNODE(chain; dataset = dataset, draw_samples = 1000,
l2std = [0.2, 0.2], phystd = [0.1, 0.1], priorsNNw = (0.0, 1.0),
param = [Normal(2, 0.5), Normal(2, 0.5), Normal(2, 0.5), Normal(2, 0.5)], progress = true, verbose = true)

alg2 = BNNODE(chain; dataset = dataset, draw_samples = 1000,
l2std = [0.2, 0.2], phystd = [0.1, 0.1], phynewstd = [0.1, 0.1], priorsNNw = (0.0, 1.0),
param = [Normal(2, 0.5), Normal(2, 0.5), Normal(2, 0.5), Normal(2, 0.5)],
estim_collocate = true, progress = true, verbose = true)

@time sol_pestim1 = solve(prob, alg1; saveat = dt)
@time sol_pestim2 = solve(prob, alg2; saveat = dt)

unsafe_comparisons(true)
bitvec = abs.(p .- sol_pestim1.estimated_de_params) .>
abs.(p .- sol_pestim2.estimated_de_params)
bitvec == ones(size(bitvec))

using MCMCChains, Distributions, OrdinaryDiffEq, OptimizationOptimisers, Lux,
AdvancedHMC, Statistics, Random, Functors, ComponentArrays, MonteCarloMeasurements
using NeuralPDE
import Flux

Random.seed!(100)

linear_analytic = (u0, p, t) -> u0 + sin(2 * π * t) / (2 * π)
linear = (u, p, t) -> cos(2 * π * t)
tspan = (0.0, 2.0)
u0 = 0.0
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan)
p = prob.p

# Numerical and Analytical Solutions: testing ahmc_bayesian_pinn_ode()
ta = range(tspan[1], tspan[2], length = 300)
u = [linear_analytic(u0, nothing, ti) for ti in ta]
= collect(Float64, Array(u) + 0.02 * randn(size(u)))
time = vec(collect(Float64, ta))
physsol1 = [linear_analytic(prob.u0, p, time[i]) for i in eachindex(time)]

# testing points for solve() call must match saveat(1/50.0) arg
ta0 = range(tspan[1], tspan[2], length = 101)
u1 = [linear_analytic(u0, nothing, ti) for ti in ta0]
x̂1 = collect(Float64, Array(u1) + 0.02 * randn(size(u1)))
time1 = vec(collect(Float64, ta0))
physsol0_1 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)]

chainlux = Chain(Dense(1, 7, tanh), Dense(7, 1))
θinit, st = Lux.setup(Random.default_rng(), chainlux)

fh_mcmc_chain, fhsamples, fhstats = ahmc_bayesian_pinn_ode(
prob, chainlux, draw_samples = 2500, progress = true,
estim_collocate = true, verbose = true)

alg = BNNODE(chainlux, draw_samples = 2500)
sol1lux = solve(prob, alg)

# testing points
t = time
# Mean of last 500 sampled parameter's curves[Ensemble predictions]
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:length(fhsamples)]
luxar = [chainlux(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

# --------------------- ahmc_bayesian_pinn_ode() call
@test mean(abs.(x̂ .- meanscurve)) < 0.05
@test mean(abs.(physsol1 .- meanscurve)) < 0.005

#--------------------- solve() call
@test mean(abs.(x̂1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025
@test mean(abs.(physsol0_1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025
end

0 comments on commit d0962a7

Please sign in to comment.