From 1c128ee10d6aec36c19a53cd0570fd778257744d Mon Sep 17 00:00:00 2001 From: Robin Krom Date: Wed, 30 Aug 2023 17:06:18 +0200 Subject: [PATCH 1/3] asl: adding additional symbolic loss functions This commit adds additional symbolic loss functions. These are loss functions analogous to those obtained from PDE equations or boundary equations, but the integrand can be given explicitly in symbolic form. --- Project.toml | 1 + src/adaptive_losses.jl | 51 ++++-- src/discretize.jl | 238 +++++++++++++++++-------- src/pinn_types.jl | 33 +++- src/symbolic_utilities.jl | 8 +- src/training_strategies.jl | 90 ++++++---- test/additional_symbolic_loss_tests.jl | 172 ++++++++++++++++++ test/runtests.jl | 3 +- 8 files changed, 462 insertions(+), 134 deletions(-) create mode 100644 test/additional_symbolic_loss_tests.jl diff --git a/Project.toml b/Project.toml index 1013977bad..5a66ccc06e 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index b37023da7c..1a97f80e49 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -15,6 +15,7 @@ end """ ```julia NonAdaptiveLoss{T}(; pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) ``` @@ -24,31 +25,34 @@ change during optimization """ mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss pde_loss_weights::Vector{T} + asl_loss_weights::Vector{T} bc_loss_weights::Vector{T} additional_loss_weights::Vector{T} SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) where { T <: Real } - new(vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), + new(vectorify(pde_loss_weights, T), vectorify(asl_loss_weights, T), vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T)) end end # default to Float64 -SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, bc_loss_weights = 1, +SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) NonAdaptiveLoss{Float64}(; pde_loss_weights = pde_loss_weights, + asl_loss_weights = asl_loss_weights, bc_loss_weights = bc_loss_weights, additional_loss_weights = additional_loss_weights) end function generate_adaptive_loss_function(pinnrep::PINNRepresentation, adaloss::NonAdaptiveLoss, - pde_loss_functions, bc_loss_functions) - function null_nonadaptive_loss(θ, pde_losses, bc_losses) + pde_loss_functions, asl_loss_functions, bc_loss_functions) + function null_nonadaptive_loss(θ, pde_loss, asl_loss, bc_losses) nothing end end @@ -58,6 +62,7 @@ end GradientScaleAdaptiveLoss(reweight_every; weight_change_inertia = 0.9, pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) ``` @@ -90,30 +95,34 @@ mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss reweight_every::Int64 weight_change_inertia::T pde_loss_weights::Vector{T} + asl_loss_weights::Vector{T} bc_loss_weights::Vector{T} additional_loss_weights::Vector{T} SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T}(reweight_every; weight_change_inertia = 0.9, pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) where { T <: Real } new(convert(Int64, reweight_every), convert(T, weight_change_inertia), - vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), - vectorify(additional_loss_weights, T)) + vectorify(pde_loss_weights, T), vectorify(asl_loss_weights, T), + vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T)) end end # default to Float64 SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every; weight_change_inertia = 0.9, pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) GradientScaleAdaptiveLoss{Float64}(reweight_every; weight_change_inertia = weight_change_inertia, pde_loss_weights = pde_loss_weights, + asl_loss_weights = asl_loss_weights, bc_loss_weights = bc_loss_weights, additional_loss_weights = additional_loss_weights) end @@ -136,7 +145,7 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) - bc_loss_weights_proposed = pde_grads_max ./ + bc_loss_weights_proposed = pde_asl_grads_max ./ (bc_grads_mean .+ nonzero_divisor_eps) adaloss.bc_loss_weights .= weight_change_inertia .* adaloss.bc_loss_weights .+ @@ -160,8 +169,10 @@ end ```julia function MiniMaxAdaptiveLoss(reweight_every; pde_max_optimiser = Flux.ADAM(1e-4), + asl_max_optimiser = Flux.ADAM(1e-4), bc_max_optimiser = Flux.ADAM(0.5), pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) ``` @@ -191,65 +202,81 @@ https://arxiv.org/abs/2009.04544 """ mutable struct MiniMaxAdaptiveLoss{T <: Real, PDE_OPT <: Flux.Optimise.AbstractOptimiser, + ASL_OPT <: Flux.Optimise.AbstractOptimiser, BC_OPT <: Flux.Optimise.AbstractOptimiser} <: AbstractAdaptiveLoss reweight_every::Int64 pde_max_optimiser::PDE_OPT + asl_max_optimiser::ASL_OPT bc_max_optimiser::BC_OPT pde_loss_weights::Vector{T} + asl_loss_weights::Vector{T} bc_loss_weights::Vector{T} additional_loss_weights::Vector{T} SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T, - PDE_OPT, BC_OPT}(reweight_every; + PDE_OPT, ASL_OPT, BC_OPT}(reweight_every; pde_max_optimiser = Flux.ADAM(1e-4), + asl_max_optimiser = Flux.ADAM(1e-4), bc_max_optimiser = Flux.ADAM(0.5), pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) where { T <: Real, PDE_OPT <: Flux.Optimise.AbstractOptimiser, + ASL_OPT <: + Flux.Optimise.AbstractOptimiser, BC_OPT <: Flux.Optimise.AbstractOptimiser } - new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), + new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), convert(ASL_OPT, asl_max_optimiser), convert(BC_OPT, bc_max_optimiser), vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), - vectorify(additional_loss_weights, T)) + vectorify(asl_loss_weights, T), vectorify(additional_loss_weights, T)) end end # default to Float64, ADAM, ADAM SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every; pde_max_optimiser = Flux.ADAM(1e-4), + asl_max_optimiser = Flux.ADAM(1e-4), bc_max_optimiser = Flux.ADAM(0.5), pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) MiniMaxAdaptiveLoss{Float64, typeof(pde_max_optimiser), typeof(bc_max_optimiser)}(reweight_every; pde_max_optimiser = pde_max_optimiser, + asl_max_optimiser = asl_max_optimiser, bc_max_optimiser = bc_max_optimiser, pde_loss_weights = pde_loss_weights, + asl_loss_weights = asl_loss_weights, bc_loss_weights = bc_loss_weights, additional_loss_weights = additional_loss_weights) end function generate_adaptive_loss_function(pinnrep::PINNRepresentation, adaloss::MiniMaxAdaptiveLoss, - pde_loss_functions, bc_loss_functions) + pde_loss_functions, asl_loss_functions, bc_loss_functions) pde_max_optimiser = adaloss.pde_max_optimiser + asl_max_optimiser = adaloss.asl_max_optimiser bc_max_optimiser = adaloss.bc_max_optimiser iteration = pinnrep.iteration - function run_minimax_adaptive_loss(θ, pde_losses, bc_losses) + function run_minimax_adaptive_loss(θ, pde_losses, asl_losses, bc_losses) if iteration[1] % adaloss.reweight_every == 0 Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights, -pde_losses) + Flux.Optimise.update!(asl_max_optimiser, adaloss.asl_loss_weights, + -asl_losses) Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses) logvector(pinnrep.logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", iteration[1]) + logvector(pinnrep.logger, adaloss.asl_loss_weights, + "adaptive_loss/asl_loss_weights", iteration[1]) logvector(pinnrep.logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", iteration[1]) diff --git a/src/discretize.jl b/src/discretize.jl index c8412b2d15..bd6f177f5c 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -204,16 +204,16 @@ strategy. """ function generate_training_sets end -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, _indvars::Array, +function generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, _indvars::Array, _depvars::Array) depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars, + _depvars) + return generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars) end # Generate training set in the domain and on the boundary -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::Dict, +function generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, dict_indvars::Dict, dict_depvars::Dict) if dx isa Array dxs = dx @@ -249,19 +249,26 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) end - pde_vars = get_variables(eqs, dict_indvars, dict_depvars) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + function get_eqs_train_sets(eqs) + eqs_vars = get_variables(eqs, dict_indvars, dict_depvars) + eqs_args = get_argument(eqs, dict_indvars, dict_depvars) - pde_train_set = adapt(eltypeθ, - hcat(vec(map(points -> collect(points), - Iterators.product(bc_data...)))...)) + eqs_train_set = adapt(eltypeθ, + hcat(vec(map(points -> collect(points), + Iterators.product(bc_data...)))...)) - pde_train_sets = map(pde_args) do bt - span = map(b -> get(dict_var_span_, b, b), bt) - _set = adapt(eltypeθ, - hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) + eqs_train_sets = map(eqs_args) do bt + span = map(b -> get(dict_var_span_, b, b), bt) + _set = adapt(eltypeθ, + hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) + end + return eqs_train_sets end - [pde_train_sets, bcs_train_sets] + + pde_train_sets = get_eqs_train_sets(eqs) + asl_train_sets = get_eqs_train_sets(asl) + + [pde_train_sets, asl_train_sets, bcs_train_sets] end """ @@ -274,35 +281,38 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) +function get_bounds(domains, eqs, asl, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) + return get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, +function get_bounds(domains, eqs, asl, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy::QuadratureTraining) depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) + return get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, +function get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy::QuadratureTraining) dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) - - pde_lower_bounds = map(pde_args) do pd - span = map(p -> get(dict_lower_bound, p, p), pd) - map(s -> adapt(eltypeθ, s) + cbrt(eps(eltypeθ)), span) - end - pde_upper_bounds = map(pde_args) do pd - span = map(p -> get(dict_upper_bound, p, p), pd) - map(s -> adapt(eltypeθ, s) - cbrt(eps(eltypeθ)), span) + function get_eqs_bounds(eqs) + eqs_args = get_argument(eqs, dict_indvars, dict_depvars) + eqs_lower_bounds = map(eqs_args) do pd + span = map(p -> get(dict_lower_bound, p, p), pd) + map(s -> adapt(eltypeθ, s) + cbrt(eps(eltypeθ)), span) + end + eqs_upper_bounds = map(eqs_args) do pd + span = map(p -> get(dict_upper_bound, p, p), pd) + map(s -> adapt(eltypeθ, s) - cbrt(eps(eltypeθ)), span) + end + return [eqs_lower_bounds, eqs_upper_bounds] end - pde_bounds = [pde_lower_bounds, pde_upper_bounds] + pde_bounds = get_eqs_bounds(eqs) + asl_bounds = get_eqs_bounds(asl) bound_vars = get_variables(bcs, dict_indvars, dict_depvars) @@ -314,10 +324,10 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - [pde_bounds, bcs_bounds] + [pde_bounds, asl_bounds, bcs_bounds] end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) +function get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) dx = 1 / strategy.points dict_span = Dict([Symbol(d.variables) => [ infimum(d.domain) + dx, @@ -325,20 +335,19 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str ] for d in domains]) # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] - pde_args = get_argument(eqs, dict_indvars, dict_depvars) - pde_bounds = map(pde_args) do pde_arg - bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) - bds = eltypeθ.(bds) - bds[1, :], bds[2, :] - end - - bound_args = get_argument(bcs, dict_indvars, dict_depvars) - bcs_bounds = map(bound_args) do bound_arg - bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) - bds = eltypeθ.(bds) - bds[1, :], bds[2, :] + function get_eqs_bounds(eqs) + eqs_args = get_argument(eqs, dict_indvars, dict_depvars) + eqs_bounds = map(eqs_args) do eqs_arg + bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, eqs_arg) + bds = eltypeθ.(bds) + bds[1, :], bds[2, :] + end + return eqs_bounds end - return pde_bounds, bcs_bounds + pde_bounds = get_eqs_bounds(eqs) + asl_bounds = get_eqs_bounds(asl) + bcs_bounds = get_eqs_bounds(bcs) + return pde_bounds, asl_bounds, bcs_bounds end function get_numeric_integral(pinnrep::PINNRepresentation) @@ -404,6 +413,7 @@ For more information, see `discretize` and `PINNRepresentation`. function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) eqs = pde_system.eqs + asl = discretization.additional_symb_loss bcs = pde_system.bcs chain = discretization.chain @@ -515,61 +525,73 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, eqs = [eqs] end - pde_indvars = if strategy isa QuadratureTraining - get_argument(eqs, dict_indvars, dict_depvars) - else - get_variables(eqs, dict_indvars, dict_depvars) + if !(asl isa Array) + asl = [asl] end - bc_indvars = if strategy isa QuadratureTraining - get_argument(bcs, dict_indvars, dict_depvars) - else - get_variables(bcs, dict_indvars, dict_depvars) + function get_eqs_indvars(eqs) + eqs_indvars = if strategy isa QuadratureTraining + get_argument(eqs, dict_indvars, dict_depvars) + else + get_variables(eqs, dict_indvars, dict_depvars) + end + return eqs_indvars end + pde_indvars = get_eqs_indvars(eqs) + bc_indvars = get_eqs_indvars(bcs) + asl_indvars = get_eqs_indvars(asl) pde_integration_vars = get_integration_variables(eqs, dict_indvars, dict_depvars) bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars) + asl_integration_vars = get_integration_variables(asl, dict_indvars, dict_depvars) - pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, + pinnrep = PINNRepresentation(eqs, asl, bcs, domains, eq_params, defaults, + default_p, param_estim, additional_loss, adaloss, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input, logger, multioutput, iteration, init_params, flat_init_params, phi, derivative, - strategy, pde_indvars, bc_indvars, pde_integration_vars, - bc_integration_vars, nothing, nothing, nothing, nothing) + strategy, pde_indvars, asl_indvars, bc_indvars, pde_integration_vars, + asl_integration_vars, bc_integration_vars, nothing, nothing, nothing, nothing, nothing) integral = get_numeric_integral(pinnrep) - symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq; - bc_indvars = pde_indvar) - for (eq, pde_indvar) in zip(eqs, pde_indvars, - pde_integration_vars)] - - symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc; - bc_indvars = bc_indvar) - for (bc, bc_indvar) in zip(bcs, bc_indvars, - bc_integration_vars)] + function build_symbolic_loss_functions(eqs, eqs_indvars, eqs_integration_vars) + symbolic_eqs_loss_functions = [build_symbolic_loss_function(pinnrep, eq; + bc_indvars = eqs_indvar) + for (eq, eqs_indvar) in zip(eqs, eqs_indvars, + eqs_integration_vars)] + return symbolic_eqs_loss_functions + end + symbolic_pde_loss_functions = build_symbolic_loss_functions(eqs, pde_indvars, pde_integration_vars) + symbolic_asl_loss_functions = build_symbolic_loss_functions(asl, asl_indvars, asl_integration_vars) + symbolic_bc_loss_functions = build_symbolic_loss_functions(bcs, bc_indvars, bc_integration_vars) pinnrep.integral = integral pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions + pinnrep.symbolic_asl_loss_functions = symbolic_asl_loss_functions pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions - datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar) - for (eq, pde_indvar, integration_indvar) in zip(eqs, - pde_indvars, - pde_integration_vars)] + function build_datafree_eqs_loss_functions(eqs, eqs_indvars, eqs_integration_vars) + return [build_loss_function(pinnrep, eq, eq_indvar) + for (eq, eq_indvar, integration_indvar) in zip(eqs, + eqs_indvars, + eqs_integration_vars)] + end - datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar) - for (bc, bc_indvar, integration_indvar) in zip(bcs, - bc_indvars, - bc_integration_vars)] + datafree_pde_loss_functions = build_datafree_eqs_loss_functions(eqs, pde_indvars, pde_integration_vars) + datafree_asl_loss_functions = build_datafree_eqs_loss_functions(asl, asl_indvars, asl_integration_vars) + datafree_bc_loss_functions = build_datafree_eqs_loss_functions(bcs, bc_indvars, bc_integration_vars) + + pde_loss_functions, asl_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, + strategy, + datafree_pde_loss_functions, + datafree_asl_loss_functions, + datafree_bc_loss_functions) - pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, - strategy, - datafree_pde_loss_functions, - datafree_bc_loss_functions) # setup for all adaptive losses num_pde_losses = length(pde_loss_functions) + num_asl_losses = length(asl_loss_functions) num_bc_losses = length(bc_loss_functions) # assume one single additional loss function if there is one. this means that the user needs to lump all their functions into a single one, num_additional_loss = additional_loss isa Nothing ? 0 : 1 @@ -578,12 +600,14 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* adaloss.pde_loss_weights + adaloss.asl_loss_weights = ones(adaloss_T, num_asl_losses) .* adaloss.asl_loss_weights adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .* adaloss.additional_loss_weights reweight_losses_func = generate_adaptive_loss_function(pinnrep, adaloss, pde_loss_functions, + asl_loss_functions, bc_loss_functions) function get_likelihood_estimate_function(discretization::PhysicsInformedNN) @@ -597,6 +621,11 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, ChainRulesCore.@ignore_derivatives if self_increment iteration[1] += 1 end + # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them + # we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions. + pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] + asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions] + bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, @@ -665,12 +694,21 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, iteration[1]) end end + ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, + asl_losses, bc_losses) end return full_weighted_loss end + weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses + weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses + weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses return full_loss_function end + sum_weighted_pde_losses = sum(weighted_pde_losses) + sum_weighted_asl_losses = sum(weighted_asl_losses) + sum_weighted_bc_losses = sum(weighted_bc_losses) + weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses function get_likelihood_estimate_function(discretization::BayesianPINN) dataset_pde, dataset_bc = discretization.dataset @@ -758,13 +796,55 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, end return full_loss_function + ChainRulesCore.@ignore_derivatives begin if iteration[1] % log_frequency == 0 + logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses", + iteration[1]) + logvector(pinnrep.logger, asl_losses, "unweighted_loss/asl_losses", + iteration[1]) + logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[1]) + logvector(pinnrep.logger, weighted_pde_losses, + "weighted_loss/weighted_pde_losses", + iteration[1]) + logvector(pinnrep.logger, weighted_asl_losses, + "weighted_loss/weighted_asl_losses", + iteration[1]) + logvector(pinnrep.logger, weighted_bc_losses, + "weighted_loss/weighted_bc_losses", + iteration[1]) + if !(additional_loss isa Nothing) + logscalar(pinnrep.logger, weighted_additional_loss_val, + "weighted_loss/weighted_additional_loss", iteration[1]) + end + logscalar(pinnrep.logger, sum_weighted_pde_losses, + "weighted_loss/sum_weighted_pde_losses", iteration[1]) + logscalar(pinnrep.logger, sum_weighted_bc_losses, + "weighted_loss/sum_weighted_bc_losses", iteration[1]) + logscalar(pinnrep.logger, sum_weighted_asl_losses, + "weighted_loss/sum_weighted_asl_losses", iteration[1]) + logscalar(pinnrep.logger, full_weighted_loss, + "weighted_loss/full_weighted_loss", + iteration[1]) + logvector(pinnrep.logger, adaloss.pde_loss_weights, + "adaptive_loss/pde_loss_weights", + iteration[1]) + logvector(pinnrep.logger, adaloss.asl_loss_weights, + "adaptive_loss/asl_loss_weights", + iteration[1]) + logvector(pinnrep.logger, adaloss.bc_loss_weights, + "adaptive_loss/bc_loss_weights", + iteration[1]) + end end + + return full_weighted_loss end full_loss_function = get_likelihood_estimate_function(discretization) pinnrep.loss_functions = PINNLossFunctions(bc_loss_functions, pde_loss_functions, - full_loss_function, additional_loss, - datafree_pde_loss_functions, - datafree_bc_loss_functions) + asl_loss_functions, + full_loss_function, additional_loss, + datafree_pde_loss_functions, + datafree_asl_loss_functions, + datafree_bc_loss_functions) return pinnrep diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 48c8f46da9..3a260bc554 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -34,6 +34,7 @@ PhysicsInformedNN(chain, phi = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = nothing, adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -78,7 +79,7 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments which are splatted to the `OptimizationProblem` on `solve`. """ -struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN +struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPINN chain::Any strategy::T init_params::P @@ -86,6 +87,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN derivative::DER param_estim::PE additional_loss::AL + additional_symb_loss::ASL adaptive_loss::ADA logger::LOG log_options::LogOptions @@ -101,6 +103,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN derivative = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = [], adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -140,6 +143,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN _derivative, param_estim, additional_loss, + additional_symb_loss, adaptive_loss, logger, log_options, @@ -276,6 +280,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN _derivative, param_estim, additional_loss, + additional_symb_loss, adaptive_loss, logger, log_options, @@ -284,7 +289,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN multioutput, dataset, kwargs) - end + end end """ @@ -303,6 +308,10 @@ mutable struct PINNRepresentation """ eqs::Any """ + The additional symbolic loss functions + """ + asl::Any + """ The boundary condition equations """ bcs::Any @@ -404,6 +413,10 @@ mutable struct PINNRepresentation """ ??? """ + asl_indvars::Any + """ + ??? + """ bc_indvars::Any """ ??? @@ -412,6 +425,10 @@ mutable struct PINNRepresentation """ ??? """ + asl_integration_vars::Any + """ + ??? + """ bc_integration_vars::Any """ ??? @@ -422,6 +439,10 @@ mutable struct PINNRepresentation """ symbolic_pde_loss_functions::Any """ + The additional loss functions as represented in Julia AST + """ + symbolic_asl_loss_functions::Any + """ The boundary condition loss functions as represented in Julia AST """ symbolic_bc_loss_functions::Any @@ -450,6 +471,10 @@ struct PINNLossFunctions """ pde_loss_functions::Any """ + The additional symbolic loss functions + """ + asl_loss_functions::Any + """ The full loss function, combining the PDE and boundary condition loss functions. This is the loss function that is used by the optimizer. """ @@ -463,6 +488,10 @@ struct PINNLossFunctions """ datafree_pde_loss_functions::Any """ + The pre-data version of the additional symbolic loss function + """ + datafree_asl_loss_functions::Any + """ The pre-data version of the BC loss function """ datafree_bc_loss_functions::Any diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 9161f3c365..b14c264e5f 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -303,7 +303,8 @@ Example: [(derivative(phi1, u1, [x, y], [[ε,0]], 1, θ1) + 4 * derivative(phi2, u, [x, y], [[0,ε]], 1, θ2)) - 0, (derivative(phi2, u2, [x, y], [[ε,0]], 1, θ2) + 9 * derivative(phi1, u, [x, y], [[0,ε]], 1, θ1)) - 0] """ -function parse_equation(pinnrep::PINNRepresentation, eq) +# Parse an equation +function parse_equation(pinnrep::PINNRepresentation, eq::Equation) eq_lhs = isequal(expand_derivatives(eq.lhs), 0) ? eq.lhs : expand_derivatives(eq.lhs) eq_rhs = isequal(expand_derivatives(eq.rhs), 0) ? eq.rhs : expand_derivatives(eq.rhs) left_expr = transform_expression(pinnrep, toexpr(eq_lhs)) @@ -313,6 +314,11 @@ function parse_equation(pinnrep::PINNRepresentation, eq) loss_func = :($left_expr .- $right_expr) end +# Parse an energy +function parse_equation(pinnrep::PINNRepresentation, eq) + loss_func = _dot_(transform_expression(pinnrep, toexpr(eq))) +end + function get_indvars_ex(bc_indvars) # , dict_this_eq_indvars) i_ = 1 indvars_ex = map(bc_indvars) do u diff --git a/src/training_strategies.jl b/src/training_strategies.jl index d4edd26bfc..2b46bf4b25 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -53,28 +53,31 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep dx = strategy.dx eltypeθ = eltype(pinnrep.flat_init_params) - train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ, + train_sets = generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars) # the points in the domain and on the boundary - pde_train_sets, bcs_train_sets = train_sets - pde_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), - pde_train_sets) - bcs_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), - bcs_train_sets) - pde_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) - for (_loss, _set) in zip(datafree_pde_loss_function, - pde_train_sets)] - - bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) - for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)] - - pde_loss_functions, bc_loss_functions + pde_train_sets, asl_train_sets, bcs_train_sets = train_sets + + all_loss_functions = + map([(pde_train_sets, datafree_pde_loss_function), + (asl_train_sets, datafree_asl_loss_function), + (bcs_train_sets, datafree_bc_loss_function)]) do (train_sets, datafree_loss_function) + + train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), + train_sets) + loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) + for (_loss, _set) in zip(datafree_loss_function, + train_sets)] + return loss_functions + end + return Tuple(all_loss_functions) end function get_loss_function(loss_function, train_set, eltypeθ, strategy::GridTraining; @@ -113,22 +116,25 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::StochasticTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - pde_bounds, bcs_bounds = bounds + pde_bounds, asl_bounds, bcs_bounds = bounds - pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) - for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] + all_loss_functions = map([(datafree_pde_loss_function, pde_bounds), + (datafree_asl_loss_function, asl_bounds), + (datafree_bc_loss_function, bcs_bounds)]) do (datafree_loss_function, bounds) - bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) - for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] + return [get_loss_function(_loss, bound, eltypeθ, strategy) + for (_loss, bound) in zip(datafree_loss_function, bounds)] + end - pde_loss_functions, bc_loss_functions + return Tuple(all_loss_functions) end function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining; @@ -195,17 +201,20 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuasiRandomTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - pde_bounds, bcs_bounds = bounds + pde_bounds, asl_bounds, bcs_bounds = bounds pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] + asl_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) + for (_loss, bound) in zip(datafree_asl_loss_function, asl_bounds)] strategy_ = QuasiRandomTraining(strategy.bcs_points; sampling_alg = strategy.sampling_alg, @@ -214,7 +223,8 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy_) for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] - pde_loss_functions, bc_loss_functions + + pde_loss_functions, asl_loss_functions, bc_loss_functions end function get_loss_function(loss_function, bound, eltypeθ, strategy::QuasiRandomTraining; @@ -288,22 +298,24 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuadratureTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - pde_bounds, bcs_bounds = bounds - - lbs, ubs = pde_bounds - pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) - for (_loss, lb, ub) in zip(datafree_pde_loss_function, lbs, ubs)] - lbs, ubs = bcs_bounds - bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) - for (_loss, lb, ub) in zip(datafree_bc_loss_function, lbs, ubs)] + pde_bounds, asl_bounds, bcs_bounds = bounds + + all_loss_functions = map([(datafree_pde_loss_function, pde_bounds), + (datafree_asl_loss_function, asl_bounds), + (datafree_bc_loss_function, bcs_bounds)]) do (datafree_loss_function, bounds) + lbs, ubs = bounds + return [get_loss_function(_loss, lb, ub, eltypeθ, strategy) + for (_loss, lb, ub) in zip(datafree_loss_function, lbs, ubs)] + end - pde_loss_functions, bc_loss_functions + return Tuple(all_loss_functions) end function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::QuadratureTraining; @@ -366,4 +378,4 @@ function get_loss_function(loss_function, train_set, eltypeθ, strategy::WeightedIntervalTraining; τ = nothing) loss = (θ) -> mean(abs2, loss_function(train_set, θ)) -end \ No newline at end of file +end diff --git a/test/additional_symbolic_loss_tests.jl b/test/additional_symbolic_loss_tests.jl new file mode 100644 index 0000000000..27c154ae97 --- /dev/null +++ b/test/additional_symbolic_loss_tests.jl @@ -0,0 +1,172 @@ +using NeuralPDE: DomainSets +using Random +using Test +using ComponentArrays +using OptimizationOptimisers +using NeuralPDE +using LinearAlgebra +using Lux +import ModelingToolkit: Interval + +@parameters x0 x1 x2 x3 +@variables ρ01(..) ρ02(..) ρ03(..) ρ12(..) ρ13(..) ρ23(..) + +# the 4-torus +domain = [ + x0 ∈ Interval(0.0, 1.0), + x1 ∈ Interval(0.0, 1.0), + x2 ∈ Interval(0.0, 1.0), + x3 ∈ Interval(0.0, 1.0), +] + +∂₀ = Differential(x0) +∂₁ = Differential(x1) +∂₂ = Differential(x2) +∂₃ = Differential(x3) + +d₂(ρ) = [ + # commented are the signed permutations of the indeces + + #(0,1,2) + (2,0,1) + (1,2,0) - (1,0,2) - (2,1,0) - (0,2,1) + 2 * ∂₀(ρ[4]) - 2 * ∂₁(ρ[2]) + 2 * ∂₂(ρ[1]), + #(0,1,3) + (3,0,1) + (1,3,0) - (1,0,3) - (0,3,1) - (3,1,0) + 2 * ∂₀(ρ[5]) - 2 * ∂₁(ρ[3]) + 2 * ∂₃(ρ[1]), + #(0,2,3) + (3,0,2) + (2,3,0) - (2,0,3) - (0,3,2) - (3,2,0) + 2 * ∂₀(ρ[6]) - 2 * ∂₂(ρ[3]) + 2 * ∂₃(ρ[2]), + #(1,2,3) + (3,1,2) + (2,3,1) - (2,1,3) - (1,3,2) - (3,2,1) + 2 * ∂₁(ρ[6]) - 2 * ∂₂(ρ[5]) + 2 * ∂₃(ρ[4]), +] + +u(ρ) = ρ[1] * ρ[6] - ρ[2] * ρ[5] + ρ[3] * ρ[4] + +K₁(ρ) = 2(ρ[1] + ρ[6]) / u(ρ) +K₂(ρ) = 2(ρ[2] - ρ[5]) / u(ρ) +K₃(ρ) = 2(ρ[3] + ρ[4]) / u(ρ) + +K(ρ) = [ + K₁(ρ), + K₂(ρ), + K₃(ρ), +] + +# energy +fₑ(ρ) = (K(ρ)[1]^2 + K(ρ)[2]^2 + K(ρ)[3]^2) * u(ρ) + +energies = + let ρ = [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] + [fₑ(ρ)] + end + +# periodic boundary conditions for the 4-torus +bcs = [ + ρ01(0.0, x1, x2, x3) ~ ρ01(1.0, x1, x2, x3), + ρ01(x0, 0.0, x2, x3) ~ ρ01(x0, 1.0, x2, x3), + ρ01(x0, x1, 0.0, x3) ~ ρ01(x0, x1, 1.0, x3), + ρ01(x0, x1, x2, 0.0) ~ ρ01(x0, x1, x2, 1.0), + ρ02(0.0, x1, x2, x3) ~ ρ02(1.0, x1, x2, x3), + ρ02(x0, 0.0, x2, x3) ~ ρ02(x0, 1.0, x2, x3), + ρ02(x0, x1, 0.0, x3) ~ ρ02(x0, x1, 1.0, x3), + ρ02(x0, x1, x2, 0.0) ~ ρ02(x0, x1, x2, 1.0), + ρ03(0.0, x1, x2, x3) ~ ρ03(1.0, x1, x2, x3), + ρ03(x0, 0.0, x2, x3) ~ ρ03(x0, 1.0, x2, x3), + ρ03(x0, x1, 0.0, x3) ~ ρ03(x0, x1, 1.0, x3), + ρ03(x0, x1, x2, 0.0) ~ ρ03(x0, x1, x2, 1.0), + ρ12(0.0, x1, x2, x3) ~ ρ12(1.0, x1, x2, x3), + ρ12(x0, 0.0, x2, x3) ~ ρ12(x0, 1.0, x2, x3), + ρ12(x0, x1, 0.0, x3) ~ ρ12(x0, x1, 1.0, x3), + ρ12(x0, x1, x2, 0.0) ~ ρ12(x0, x1, x2, 1.0), + ρ13(0.0, x1, x2, x3) ~ ρ13(1.0, x1, x2, x3), + ρ13(x0, 0.0, x2, x3) ~ ρ13(x0, 1.0, x2, x3), + ρ13(x0, x1, 0.0, x3) ~ ρ13(x0, x1, 1.0, x3), + ρ13(x0, x1, x2, 0.0) ~ ρ13(x0, x1, x2, 1.0), + ρ23(0.0, x1, x2, x3) ~ ρ23(1.0, x1, x2, x3), + ρ23(x0, 0.0, x2, x3) ~ ρ23(x0, 1.0, x2, x3), + ρ23(x0, x1, 0.0, x3) ~ ρ23(x0, x1, 1.0, x3), + ρ23(x0, x1, x2, 0.0) ~ ρ23(x0, x1, x2, 1.0), +] + +# equations for dρ = 0. +eqClosed(ρ) = d₂(ρ)[:] .~ 0 + +eqs = + let ρ = [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] + vcat( + eqClosed(ρ), + ) + end + + +input_ = length(domain) +n = 16 + +ixToSym = Dict( + 1 => :ρ01, + 2 => :ρ02, + 3 => :ρ03, + 4 => :ρ12, + 5 => :ρ13, + 6 => :ρ23 +) + +chains = NamedTuple((ixToSym[ix], Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, 1))) for ix in 1:6) +chains0 = collect(chains) + +function test_donaldson_energy_loss_no_logs(ϵ, sym_prob, prob) + # pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions + # bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions + # energy_inner_loss_functions = sym_prob.loss_functions.asl_loss_functions + + ps = map(c -> Lux.setup(Random.default_rng(), c)[1], chains) |> ComponentArray .|> Float64 + prob1 = remake(prob; u0 = ComponentVector(depvar = ps)) + + callback(ϵ::Float64) = function(p, l) + # println("loss: ", l) + # println("pde_losses: ", map(l_ -> l_(p), pde_inner_loss_functions)) + # println("bcs_losses: ", map(l_ -> l_(p), bcs_inner_loss_functions)) + # println("energy losses: ", map(l_ -> l_(p), energy_inner_loss_functions)) + return l < ϵ + end + _sol = Optimization.solve(prob1, Adam(0.01); callback=callback(ϵ), maxiters = 1) + return true +end + + +@named pdesystem = PDESystem(eqs, bcs, domain, [x0, x1, x2, x3], + [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] +) +discretization = PhysicsInformedNN(chains0, QuasiRandomTraining(1000)) +sym_prob = symbolic_discretize(pdesystem, discretization) +prob = discretize(pdesystem, discretization) +@info "testing additional symbolic loss functions: solver runs without additional symbolic losses." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + + +@named pdesystem1 = PDESystem([], bcs, domain, [x0, x1, x2, x3], + [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] +) +discretization = PhysicsInformedNN(chains0, QuasiRandomTraining(1000); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem1, discretization) +prob = discretize(pdesystem1, discretization) +@info "testing additional symbolic loss functions: quasi random training: solver runs with only additional symbolic loss function." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + +@named pdesystem2 = PDESystem(eqs, bcs, domain, [x0, x1, x2, x3], + [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] +) +discretization = PhysicsInformedNN(chains0, StochasticTraining(1000); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem2, discretization) +prob = discretize(pdesystem2, discretization) +@info "testing additional symbolic loss functions: stochastic training: solver runs with additional symbolic loss function and PDE system." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + +discretization = PhysicsInformedNN(chains0, GridTraining(0.1); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem2, discretization) +prob = discretize(pdesystem2, discretization) +@info "testing additional symbolic loss functions: grid training: solver runs with additional symbolic loss function and PDE system." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + +discretization = PhysicsInformedNN(chains0, QuadratureTraining(); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem1, discretization) +prob = discretize(pdesystem1, discretization) +@info "testing additional symbolic loss functions: quadrature training: solver runs additional symbolic loss function and PDE system." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) diff --git a/test/runtests.jl b/test/runtests.jl index 5d6ac6909e..391ea4f9a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ end if GROUP == "All" || GROUP == "NNPDE2" @time @safetestset "Additional Loss" begin include("additional_loss_tests.jl") end + @time @safetestset "Additional Symbolic Loss" begin include("additional_symbolic_loss_tests.jl") end @time @safetestset "Direction Function Approximation" begin include("direct_function_tests.jl") end end if GROUP == "All" || GROUP == "NeuralAdapter" @@ -65,4 +66,4 @@ end @safetestset "NNPDE_gpu" begin include("NNPDE_tests_gpu.jl") end @safetestset "NNPDE_gpu_Lux" begin include("NNPDE_tests_gpu_Lux.jl") end end -end \ No newline at end of file +end From 9af3137cfd1011e10d51e82713f64baea8e8d738 Mon Sep 17 00:00:00 2001 From: Robin Krom Date: Thu, 18 Jan 2024 13:08:14 +0100 Subject: [PATCH 2/3] rebase: fixes --- src/discretize.jl | 66 ++++++----------------------------------------- src/pinn_types.jl | 13 ++++++---- 2 files changed, 16 insertions(+), 63 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index bd6f177f5c..cd8e4efe41 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -613,31 +613,30 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, function get_likelihood_estimate_function(discretization::PhysicsInformedNN) function full_loss_function(θ, p) # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them - pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] - bc_losses = [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] + # we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions. + pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] + asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions] + bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized # that's why we prefer the user to maintain the increment in the outer loop callback during optimization ChainRulesCore.@ignore_derivatives if self_increment iteration[1] += 1 end - # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them - # we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions. - pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] - asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions] - bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, - bc_losses) + asl_losses, bc_losses) end weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses + weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses sum_weighted_pde_losses = sum(weighted_pde_losses) + sum_weighted_asl_losses = sum(weighted_asl_losses) sum_weighted_bc_losses = sum(weighted_bc_losses) - weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses + weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses full_weighted_loss = if additional_loss isa Nothing weighted_loss_before_additional @@ -694,21 +693,12 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, iteration[1]) end end - ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, - asl_losses, bc_losses) end return full_weighted_loss end - weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses - weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses - weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses return full_loss_function end - sum_weighted_pde_losses = sum(weighted_pde_losses) - sum_weighted_asl_losses = sum(weighted_asl_losses) - sum_weighted_bc_losses = sum(weighted_bc_losses) - weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses function get_likelihood_estimate_function(discretization::BayesianPINN) dataset_pde, dataset_bc = discretization.dataset @@ -796,46 +786,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, end return full_loss_function - ChainRulesCore.@ignore_derivatives begin if iteration[1] % log_frequency == 0 - logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses", - iteration[1]) - logvector(pinnrep.logger, asl_losses, "unweighted_loss/asl_losses", - iteration[1]) - logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[1]) - logvector(pinnrep.logger, weighted_pde_losses, - "weighted_loss/weighted_pde_losses", - iteration[1]) - logvector(pinnrep.logger, weighted_asl_losses, - "weighted_loss/weighted_asl_losses", - iteration[1]) - logvector(pinnrep.logger, weighted_bc_losses, - "weighted_loss/weighted_bc_losses", - iteration[1]) - if !(additional_loss isa Nothing) - logscalar(pinnrep.logger, weighted_additional_loss_val, - "weighted_loss/weighted_additional_loss", iteration[1]) - end - logscalar(pinnrep.logger, sum_weighted_pde_losses, - "weighted_loss/sum_weighted_pde_losses", iteration[1]) - logscalar(pinnrep.logger, sum_weighted_bc_losses, - "weighted_loss/sum_weighted_bc_losses", iteration[1]) - logscalar(pinnrep.logger, sum_weighted_asl_losses, - "weighted_loss/sum_weighted_asl_losses", iteration[1]) - logscalar(pinnrep.logger, full_weighted_loss, - "weighted_loss/full_weighted_loss", - iteration[1]) - logvector(pinnrep.logger, adaloss.pde_loss_weights, - "adaptive_loss/pde_loss_weights", - iteration[1]) - logvector(pinnrep.logger, adaloss.asl_loss_weights, - "adaptive_loss/asl_loss_weights", - iteration[1]) - logvector(pinnrep.logger, adaloss.bc_loss_weights, - "adaptive_loss/bc_loss_weights", - iteration[1]) - end end - - return full_weighted_loss end full_loss_function = get_likelihood_estimate_function(discretization) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 3a260bc554..b156c83cb7 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -103,7 +103,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPIN derivative = nothing, param_estim = false, additional_loss = nothing, - additional_symb_loss = [], + additional_symb_loss = [], adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -136,14 +136,14 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPIN new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), typeof(param_estim), - typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, + typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, strategy, init_params, _phi, _derivative, param_estim, additional_loss, - additional_symb_loss, + additional_symb_loss, adaptive_loss, logger, log_options, @@ -162,6 +162,7 @@ BayesianPINN(chain, phi = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = nothing, adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -211,7 +212,7 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments. """ -struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN +struct BayesianPINN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, D, K} <: AbstractPINN chain::Any strategy::T init_params::P @@ -219,6 +220,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN derivative::DER param_estim::PE additional_loss::AL + additional_symb_loss::ASL adaptive_loss::ADA logger::LOG log_options::LogOptions @@ -235,6 +237,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN derivative = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = nothing, adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -272,7 +275,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), typeof(param_estim), - typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset), + typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset), typeof(kwargs)}(chain, strategy, init_params, From 1bef6792d5bc649200143048131271be21b91772 Mon Sep 17 00:00:00 2001 From: Robin Krom Date: Sat, 27 Jan 2024 14:17:52 +0100 Subject: [PATCH 3/3] package upgrade --- Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 5a66ccc06e..396a3b1a44 100644 --- a/Project.toml +++ b/Project.toml @@ -19,12 +19,14 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +OhMyREPL = "5fb14364-9ced-5910-84b2-373655c76a03" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -38,6 +40,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +TerminalPager = "0c614874-6106-40ed-a7c2-2f1cd0bff883" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"