diff --git a/Project.toml b/Project.toml index 1013977bad..396a3b1a44 100644 --- a/Project.toml +++ b/Project.toml @@ -19,15 +19,18 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +OhMyREPL = "5fb14364-9ced-5910-84b2-373655c76a03" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -37,6 +40,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +TerminalPager = "0c614874-6106-40ed-a7c2-2f1cd0bff883" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index b37023da7c..1a97f80e49 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -15,6 +15,7 @@ end """ ```julia NonAdaptiveLoss{T}(; pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) ``` @@ -24,31 +25,34 @@ change during optimization """ mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss pde_loss_weights::Vector{T} + asl_loss_weights::Vector{T} bc_loss_weights::Vector{T} additional_loss_weights::Vector{T} SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) where { T <: Real } - new(vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), + new(vectorify(pde_loss_weights, T), vectorify(asl_loss_weights, T), vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T)) end end # default to Float64 -SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, bc_loss_weights = 1, +SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) NonAdaptiveLoss{Float64}(; pde_loss_weights = pde_loss_weights, + asl_loss_weights = asl_loss_weights, bc_loss_weights = bc_loss_weights, additional_loss_weights = additional_loss_weights) end function generate_adaptive_loss_function(pinnrep::PINNRepresentation, adaloss::NonAdaptiveLoss, - pde_loss_functions, bc_loss_functions) - function null_nonadaptive_loss(θ, pde_losses, bc_losses) + pde_loss_functions, asl_loss_functions, bc_loss_functions) + function null_nonadaptive_loss(θ, pde_loss, asl_loss, bc_losses) nothing end end @@ -58,6 +62,7 @@ end GradientScaleAdaptiveLoss(reweight_every; weight_change_inertia = 0.9, pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) ``` @@ -90,30 +95,34 @@ mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss reweight_every::Int64 weight_change_inertia::T pde_loss_weights::Vector{T} + asl_loss_weights::Vector{T} bc_loss_weights::Vector{T} additional_loss_weights::Vector{T} SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T}(reweight_every; weight_change_inertia = 0.9, pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) where { T <: Real } new(convert(Int64, reweight_every), convert(T, weight_change_inertia), - vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), - vectorify(additional_loss_weights, T)) + vectorify(pde_loss_weights, T), vectorify(asl_loss_weights, T), + vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T)) end end # default to Float64 SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every; weight_change_inertia = 0.9, pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) GradientScaleAdaptiveLoss{Float64}(reweight_every; weight_change_inertia = weight_change_inertia, pde_loss_weights = pde_loss_weights, + asl_loss_weights = asl_loss_weights, bc_loss_weights = bc_loss_weights, additional_loss_weights = additional_loss_weights) end @@ -136,7 +145,7 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) - bc_loss_weights_proposed = pde_grads_max ./ + bc_loss_weights_proposed = pde_asl_grads_max ./ (bc_grads_mean .+ nonzero_divisor_eps) adaloss.bc_loss_weights .= weight_change_inertia .* adaloss.bc_loss_weights .+ @@ -160,8 +169,10 @@ end ```julia function MiniMaxAdaptiveLoss(reweight_every; pde_max_optimiser = Flux.ADAM(1e-4), + asl_max_optimiser = Flux.ADAM(1e-4), bc_max_optimiser = Flux.ADAM(0.5), pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) ``` @@ -191,65 +202,81 @@ https://arxiv.org/abs/2009.04544 """ mutable struct MiniMaxAdaptiveLoss{T <: Real, PDE_OPT <: Flux.Optimise.AbstractOptimiser, + ASL_OPT <: Flux.Optimise.AbstractOptimiser, BC_OPT <: Flux.Optimise.AbstractOptimiser} <: AbstractAdaptiveLoss reweight_every::Int64 pde_max_optimiser::PDE_OPT + asl_max_optimiser::ASL_OPT bc_max_optimiser::BC_OPT pde_loss_weights::Vector{T} + asl_loss_weights::Vector{T} bc_loss_weights::Vector{T} additional_loss_weights::Vector{T} SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T, - PDE_OPT, BC_OPT}(reweight_every; + PDE_OPT, ASL_OPT, BC_OPT}(reweight_every; pde_max_optimiser = Flux.ADAM(1e-4), + asl_max_optimiser = Flux.ADAM(1e-4), bc_max_optimiser = Flux.ADAM(0.5), pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) where { T <: Real, PDE_OPT <: Flux.Optimise.AbstractOptimiser, + ASL_OPT <: + Flux.Optimise.AbstractOptimiser, BC_OPT <: Flux.Optimise.AbstractOptimiser } - new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), + new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), convert(ASL_OPT, asl_max_optimiser), convert(BC_OPT, bc_max_optimiser), vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), - vectorify(additional_loss_weights, T)) + vectorify(asl_loss_weights, T), vectorify(additional_loss_weights, T)) end end # default to Float64, ADAM, ADAM SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every; pde_max_optimiser = Flux.ADAM(1e-4), + asl_max_optimiser = Flux.ADAM(1e-4), bc_max_optimiser = Flux.ADAM(0.5), pde_loss_weights = 1, + asl_loss_weights = 1, bc_loss_weights = 1, additional_loss_weights = 1) MiniMaxAdaptiveLoss{Float64, typeof(pde_max_optimiser), typeof(bc_max_optimiser)}(reweight_every; pde_max_optimiser = pde_max_optimiser, + asl_max_optimiser = asl_max_optimiser, bc_max_optimiser = bc_max_optimiser, pde_loss_weights = pde_loss_weights, + asl_loss_weights = asl_loss_weights, bc_loss_weights = bc_loss_weights, additional_loss_weights = additional_loss_weights) end function generate_adaptive_loss_function(pinnrep::PINNRepresentation, adaloss::MiniMaxAdaptiveLoss, - pde_loss_functions, bc_loss_functions) + pde_loss_functions, asl_loss_functions, bc_loss_functions) pde_max_optimiser = adaloss.pde_max_optimiser + asl_max_optimiser = adaloss.asl_max_optimiser bc_max_optimiser = adaloss.bc_max_optimiser iteration = pinnrep.iteration - function run_minimax_adaptive_loss(θ, pde_losses, bc_losses) + function run_minimax_adaptive_loss(θ, pde_losses, asl_losses, bc_losses) if iteration[1] % adaloss.reweight_every == 0 Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights, -pde_losses) + Flux.Optimise.update!(asl_max_optimiser, adaloss.asl_loss_weights, + -asl_losses) Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses) logvector(pinnrep.logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", iteration[1]) + logvector(pinnrep.logger, adaloss.asl_loss_weights, + "adaptive_loss/asl_loss_weights", iteration[1]) logvector(pinnrep.logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", iteration[1]) diff --git a/src/discretize.jl b/src/discretize.jl index c8412b2d15..cd8e4efe41 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -204,16 +204,16 @@ strategy. """ function generate_training_sets end -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, _indvars::Array, +function generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, _indvars::Array, _depvars::Array) depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars, + _depvars) + return generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars) end # Generate training set in the domain and on the boundary -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::Dict, +function generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, dict_indvars::Dict, dict_depvars::Dict) if dx isa Array dxs = dx @@ -249,19 +249,26 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) end - pde_vars = get_variables(eqs, dict_indvars, dict_depvars) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + function get_eqs_train_sets(eqs) + eqs_vars = get_variables(eqs, dict_indvars, dict_depvars) + eqs_args = get_argument(eqs, dict_indvars, dict_depvars) - pde_train_set = adapt(eltypeθ, - hcat(vec(map(points -> collect(points), - Iterators.product(bc_data...)))...)) + eqs_train_set = adapt(eltypeθ, + hcat(vec(map(points -> collect(points), + Iterators.product(bc_data...)))...)) - pde_train_sets = map(pde_args) do bt - span = map(b -> get(dict_var_span_, b, b), bt) - _set = adapt(eltypeθ, - hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) + eqs_train_sets = map(eqs_args) do bt + span = map(b -> get(dict_var_span_, b, b), bt) + _set = adapt(eltypeθ, + hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) + end + return eqs_train_sets end - [pde_train_sets, bcs_train_sets] + + pde_train_sets = get_eqs_train_sets(eqs) + asl_train_sets = get_eqs_train_sets(asl) + + [pde_train_sets, asl_train_sets, bcs_train_sets] end """ @@ -274,35 +281,38 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) +function get_bounds(domains, eqs, asl, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) + return get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, +function get_bounds(domains, eqs, asl, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy::QuadratureTraining) depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) + return get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, +function get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy::QuadratureTraining) dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) - - pde_lower_bounds = map(pde_args) do pd - span = map(p -> get(dict_lower_bound, p, p), pd) - map(s -> adapt(eltypeθ, s) + cbrt(eps(eltypeθ)), span) - end - pde_upper_bounds = map(pde_args) do pd - span = map(p -> get(dict_upper_bound, p, p), pd) - map(s -> adapt(eltypeθ, s) - cbrt(eps(eltypeθ)), span) + function get_eqs_bounds(eqs) + eqs_args = get_argument(eqs, dict_indvars, dict_depvars) + eqs_lower_bounds = map(eqs_args) do pd + span = map(p -> get(dict_lower_bound, p, p), pd) + map(s -> adapt(eltypeθ, s) + cbrt(eps(eltypeθ)), span) + end + eqs_upper_bounds = map(eqs_args) do pd + span = map(p -> get(dict_upper_bound, p, p), pd) + map(s -> adapt(eltypeθ, s) - cbrt(eps(eltypeθ)), span) + end + return [eqs_lower_bounds, eqs_upper_bounds] end - pde_bounds = [pde_lower_bounds, pde_upper_bounds] + pde_bounds = get_eqs_bounds(eqs) + asl_bounds = get_eqs_bounds(asl) bound_vars = get_variables(bcs, dict_indvars, dict_depvars) @@ -314,10 +324,10 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - [pde_bounds, bcs_bounds] + [pde_bounds, asl_bounds, bcs_bounds] end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) +function get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) dx = 1 / strategy.points dict_span = Dict([Symbol(d.variables) => [ infimum(d.domain) + dx, @@ -325,20 +335,19 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str ] for d in domains]) # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] - pde_args = get_argument(eqs, dict_indvars, dict_depvars) - pde_bounds = map(pde_args) do pde_arg - bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) - bds = eltypeθ.(bds) - bds[1, :], bds[2, :] - end - - bound_args = get_argument(bcs, dict_indvars, dict_depvars) - bcs_bounds = map(bound_args) do bound_arg - bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) - bds = eltypeθ.(bds) - bds[1, :], bds[2, :] + function get_eqs_bounds(eqs) + eqs_args = get_argument(eqs, dict_indvars, dict_depvars) + eqs_bounds = map(eqs_args) do eqs_arg + bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, eqs_arg) + bds = eltypeθ.(bds) + bds[1, :], bds[2, :] + end + return eqs_bounds end - return pde_bounds, bcs_bounds + pde_bounds = get_eqs_bounds(eqs) + asl_bounds = get_eqs_bounds(asl) + bcs_bounds = get_eqs_bounds(bcs) + return pde_bounds, asl_bounds, bcs_bounds end function get_numeric_integral(pinnrep::PINNRepresentation) @@ -404,6 +413,7 @@ For more information, see `discretize` and `PINNRepresentation`. function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) eqs = pde_system.eqs + asl = discretization.additional_symb_loss bcs = pde_system.bcs chain = discretization.chain @@ -515,61 +525,73 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, eqs = [eqs] end - pde_indvars = if strategy isa QuadratureTraining - get_argument(eqs, dict_indvars, dict_depvars) - else - get_variables(eqs, dict_indvars, dict_depvars) + if !(asl isa Array) + asl = [asl] end - bc_indvars = if strategy isa QuadratureTraining - get_argument(bcs, dict_indvars, dict_depvars) - else - get_variables(bcs, dict_indvars, dict_depvars) + function get_eqs_indvars(eqs) + eqs_indvars = if strategy isa QuadratureTraining + get_argument(eqs, dict_indvars, dict_depvars) + else + get_variables(eqs, dict_indvars, dict_depvars) + end + return eqs_indvars end + pde_indvars = get_eqs_indvars(eqs) + bc_indvars = get_eqs_indvars(bcs) + asl_indvars = get_eqs_indvars(asl) pde_integration_vars = get_integration_variables(eqs, dict_indvars, dict_depvars) bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars) + asl_integration_vars = get_integration_variables(asl, dict_indvars, dict_depvars) - pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, + pinnrep = PINNRepresentation(eqs, asl, bcs, domains, eq_params, defaults, + default_p, param_estim, additional_loss, adaloss, depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input, logger, multioutput, iteration, init_params, flat_init_params, phi, derivative, - strategy, pde_indvars, bc_indvars, pde_integration_vars, - bc_integration_vars, nothing, nothing, nothing, nothing) + strategy, pde_indvars, asl_indvars, bc_indvars, pde_integration_vars, + asl_integration_vars, bc_integration_vars, nothing, nothing, nothing, nothing, nothing) integral = get_numeric_integral(pinnrep) - symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq; - bc_indvars = pde_indvar) - for (eq, pde_indvar) in zip(eqs, pde_indvars, - pde_integration_vars)] - - symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc; - bc_indvars = bc_indvar) - for (bc, bc_indvar) in zip(bcs, bc_indvars, - bc_integration_vars)] + function build_symbolic_loss_functions(eqs, eqs_indvars, eqs_integration_vars) + symbolic_eqs_loss_functions = [build_symbolic_loss_function(pinnrep, eq; + bc_indvars = eqs_indvar) + for (eq, eqs_indvar) in zip(eqs, eqs_indvars, + eqs_integration_vars)] + return symbolic_eqs_loss_functions + end + symbolic_pde_loss_functions = build_symbolic_loss_functions(eqs, pde_indvars, pde_integration_vars) + symbolic_asl_loss_functions = build_symbolic_loss_functions(asl, asl_indvars, asl_integration_vars) + symbolic_bc_loss_functions = build_symbolic_loss_functions(bcs, bc_indvars, bc_integration_vars) pinnrep.integral = integral pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions + pinnrep.symbolic_asl_loss_functions = symbolic_asl_loss_functions pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions - datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar) - for (eq, pde_indvar, integration_indvar) in zip(eqs, - pde_indvars, - pde_integration_vars)] + function build_datafree_eqs_loss_functions(eqs, eqs_indvars, eqs_integration_vars) + return [build_loss_function(pinnrep, eq, eq_indvar) + for (eq, eq_indvar, integration_indvar) in zip(eqs, + eqs_indvars, + eqs_integration_vars)] + end + + datafree_pde_loss_functions = build_datafree_eqs_loss_functions(eqs, pde_indvars, pde_integration_vars) + datafree_asl_loss_functions = build_datafree_eqs_loss_functions(asl, asl_indvars, asl_integration_vars) + datafree_bc_loss_functions = build_datafree_eqs_loss_functions(bcs, bc_indvars, bc_integration_vars) - datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar) - for (bc, bc_indvar, integration_indvar) in zip(bcs, - bc_indvars, - bc_integration_vars)] + pde_loss_functions, asl_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, + strategy, + datafree_pde_loss_functions, + datafree_asl_loss_functions, + datafree_bc_loss_functions) - pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, - strategy, - datafree_pde_loss_functions, - datafree_bc_loss_functions) # setup for all adaptive losses num_pde_losses = length(pde_loss_functions) + num_asl_losses = length(asl_loss_functions) num_bc_losses = length(bc_loss_functions) # assume one single additional loss function if there is one. this means that the user needs to lump all their functions into a single one, num_additional_loss = additional_loss isa Nothing ? 0 : 1 @@ -578,19 +600,23 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* adaloss.pde_loss_weights + adaloss.asl_loss_weights = ones(adaloss_T, num_asl_losses) .* adaloss.asl_loss_weights adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .* adaloss.additional_loss_weights reweight_losses_func = generate_adaptive_loss_function(pinnrep, adaloss, pde_loss_functions, + asl_loss_functions, bc_loss_functions) function get_likelihood_estimate_function(discretization::PhysicsInformedNN) function full_loss_function(θ, p) # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them - pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] - bc_losses = [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] + # we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions. + pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] + asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions] + bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized # that's why we prefer the user to maintain the increment in the outer loop callback during optimization @@ -600,15 +626,17 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, - bc_losses) + asl_losses, bc_losses) end weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses + weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses sum_weighted_pde_losses = sum(weighted_pde_losses) + sum_weighted_asl_losses = sum(weighted_asl_losses) sum_weighted_bc_losses = sum(weighted_bc_losses) - weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses + weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses full_weighted_loss = if additional_loss isa Nothing weighted_loss_before_additional @@ -762,9 +790,11 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, full_loss_function = get_likelihood_estimate_function(discretization) pinnrep.loss_functions = PINNLossFunctions(bc_loss_functions, pde_loss_functions, - full_loss_function, additional_loss, - datafree_pde_loss_functions, - datafree_bc_loss_functions) + asl_loss_functions, + full_loss_function, additional_loss, + datafree_pde_loss_functions, + datafree_asl_loss_functions, + datafree_bc_loss_functions) return pinnrep diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 48c8f46da9..b156c83cb7 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -34,6 +34,7 @@ PhysicsInformedNN(chain, phi = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = nothing, adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -78,7 +79,7 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments which are splatted to the `OptimizationProblem` on `solve`. """ -struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN +struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPINN chain::Any strategy::T init_params::P @@ -86,6 +87,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN derivative::DER param_estim::PE additional_loss::AL + additional_symb_loss::ASL adaptive_loss::ADA logger::LOG log_options::LogOptions @@ -101,6 +103,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN derivative = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = [], adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -133,13 +136,14 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), typeof(param_estim), - typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, + typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, strategy, init_params, _phi, _derivative, param_estim, additional_loss, + additional_symb_loss, adaptive_loss, logger, log_options, @@ -158,6 +162,7 @@ BayesianPINN(chain, phi = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = nothing, adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -207,7 +212,7 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments. """ -struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN +struct BayesianPINN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, D, K} <: AbstractPINN chain::Any strategy::T init_params::P @@ -215,6 +220,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN derivative::DER param_estim::PE additional_loss::AL + additional_symb_loss::ASL adaptive_loss::ADA logger::LOG log_options::LogOptions @@ -231,6 +237,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN derivative = nothing, param_estim = false, additional_loss = nothing, + additional_symb_loss = nothing, adaptive_loss = nothing, logger = nothing, log_options = LogOptions(), @@ -268,7 +275,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), typeof(param_estim), - typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset), + typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset), typeof(kwargs)}(chain, strategy, init_params, @@ -276,6 +283,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN _derivative, param_estim, additional_loss, + additional_symb_loss, adaptive_loss, logger, log_options, @@ -284,7 +292,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN multioutput, dataset, kwargs) - end + end end """ @@ -303,6 +311,10 @@ mutable struct PINNRepresentation """ eqs::Any """ + The additional symbolic loss functions + """ + asl::Any + """ The boundary condition equations """ bcs::Any @@ -404,6 +416,10 @@ mutable struct PINNRepresentation """ ??? """ + asl_indvars::Any + """ + ??? + """ bc_indvars::Any """ ??? @@ -412,6 +428,10 @@ mutable struct PINNRepresentation """ ??? """ + asl_integration_vars::Any + """ + ??? + """ bc_integration_vars::Any """ ??? @@ -422,6 +442,10 @@ mutable struct PINNRepresentation """ symbolic_pde_loss_functions::Any """ + The additional loss functions as represented in Julia AST + """ + symbolic_asl_loss_functions::Any + """ The boundary condition loss functions as represented in Julia AST """ symbolic_bc_loss_functions::Any @@ -450,6 +474,10 @@ struct PINNLossFunctions """ pde_loss_functions::Any """ + The additional symbolic loss functions + """ + asl_loss_functions::Any + """ The full loss function, combining the PDE and boundary condition loss functions. This is the loss function that is used by the optimizer. """ @@ -463,6 +491,10 @@ struct PINNLossFunctions """ datafree_pde_loss_functions::Any """ + The pre-data version of the additional symbolic loss function + """ + datafree_asl_loss_functions::Any + """ The pre-data version of the BC loss function """ datafree_bc_loss_functions::Any diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 9161f3c365..b14c264e5f 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -303,7 +303,8 @@ Example: [(derivative(phi1, u1, [x, y], [[ε,0]], 1, θ1) + 4 * derivative(phi2, u, [x, y], [[0,ε]], 1, θ2)) - 0, (derivative(phi2, u2, [x, y], [[ε,0]], 1, θ2) + 9 * derivative(phi1, u, [x, y], [[0,ε]], 1, θ1)) - 0] """ -function parse_equation(pinnrep::PINNRepresentation, eq) +# Parse an equation +function parse_equation(pinnrep::PINNRepresentation, eq::Equation) eq_lhs = isequal(expand_derivatives(eq.lhs), 0) ? eq.lhs : expand_derivatives(eq.lhs) eq_rhs = isequal(expand_derivatives(eq.rhs), 0) ? eq.rhs : expand_derivatives(eq.rhs) left_expr = transform_expression(pinnrep, toexpr(eq_lhs)) @@ -313,6 +314,11 @@ function parse_equation(pinnrep::PINNRepresentation, eq) loss_func = :($left_expr .- $right_expr) end +# Parse an energy +function parse_equation(pinnrep::PINNRepresentation, eq) + loss_func = _dot_(transform_expression(pinnrep, toexpr(eq))) +end + function get_indvars_ex(bc_indvars) # , dict_this_eq_indvars) i_ = 1 indvars_ex = map(bc_indvars) do u diff --git a/src/training_strategies.jl b/src/training_strategies.jl index d4edd26bfc..2b46bf4b25 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -53,28 +53,31 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep dx = strategy.dx eltypeθ = eltype(pinnrep.flat_init_params) - train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ, + train_sets = generate_training_sets(domains, dx, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars) # the points in the domain and on the boundary - pde_train_sets, bcs_train_sets = train_sets - pde_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), - pde_train_sets) - bcs_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), - bcs_train_sets) - pde_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) - for (_loss, _set) in zip(datafree_pde_loss_function, - pde_train_sets)] - - bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) - for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)] - - pde_loss_functions, bc_loss_functions + pde_train_sets, asl_train_sets, bcs_train_sets = train_sets + + all_loss_functions = + map([(pde_train_sets, datafree_pde_loss_function), + (asl_train_sets, datafree_asl_loss_function), + (bcs_train_sets, datafree_bc_loss_function)]) do (train_sets, datafree_loss_function) + + train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), + train_sets) + loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) + for (_loss, _set) in zip(datafree_loss_function, + train_sets)] + return loss_functions + end + return Tuple(all_loss_functions) end function get_loss_function(loss_function, train_set, eltypeθ, strategy::GridTraining; @@ -113,22 +116,25 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::StochasticTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - pde_bounds, bcs_bounds = bounds + pde_bounds, asl_bounds, bcs_bounds = bounds - pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) - for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] + all_loss_functions = map([(datafree_pde_loss_function, pde_bounds), + (datafree_asl_loss_function, asl_bounds), + (datafree_bc_loss_function, bcs_bounds)]) do (datafree_loss_function, bounds) - bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) - for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] + return [get_loss_function(_loss, bound, eltypeθ, strategy) + for (_loss, bound) in zip(datafree_loss_function, bounds)] + end - pde_loss_functions, bc_loss_functions + return Tuple(all_loss_functions) end function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining; @@ -195,17 +201,20 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuasiRandomTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - pde_bounds, bcs_bounds = bounds + pde_bounds, asl_bounds, bcs_bounds = bounds pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] + asl_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) + for (_loss, bound) in zip(datafree_asl_loss_function, asl_bounds)] strategy_ = QuasiRandomTraining(strategy.bcs_points; sampling_alg = strategy.sampling_alg, @@ -214,7 +223,8 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy_) for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] - pde_loss_functions, bc_loss_functions + + pde_loss_functions, asl_loss_functions, bc_loss_functions end function get_loss_function(loss_function, bound, eltypeθ, strategy::QuasiRandomTraining; @@ -288,22 +298,24 @@ end function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuadratureTraining, datafree_pde_loss_function, + datafree_asl_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, asl, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, asl, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - pde_bounds, bcs_bounds = bounds - - lbs, ubs = pde_bounds - pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) - for (_loss, lb, ub) in zip(datafree_pde_loss_function, lbs, ubs)] - lbs, ubs = bcs_bounds - bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) - for (_loss, lb, ub) in zip(datafree_bc_loss_function, lbs, ubs)] + pde_bounds, asl_bounds, bcs_bounds = bounds + + all_loss_functions = map([(datafree_pde_loss_function, pde_bounds), + (datafree_asl_loss_function, asl_bounds), + (datafree_bc_loss_function, bcs_bounds)]) do (datafree_loss_function, bounds) + lbs, ubs = bounds + return [get_loss_function(_loss, lb, ub, eltypeθ, strategy) + for (_loss, lb, ub) in zip(datafree_loss_function, lbs, ubs)] + end - pde_loss_functions, bc_loss_functions + return Tuple(all_loss_functions) end function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::QuadratureTraining; @@ -366,4 +378,4 @@ function get_loss_function(loss_function, train_set, eltypeθ, strategy::WeightedIntervalTraining; τ = nothing) loss = (θ) -> mean(abs2, loss_function(train_set, θ)) -end \ No newline at end of file +end diff --git a/test/additional_symbolic_loss_tests.jl b/test/additional_symbolic_loss_tests.jl new file mode 100644 index 0000000000..27c154ae97 --- /dev/null +++ b/test/additional_symbolic_loss_tests.jl @@ -0,0 +1,172 @@ +using NeuralPDE: DomainSets +using Random +using Test +using ComponentArrays +using OptimizationOptimisers +using NeuralPDE +using LinearAlgebra +using Lux +import ModelingToolkit: Interval + +@parameters x0 x1 x2 x3 +@variables ρ01(..) ρ02(..) ρ03(..) ρ12(..) ρ13(..) ρ23(..) + +# the 4-torus +domain = [ + x0 ∈ Interval(0.0, 1.0), + x1 ∈ Interval(0.0, 1.0), + x2 ∈ Interval(0.0, 1.0), + x3 ∈ Interval(0.0, 1.0), +] + +∂₀ = Differential(x0) +∂₁ = Differential(x1) +∂₂ = Differential(x2) +∂₃ = Differential(x3) + +d₂(ρ) = [ + # commented are the signed permutations of the indeces + + #(0,1,2) + (2,0,1) + (1,2,0) - (1,0,2) - (2,1,0) - (0,2,1) + 2 * ∂₀(ρ[4]) - 2 * ∂₁(ρ[2]) + 2 * ∂₂(ρ[1]), + #(0,1,3) + (3,0,1) + (1,3,0) - (1,0,3) - (0,3,1) - (3,1,0) + 2 * ∂₀(ρ[5]) - 2 * ∂₁(ρ[3]) + 2 * ∂₃(ρ[1]), + #(0,2,3) + (3,0,2) + (2,3,0) - (2,0,3) - (0,3,2) - (3,2,0) + 2 * ∂₀(ρ[6]) - 2 * ∂₂(ρ[3]) + 2 * ∂₃(ρ[2]), + #(1,2,3) + (3,1,2) + (2,3,1) - (2,1,3) - (1,3,2) - (3,2,1) + 2 * ∂₁(ρ[6]) - 2 * ∂₂(ρ[5]) + 2 * ∂₃(ρ[4]), +] + +u(ρ) = ρ[1] * ρ[6] - ρ[2] * ρ[5] + ρ[3] * ρ[4] + +K₁(ρ) = 2(ρ[1] + ρ[6]) / u(ρ) +K₂(ρ) = 2(ρ[2] - ρ[5]) / u(ρ) +K₃(ρ) = 2(ρ[3] + ρ[4]) / u(ρ) + +K(ρ) = [ + K₁(ρ), + K₂(ρ), + K₃(ρ), +] + +# energy +fₑ(ρ) = (K(ρ)[1]^2 + K(ρ)[2]^2 + K(ρ)[3]^2) * u(ρ) + +energies = + let ρ = [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] + [fₑ(ρ)] + end + +# periodic boundary conditions for the 4-torus +bcs = [ + ρ01(0.0, x1, x2, x3) ~ ρ01(1.0, x1, x2, x3), + ρ01(x0, 0.0, x2, x3) ~ ρ01(x0, 1.0, x2, x3), + ρ01(x0, x1, 0.0, x3) ~ ρ01(x0, x1, 1.0, x3), + ρ01(x0, x1, x2, 0.0) ~ ρ01(x0, x1, x2, 1.0), + ρ02(0.0, x1, x2, x3) ~ ρ02(1.0, x1, x2, x3), + ρ02(x0, 0.0, x2, x3) ~ ρ02(x0, 1.0, x2, x3), + ρ02(x0, x1, 0.0, x3) ~ ρ02(x0, x1, 1.0, x3), + ρ02(x0, x1, x2, 0.0) ~ ρ02(x0, x1, x2, 1.0), + ρ03(0.0, x1, x2, x3) ~ ρ03(1.0, x1, x2, x3), + ρ03(x0, 0.0, x2, x3) ~ ρ03(x0, 1.0, x2, x3), + ρ03(x0, x1, 0.0, x3) ~ ρ03(x0, x1, 1.0, x3), + ρ03(x0, x1, x2, 0.0) ~ ρ03(x0, x1, x2, 1.0), + ρ12(0.0, x1, x2, x3) ~ ρ12(1.0, x1, x2, x3), + ρ12(x0, 0.0, x2, x3) ~ ρ12(x0, 1.0, x2, x3), + ρ12(x0, x1, 0.0, x3) ~ ρ12(x0, x1, 1.0, x3), + ρ12(x0, x1, x2, 0.0) ~ ρ12(x0, x1, x2, 1.0), + ρ13(0.0, x1, x2, x3) ~ ρ13(1.0, x1, x2, x3), + ρ13(x0, 0.0, x2, x3) ~ ρ13(x0, 1.0, x2, x3), + ρ13(x0, x1, 0.0, x3) ~ ρ13(x0, x1, 1.0, x3), + ρ13(x0, x1, x2, 0.0) ~ ρ13(x0, x1, x2, 1.0), + ρ23(0.0, x1, x2, x3) ~ ρ23(1.0, x1, x2, x3), + ρ23(x0, 0.0, x2, x3) ~ ρ23(x0, 1.0, x2, x3), + ρ23(x0, x1, 0.0, x3) ~ ρ23(x0, x1, 1.0, x3), + ρ23(x0, x1, x2, 0.0) ~ ρ23(x0, x1, x2, 1.0), +] + +# equations for dρ = 0. +eqClosed(ρ) = d₂(ρ)[:] .~ 0 + +eqs = + let ρ = [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] + vcat( + eqClosed(ρ), + ) + end + + +input_ = length(domain) +n = 16 + +ixToSym = Dict( + 1 => :ρ01, + 2 => :ρ02, + 3 => :ρ03, + 4 => :ρ12, + 5 => :ρ13, + 6 => :ρ23 +) + +chains = NamedTuple((ixToSym[ix], Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, 1))) for ix in 1:6) +chains0 = collect(chains) + +function test_donaldson_energy_loss_no_logs(ϵ, sym_prob, prob) + # pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions + # bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions + # energy_inner_loss_functions = sym_prob.loss_functions.asl_loss_functions + + ps = map(c -> Lux.setup(Random.default_rng(), c)[1], chains) |> ComponentArray .|> Float64 + prob1 = remake(prob; u0 = ComponentVector(depvar = ps)) + + callback(ϵ::Float64) = function(p, l) + # println("loss: ", l) + # println("pde_losses: ", map(l_ -> l_(p), pde_inner_loss_functions)) + # println("bcs_losses: ", map(l_ -> l_(p), bcs_inner_loss_functions)) + # println("energy losses: ", map(l_ -> l_(p), energy_inner_loss_functions)) + return l < ϵ + end + _sol = Optimization.solve(prob1, Adam(0.01); callback=callback(ϵ), maxiters = 1) + return true +end + + +@named pdesystem = PDESystem(eqs, bcs, domain, [x0, x1, x2, x3], + [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] +) +discretization = PhysicsInformedNN(chains0, QuasiRandomTraining(1000)) +sym_prob = symbolic_discretize(pdesystem, discretization) +prob = discretize(pdesystem, discretization) +@info "testing additional symbolic loss functions: solver runs without additional symbolic losses." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + + +@named pdesystem1 = PDESystem([], bcs, domain, [x0, x1, x2, x3], + [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] +) +discretization = PhysicsInformedNN(chains0, QuasiRandomTraining(1000); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem1, discretization) +prob = discretize(pdesystem1, discretization) +@info "testing additional symbolic loss functions: quasi random training: solver runs with only additional symbolic loss function." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + +@named pdesystem2 = PDESystem(eqs, bcs, domain, [x0, x1, x2, x3], + [ρ01(x0, x1, x2, x3), ρ02(x0, x1, x2, x3), ρ03(x0, x1, x2, x3), ρ12(x0, x1, x2, x3), ρ13(x0, x1, x2, x3), ρ23(x0, x1, x2, x3)] +) +discretization = PhysicsInformedNN(chains0, StochasticTraining(1000); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem2, discretization) +prob = discretize(pdesystem2, discretization) +@info "testing additional symbolic loss functions: stochastic training: solver runs with additional symbolic loss function and PDE system." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + +discretization = PhysicsInformedNN(chains0, GridTraining(0.1); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem2, discretization) +prob = discretize(pdesystem2, discretization) +@info "testing additional symbolic loss functions: grid training: solver runs with additional symbolic loss function and PDE system." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) + +discretization = PhysicsInformedNN(chains0, QuadratureTraining(); additional_symb_loss = energies) +sym_prob = symbolic_discretize(pdesystem1, discretization) +prob = discretize(pdesystem1, discretization) +@info "testing additional symbolic loss functions: quadrature training: solver runs additional symbolic loss function and PDE system." +@test test_donaldson_energy_loss_no_logs(0.5, sym_prob, prob) diff --git a/test/runtests.jl b/test/runtests.jl index 5d6ac6909e..391ea4f9a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ end if GROUP == "All" || GROUP == "NNPDE2" @time @safetestset "Additional Loss" begin include("additional_loss_tests.jl") end + @time @safetestset "Additional Symbolic Loss" begin include("additional_symbolic_loss_tests.jl") end @time @safetestset "Direction Function Approximation" begin include("direct_function_tests.jl") end end if GROUP == "All" || GROUP == "NeuralAdapter" @@ -65,4 +66,4 @@ end @safetestset "NNPDE_gpu" begin include("NNPDE_tests_gpu.jl") end @safetestset "NNPDE_gpu_Lux" begin include("NNPDE_tests_gpu_Lux.jl") end end -end \ No newline at end of file +end