diff --git a/test/BPINN_PDE_tests.jl b/test/BPINN_PDE_tests.jl new file mode 100644 index 000000000..b49c35cfb --- /dev/null +++ b/test/BPINN_PDE_tests.jl @@ -0,0 +1,231 @@ +@testitem "BPINN PDE I: 2D Periodic System" tags=[:pdebpinn] begin + using MCMCChains, Lux, ModelingToolkit, Distributions, OrdinaryDiffEq, + AdvancedHMC, Statistics, Random, Functors, NeuralPDE, MonteCarloMeasurements, + ComponentArrays + import ModelingToolkit: Interval, infimum, supremum + + Random.seed!(100) + + @parameters t + @variables u(..) + Dt = Differential(t) + eq = Dt(u(t)) - cospi(2t) ~ 0 + bcs = [u(0.0) ~ 0.0] + domains = [t ∈ Interval(0.0, 2.0)] + + chainl = Chain(Dense(1, 6, tanh), Dense(6, 1)) + initl, st = Lux.setup(Random.default_rng(), chainl) + @named pde_system = PDESystem(eq, bcs, domains, [t], [u(t)]) + + # non adaptive case + discretization = BayesianPINN([chainl], GridTraining([0.01])) + + sol1 = ahmc_bayesian_pinn_pde( + pde_system, discretization; draw_samples = 1500, bcstd = [0.02], + phystd = [0.01], priorsNNw = (0.0, 1.0), saveats = [1 / 50.0]) + + analytic_sol_func(u0, t) = u0 + sinpi(2t) / (2pi) + ts = vec(sol1.timepoints[1]) + u_real = [analytic_sol_func(0.0, t) for t in ts] + u_predict = pmean(sol1.ensemblesol[1]) + + @test u_predict≈u_real atol=0.5 + @test mean(u_predict .- u_real) < 0.1 +end + +@testitem "BPINN PDE II: 1D ODE" tags=[:pdebpinn] begin + using MCMCChains, Lux, ModelingToolkit, Distributions, OrdinaryDiffEq, + AdvancedHMC, Statistics, Random, Functors, NeuralPDE, MonteCarloMeasurements, + ComponentArrays + import ModelingToolkit: Interval, infimum, supremum + + Random.seed!(100) + + @parameters θ + @variables u(..) + Dθ = Differential(θ) + + # 1D ODE + eq = Dθ(u(θ)) ~ θ^3 + 2.0f0 * θ + (θ^2) * ((1.0f0 + 3 * (θ^2)) / (1.0f0 + θ + (θ^3))) - + u(θ) * (θ + ((1.0f0 + 3.0f0 * (θ^2)) / (1.0f0 + θ + θ^3))) + + # Initial and boundary conditions + bcs = [u(0.0) ~ 1.0f0] + + # Space and time domains + domains = [θ ∈ Interval(0.0f0, 1.0f0)] + + # Discretization + dt = 0.1f0 + + # Neural network + chain = Chain(Dense(1, 12, σ), Dense(12, 1)) + + discretization = BayesianPINN([chain], GridTraining([0.01])) + + @named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) + + sol1 = ahmc_bayesian_pinn_pde( + pde_system, discretization; draw_samples = 500, bcstd = [0.1], + phystd = [0.05], priorsNNw = (0.0, 10.0), saveats = [1 / 100.0]) + + analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2 + ts = [infimum(d.domain):(dt / 10):supremum(d.domain) for d in domains][1] + u_real = [analytic_sol_func(t) for t in ts] + u_predict = pmean(sol1.ensemblesol[1]) + @test u_predict≈u_real atol=0.8 +end + +@testitem "BPINN PDE III: 3rd Degree ODE" tags=[:pdebpinn] begin + using MCMCChains, Lux, ModelingToolkit, Distributions, OrdinaryDiffEq, + AdvancedHMC, Statistics, Random, Functors, NeuralPDE, MonteCarloMeasurements, + ComponentArrays + import ModelingToolkit: Interval, infimum, supremum + + Random.seed!(100) + + @parameters x + @variables u(..), Dxu(..), Dxxu(..), O1(..), O2(..) + Dxxx = Differential(x)^3 + Dx = Differential(x) + + + # ODE + eq = Dx(Dxxu(x)) ~ cospi(x) + + # Initial and boundary conditions + ep = (cbrt(eps(eltype(Float64))))^2 / 6 + + bcs = [ + u(0.0) ~ 0.0, + u(1.0) ~ cospi(1.0), + Dxu(1.0) ~ 1.0, + Dxu(x) ~ Dx(u(x)) + ep * O1(x), + Dxxu(x) ~ Dx(Dxu(x)) + ep * O2(x) + ] + + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0)] + + # Discretization + dt = 0.1f0 + + # Neural network + inner = 20 + chain = Chain(Dense(1, inner, σ), Dense(inner, inner, σ), Dense(inner, inner, σ), + Dense(inner, inner, σ), Dense(inner, inner, σ), Dense(inner, 1)) + + strategy = GridTraining(dt) + ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64 + + discretization = PhysicsInformedNN(chain, strategy; init_params = ps) + + @named pde_system = PDESystem(eq, bcs, domains, [x], + [u(x), Dxu(x), Dxxu(x), O1(x), O2(x)]) + + sol1 = ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 200, + bcstd = [0.01, 0.01, 0.01, 0.01, 0.01], phystd = [0.005], + priorsNNw = (0.0, 10.0), saveats = [1 / 100.0, 1 / 100.0]) + + analytic_sol_func(x) = (π * x * (-x + (π^2) * (2 * x - 3) + 1) - sinpi(x)) / (π^3) + + u_predict = pmean(sol1.ensemblesol[1]) + xs = [infimum(d.domain):(dt / 10):supremum(d.domain) for d in domains][1] + u_real = [analytic_sol_func(x) for x in xs] + @test u_predict≈u_real atol=0.5 +end + +@testitem "BPINN PDE IV: 2D Poisson" tags=[:pdebpinn] begin + using MCMCChains, Lux, ModelingToolkit, Distributions, OrdinaryDiffEq, + AdvancedHMC, Statistics, Random, Functors, NeuralPDE, MonteCarloMeasurements, + ComponentArrays + import ModelingToolkit: Interval, infimum, supremum + + Random.seed!(100) + + @parameters x y + @variables u(..) + Dxx = Differential(x)^2 + Dyy = Differential(y)^2 + + # 2D PDE + eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y) + + # Boundary conditions + bcs = [ + u(0, y) ~ 0.0, + u(1, y) ~ 0.0, + u(x, 0) ~ 0.0, + u(x, 1) ~ 0.0 + ] + + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] + + # Discretization + dt = 0.1f0 + + # Neural network + chain = Chain(Dense(2, 9, σ), Dense(9, 9, σ), Dense(9, 1)) + + dx = 0.04 + discretization = PhysicsInformedNN(chain, GridTraining(dx)) + + @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) + + sol = ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 200, + bcstd = [0.003, 0.003, 0.003, 0.003], phystd = [0.003], + priorsNNw = (0.0, 10.0), saveats = [1 / 100.0, 1 / 100.0]) + + xs = sol.timepoints[1] + analytic_sol_func(x, y) = (sinpi(x) * sinpi(y)) / (2pi^2) + + u_predict = pmean(sol.ensemblesol[1]) + u_real = [analytic_sol_func(xs[:, i][1], xs[:, i][2]) for i in 1:length(xs[1, :])] + @test u_predict≈u_real rtol=0.5 +end + +@testitem "Translating from Flux" tags=[:pdebpinn] begin + using MCMCChains, Lux, ModelingToolkit, Distributions, OrdinaryDiffEq, + AdvancedHMC, Statistics, Random, Functors, NeuralPDE, MonteCarloMeasurements, + ComponentArrays + import ModelingToolkit: Interval, infimum, supremum + import Flux + + Random.seed!(100) + + @parameters θ + @variables u(..) + Dθ = Differential(θ) + + # 1D ODE + eq = Dθ(u(θ)) ~ θ^3 + 2.0f0 * θ + (θ^2) * ((1.0f0 + 3 * (θ^2)) / (1.0f0 + θ + (θ^3))) - + u(θ) * (θ + ((1.0f0 + 3.0f0 * (θ^2)) / (1.0f0 + θ + θ^3))) + + # Initial and boundary conditions + bcs = [u(0.0) ~ 1.0f0] + + # Space and time domains + domains = [θ ∈ Interval(0.0f0, 1.0f0)] + + # Discretization + dt = 0.1f0 + + # Neural network + chain = Flux.Chain(Flux.Dense(1, 12, Flux.σ), Flux.Dense(12, 1)) + + discretization = PhysicsInformedNN(chain, GridTraining(dt)) + @test discretization.chain isa Lux.AbstractLuxLayer + + @named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) + + sol = ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 500, + bcstd = [0.1], phystd = [0.05], priorsNNw = (0.0, 10.0), saveats = [1 / 100.0]) + + analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2 + ts = [infimum(d.domain):(dt / 10):supremum(d.domain) for d in domains][1] + u_real = [analytic_sol_func(t) for t in ts] + u_predict = pmean(sol.ensemblesol[1]) + + @test u_predict≈u_real atol=0.8 +end diff --git a/test/BPINN_PDE_tests_wip.jl b/test/BPINN_PDE_tests_wip.jl deleted file mode 100644 index cbb8ffa46..000000000 --- a/test/BPINN_PDE_tests_wip.jl +++ /dev/null @@ -1,187 +0,0 @@ -using Test, MCMCChains, Lux, ModelingToolkit, ForwardDiff, Distributions, OrdinaryDiffEq, - AdvancedHMC, Statistics, Random, Functors, NeuralPDE, MonteCarloMeasurements, - ComponentArrays -import ModelingToolkit: Interval, infimum, supremum -import Flux - -Random.seed!(100) - -@testset "Example 1: 2D Periodic System" begin - # Cos(pi*t) example - @parameters t - @variables u(..) - Dt = Differential(t) - eqs = Dt(u(t)) - cos(2 * π * t) ~ 0 - bcs = [u(0) ~ 0.0] - domains = [t ∈ Interval(0.0, 2.0)] - chainl = Chain(Dense(1, 6, tanh), Dense(6, 1)) - initl, st = Lux.setup(Random.default_rng(), chainl) - @named pde_system = PDESystem(eqs, bcs, domains, [t], [u(t)]) - - # non adaptive case - discretization = BayesianPINN([chainl], GridTraining([0.01])) - - sol1 = ahmc_bayesian_pinn_pde( - pde_system, discretization; draw_samples = 1500, bcstd = [0.02], - phystd = [0.01], priorsNNw = (0.0, 1.0), saveats = [1 / 50.0]) - - analytic_sol_func(u0, t) = u0 + sin(2 * π * t) / (2 * π) - ts = vec(sol1.timepoints[1]) - u_real = [analytic_sol_func(0.0, t) for t in ts] - u_predict = pmean(sol1.ensemblesol[1]) - @test u_predict≈u_real atol=0.5 - @test mean(u_predict .- u_real) < 0.1 -end - -@testset "Example 2: 1D ODE" begin - @parameters θ - @variables u(..) - Dθ = Differential(θ) - - # 1D ODE - eq = Dθ(u(θ)) ~ θ^3 + 2 * θ + (θ^2) * ((1 + 3 * (θ^2)) / (1 + θ + (θ^3))) - - u(θ) * (θ + ((1 + 3 * (θ^2)) / (1 + θ + θ^3))) - - # Initial and boundary conditions - bcs = [u(0.0) ~ 1.0] - - # Space and time domains - domains = [θ ∈ Interval(0.0, 1.0)] - - # Neural network - chain = Chain(Dense(1, 12, σ), Dense(12, 1)) - - discretization = BayesianPINN([chain], GridTraining([0.01])) - - @named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) - - sol1 = ahmc_bayesian_pinn_pde( - pde_system, discretization; draw_samples = 500, bcstd = [0.1], - phystd = [0.05], priorsNNw = (0.0, 10.0), saveats = [1 / 100.0]) - - analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2 - ts = sol1.timepoints[1] - u_real = vec([analytic_sol_func(t) for t in ts]) - u_predict = pmean(sol1.ensemblesol[1]) - @test u_predict≈u_real atol=0.8 -end - -@testset "Example 3: 3rd Degree ODE" begin - @parameters x - @variables u(..), Dxu(..), Dxxu(..), O1(..), O2(..) - Dxxx = Differential(x)^3 - Dx = Differential(x) - - # ODE - eq = Dx(Dxxu(x)) ~ cos(pi * x) - - # Initial and boundary conditions - ep = (cbrt(eps(eltype(Float64))))^2 / 6 - - bcs = [u(0.0) ~ 0.0, - u(1.0) ~ cos(pi), - Dxu(1.0) ~ 1.0, - Dxu(x) ~ Dx(u(x)) + ep * O1(x), - Dxxu(x) ~ Dx(Dxu(x)) + ep * O2(x)] - - # Space and time domains - domains = [x ∈ Interval(0.0, 1.0)] - - # Neural network - chain = [ - Chain(Dense(1, 10, tanh), Dense(10, 10, tanh), Dense(10, 1)), - Chain(Dense(1, 10, tanh), Dense(10, 10, tanh), Dense(10, 1)), - Chain(Dense(1, 10, tanh), Dense(10, 10, tanh), Dense(10, 1)), - Chain(Dense(1, 4, tanh), Dense(4, 1)), - Chain(Dense(1, 4, tanh), Dense(4, 1)) - ] - - discretization = BayesianPINN(chain, GridTraining(0.01)) - - @named pde_system = PDESystem(eq, bcs, domains, [x], - [u(x), Dxu(x), Dxxu(x), O1(x), O2(x)]) - - sol1 = ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 200, - bcstd = [0.01, 0.01, 0.01, 0.01, 0.01], phystd = [0.005], - priorsNNw = (0.0, 10.0), saveats = [1 / 100.0]) - - analytic_sol_func(x) = (π * x * (-x + (π^2) * (2 * x - 3) + 1) - sin(π * x)) / (π^3) - - u_predict = pmean(sol1.ensemblesol[1]) - xs = vec(sol1.timepoints[1]) - u_real = [analytic_sol_func(x) for x in xs] - @test u_predict≈u_real atol=0.5 -end - -@testset "Example 4: 2D Poissons equation" begin - @parameters x y - @variables u(..) - Dxx = Differential(x)^2 - Dyy = Differential(y)^2 - - # 2D PDE - eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y) - - # Boundary conditions - bcs = [u(0, y) ~ 0.0, u(1, y) ~ 0.0, - u(x, 0) ~ 0.0, u(x, 1) ~ 0.0] - - # Space and time domains - domains = [x ∈ Interval(0.0, 1.0), - y ∈ Interval(0.0, 1.0)] - - # Neural network - dim = 2 # number of dimensions - chain = Chain(Dense(dim, 9, σ), Dense(9, 9, σ), Dense(9, 1)) - - # Discretization - dx = 0.04 - discretization = BayesianPINN([chain], GridTraining(dx)) - - @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) - - sol1 = ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 200, - bcstd = [0.003, 0.003, 0.003, 0.003], phystd = [0.003], - priorsNNw = (0.0, 10.0), saveats = [1 / 100.0, 1 / 100.0]) - - xs = sol1.timepoints[1] - analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2) - - u_predict = pmean(sol1.ensemblesol[1]) - u_real = [analytic_sol_func(xs[:, i][1], xs[:, i][2]) for i in 1:length(xs[1, :])] - @test u_predict≈u_real atol=1.5 -end - -@testset "Translating from Flux" begin - @parameters θ - @variables u(..) - Dθ = Differential(θ) - - # 1D ODE - eq = Dθ(u(θ)) ~ θ^3 + 2 * θ + (θ^2) * ((1 + 3 * (θ^2)) / (1 + θ + (θ^3))) - - u(θ) * (θ + ((1 + 3 * (θ^2)) / (1 + θ + θ^3))) - - # Initial and boundary conditions - bcs = [u(0.0) ~ 1.0] - - # Space and time domains - domains = [θ ∈ Interval(0.0, 1.0)] - - # Neural network - chain = Flux.Chain(Flux.Dense(1, 12, Flux.σ), Flux.Dense(12, 1)) - - discretization = BayesianPINN([chain], GridTraining([0.01])) - @test discretization.chain[1] isa AbstractLuxLayer - - @named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) - - sol1 = ahmc_bayesian_pinn_pde( - pde_system, discretization; draw_samples = 500, bcstd = [0.1], - phystd = [0.05], priorsNNw = (0.0, 10.0), saveats = [1 / 100.0]) - - analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2 - ts = sol1.timepoints[1] - u_real = vec([analytic_sol_func(t) for t in ts]) - u_predict = pmean(sol1.ensemblesol[1]) - @test u_predict≈u_real atol=0.8 -end diff --git a/test/NNPDE_cuda_tests.jl b/test/NNPDE_cuda_tests.jl index be72b79ce..8de4163c5 100644 --- a/test/NNPDE_cuda_tests.jl +++ b/test/NNPDE_cuda_tests.jl @@ -16,7 +16,7 @@ export gpud, callback end @testitem "1D ODE - CUDA" tags=[:cuda] setup=[CUDATestSetup] begin - using Lux, Optimization, OptimizationOptimisers, Random + using Lux, Optimization, OptimizationOptimisers, Random, ComponentArrays import ModelingToolkit: Interval, infimum, supremum Random.seed!(100) @@ -60,7 +60,7 @@ end end @testitem "1D PDE Dirichlet BC - CUDA" tags=[:cuda] setup=[CUDATestSetup] begin - using Lux, Optimization, OptimizationOptimisers, Random + using Lux, Optimization, OptimizationOptimisers, Random, ComponentArrays import ModelingToolkit: Interval, infimum, supremum Random.seed!(100) @@ -106,7 +106,8 @@ end end @testitem "1D PDE Neumann BC - CUDA" tags=[:cuda] setup=[CUDATestSetup] begin - using Lux, Optimization, OptimizationOptimisers, Random, QuasiMonteCarlo + using Lux, Optimization, OptimizationOptimisers, Random, QuasiMonteCarlo, + ComponentArrays import ModelingToolkit: Interval, infimum, supremum Random.seed!(100) @@ -156,7 +157,7 @@ end end @testitem "2D PDE - CUDA" tags=[:cuda] setup=[CUDATestSetup] begin - using Lux, Optimization, OptimizationOptimisers, Random + using Lux, Optimization, OptimizationOptimisers, Random, ComponentArrays import ModelingToolkit: Interval, infimum, supremum Random.seed!(100) diff --git a/test/direct_function_tests.jl b/test/direct_function_tests.jl index 3e101afa1..022cbae2a 100644 --- a/test/direct_function_tests.jl +++ b/test/direct_function_tests.jl @@ -63,7 +63,7 @@ end prob = discretize(pde_system, discretization) res = solve(prob, OptimizationOptimisers.Adam(0.01), maxiters = 500) prob = remake(prob, u0 = res.u) - res = solve(prob, OptimizationOptimJL.BFGS(), maxiters = 1000) + res = solve(prob, BFGS(), maxiters = 1000) dx = 0.01 xs = collect(x0:dx:x_end) func_s = func(xs) @@ -101,9 +101,9 @@ end symprob.loss_functions.full_loss_function(symprob.flat_init_params, nothing) res = solve(prob, OptimizationOptimisers.Adam(0.01), maxiters = 500) prob = remake(prob, u0 = res.u) - res = solve(prob, OptimizationOptimJL.BFGS(), maxiters = 1000) + res = solve(prob, BFGS(), maxiters = 1000) prob = remake(prob, u0 = res.u) - res = solve(prob, OptimizationOptimJL.BFGS(), maxiters = 500) + res = solve(prob, BFGS(), maxiters = 500) phi = discretization.phi xs = collect(x0:0.1:x_end) ys = collect(y0:0.1:y_end)