From 6dd9e3861bf9c4e31e75ac1dec64a166fa095e2a Mon Sep 17 00:00:00 2001 From: KirillZubov Date: Thu, 17 Oct 2024 17:35:51 +0400 Subject: [PATCH] update --- Project.toml | 4 +- src/advancedHMC_MCMC.jl | 2 + src/pino_ode_solve.jl | 91 +++++++++++++++++++------------------- src/training_strategies.jl | 4 +- 4 files changed, 51 insertions(+), 50 deletions(-) diff --git a/Project.toml b/Project.toml index 14ba864c90..6b1d3b9e1c 100644 --- a/Project.toml +++ b/Project.toml @@ -124,12 +124,10 @@ MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StochasticDiffEq", "TensorBoardLogger", "Test"] - +test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StochasticDiffEq", "TensorBoardLogger", "Test"] \ No newline at end of file diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 1751cbc82e..380d284f55 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -73,6 +73,7 @@ suggested extra loss function for ODE solver case """ @views function L2loss2(ltd::LogTargetDensity, θ) ltd.extraparams ≤ 0 && return false # XXX: type-stability? + f = ltd.prob.f t = ltd.dataset[end] u1 = ltd.dataset[2] @@ -226,6 +227,7 @@ Prior logpdf for NN parameters + ODE constants. @views function priorweights(ltd::LogTargetDensity, θ) allparams = ltd.priors nnwparams = allparams[1] # nn weights + ltd.extraparams ≤ 0 && return logpdf(nnwparams, θ) # Vector of ode parameters priors diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index b10a3a880a..b1fa2c9be6 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -11,7 +11,7 @@ neural operator, which is used as a solver for a parametrized `ODEProblem`. ## Positional Arguments -* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or `Flux.Chain`. +* `chain`: A neural network architecture, defined as a `AbstractLuxLayer` or `Flux.Chain`. `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)` * `opt`: The optimizer to train the neural network. * `bounds`: A dictionary containing the bounds for the parameters of the parametric ODE. @@ -30,17 +30,15 @@ neural operator, which is used as a solver for a parametrized `ODEProblem`. * Sifan Wang "Learning the solution operator of parametric partial differential equations with physics-informed DeepOnets" * Zongyi Li "Physics-Informed Neural Operator for Learning Partial Differential Equations" """ -struct PINOODE{C, O, B, I, S <: Union{Nothing, AbstractTrainingStrategy}, - AL <: Union{Nothing, Function}, K} <: - SciMLBase.AbstractODEAlgorithm - chain::C - opt::O - bounds::B +@concrete struct PINOODE + chain + opt + bounds number_of_parameters::Int - init_params::I - strategy::S - additional_loss::AL - kwargs::K + init_params + strategy <: Union{Nothing, AbstractTrainingStrategy} + additional_loss <: Union{Nothing, Function} + kwargs end function PINOODE(chain, @@ -51,38 +49,37 @@ function PINOODE(chain, strategy = nothing, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractLuxLayer) && (chain = Lux.transform(chain)) - PINOODE(chain, opt, bounds, number_of_parameters, + chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) + return PINOODE(chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss, kwargs) end -struct PINOPhi{C, S} - chain::C - st::S - function PINOPhi(chain::Lux.AbstractLuxLayer, st) - new{typeof(chain), typeof(st)}(chain, st) - end +@concrete struct PINOPhi + model <:AbstractLuxLayer + smodel <: StatefulLuxLayer +end + +function PINOPhi(model::AbstractLuxLayer, st) + return PINOPhi(model, StatefulLuxLayer{false}(model, nothing, st)) end -function generate_pino_phi_θ(chain::Lux.AbstractLuxLayer, init_params) - θ, st = Lux.setup(Random.default_rng(), chain) - init_params = isnothing(init_params) ? θ : init_params - init_params = ComponentArrays.ComponentArray(init_params) +function generate_pino_phi_θ(chain::AbstractLuxLayer, nothing) + θ, st = LuxCore.setup(Random.default_rng(), chain) + PINOPhi(chain, st), θ +end + +function generate_pino_phi_θ(chain::AbstractLuxLayer, init_params) + st = LuxCore.initialstates(Random.default_rng(), chain) PINOPhi(chain, st), init_params end -function (f::PINOPhi{C, T})(x::Array, θ) where {C <: Lux.Chain, T} - eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ)) - x = convert.(eltypeθ, adapt(typeθ, x)) - y, st = f.chain(x, θ, f.st) - y +function (f::PINOPhi{C, T})(x, θ) where {C <: AbstractLuxLayer, T} + dev = safe_get_device(θ) + return f(dev, safe_expand(dev, x), θ) end -function (f::PINOPhi{C, T})(x::Tuple, θ) where {C <: DeepONet, T} - eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ)) - x = (convert.(eltypeθ, adapt(typeθ, x[1])), convert.(eltypeθ, adapt(typeθ, x[2]))) - y, st = f.chain(x, θ, f.st) - y +function (f::PINOPhi{C, T})(dev, x, θ) where {C <: AbstractLuxLayer, T} + f.smodel(dev(x), θ) end function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: DeepONet, T} @@ -180,7 +177,7 @@ function get_trainset( end function get_trainset( - strategy::GridTraining, chain::Lux.Chain, bounds, number_of_parameters, tspan, eltypeθ) + strategy::GridTraining, chain::Chain, bounds, number_of_parameters, tspan, eltypeθ) dt = strategy.dx tspan_ = tspan[1]:dt:tspan[2] pspan = [range(start = b[1], length = number_of_parameters, stop = b[2]) @@ -194,9 +191,9 @@ function get_trainset( end function get_trainset( - strategy::StochasticTraining, chain::Union{DeepONet, Lux.Chain}, + strategy::StochasticTraining, chain::Union{DeepONet, Chain}, bounds, number_of_parameters, tspan, eltypeθ) - (number_of_parameters != strategy.points && chain isa Lux.Chain) && + (number_of_parameters != strategy.points && chain isa Chain) && throw(error("number_of_parameters should be the same strategy.points for StochasticTraining")) p = reduce(vcat, [(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1] @@ -208,7 +205,8 @@ end function generate_loss( strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ) - x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan, eltypeθ) + x = get_trainset( + strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ) function loss(θ, _) initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ) end @@ -217,7 +215,8 @@ end function generate_loss( strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ) function loss(θ, _) - x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan, eltypeθ) + x = get_trainset( + strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ) initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ) end end @@ -240,10 +239,10 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, verbose = false, saveat = nothing, maxiters = nothing) - @unpack tspan, u0, f = prob - @unpack chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss = alg + (; tspan, u0, f) = prob + (; chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss) = alg - if !(chain isa Lux.AbstractLuxLayer) + if !(chain isa AbstractLuxLayer) error("Only Lux.AbstractLuxLayer neural networks are supported") if !(chain isa DeepONet) || !(chain isa Chain) @@ -252,7 +251,9 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, end phi, init_params = generate_pino_phi_θ(chain, init_params) - eltypeθ = eltype(init_params) + + # init_params = ComponentArray(init_params) + # eltypeθ = eltype(init_params) #TODO? isinplace(prob) && throw(error("The PINOODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) @@ -260,14 +261,14 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, try if chain isa DeepONet in_dim = chain.branch.layers.layer_1.in_dims - u = rand(eltypeθ, in_dim, number_of_parameters) - v = rand(eltypeθ, 1, 10, 1) + u = rand(in_dim, number_of_parameters) + v = rand(1, 10, 1) x = (u, v) phi(x, init_params) end if chain isa Chain in_dim = chain.layers.layer_1.in_dims - x = rand(eltypeθ, in_dim, number_of_parameters) + x = rand(in_dim, number_of_parameters) phi(x, init_params) end catch err diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 974f2529fa..ca07676f26 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -291,8 +291,8 @@ end """ WeightedIntervalTraining(weights, samples) -A training strategy that generates points for training based on the given inputs. -We split the timespan into equal segments based on the number of weights, +A training strategy that generates points for training based on the given inputs. +We split the timespan into equal segments based on the number of weights, then sample points in each segment based on that segments corresponding weight, such that the total number of sampled points is equivalent to the given samples