Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 17, 2024
1 parent 1925aaf commit 6dd9e38
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 50 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,10 @@ MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StochasticDiffEq", "TensorBoardLogger", "Test"]

test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StochasticDiffEq", "TensorBoardLogger", "Test"]
2 changes: 2 additions & 0 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ suggested extra loss function for ODE solver case
"""
@views function L2loss2(ltd::LogTargetDensity, θ)
ltd.extraparams 0 && return false # XXX: type-stability?

f = ltd.prob.f
t = ltd.dataset[end]
u1 = ltd.dataset[2]
Expand Down Expand Up @@ -226,6 +227,7 @@ Prior logpdf for NN parameters + ODE constants.
@views function priorweights(ltd::LogTargetDensity, θ)
allparams = ltd.priors
nnwparams = allparams[1] # nn weights

ltd.extraparams 0 && return logpdf(nnwparams, θ)

# Vector of ode parameters priors
Expand Down
91 changes: 46 additions & 45 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ neural operator, which is used as a solver for a parametrized `ODEProblem`.
## Positional Arguments
* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or `Flux.Chain`.
* `chain`: A neural network architecture, defined as a `AbstractLuxLayer` or `Flux.Chain`.
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`
* `opt`: The optimizer to train the neural network.
* `bounds`: A dictionary containing the bounds for the parameters of the parametric ODE.
Expand All @@ -30,17 +30,15 @@ neural operator, which is used as a solver for a parametrized `ODEProblem`.
* Sifan Wang "Learning the solution operator of parametric partial differential equations with physics-informed DeepOnets"
* Zongyi Li "Physics-Informed Neural Operator for Learning Partial Differential Equations"
"""
struct PINOODE{C, O, B, I, S <: Union{Nothing, AbstractTrainingStrategy},
AL <: Union{Nothing, Function}, K} <:
SciMLBase.AbstractODEAlgorithm
chain::C
opt::O
bounds::B
@concrete struct PINOODE
chain
opt
bounds
number_of_parameters::Int
init_params::I
strategy::S
additional_loss::AL
kwargs::K
init_params
strategy <: Union{Nothing, AbstractTrainingStrategy}
additional_loss <: Union{Nothing, Function}
kwargs
end

function PINOODE(chain,
Expand All @@ -51,38 +49,37 @@ function PINOODE(chain,
strategy = nothing,
additional_loss = nothing,
kwargs...)
!(chain isa Lux.AbstractLuxLayer) && (chain = Lux.transform(chain))
PINOODE(chain, opt, bounds, number_of_parameters,
chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return PINOODE(chain, opt, bounds, number_of_parameters,
init_params, strategy, additional_loss, kwargs)
end

struct PINOPhi{C, S}
chain::C
st::S
function PINOPhi(chain::Lux.AbstractLuxLayer, st)
new{typeof(chain), typeof(st)}(chain, st)
end
@concrete struct PINOPhi
model <:AbstractLuxLayer
smodel <: StatefulLuxLayer
end

function PINOPhi(model::AbstractLuxLayer, st)
return PINOPhi(model, StatefulLuxLayer{false}(model, nothing, st))
end

function generate_pino_phi_θ(chain::Lux.AbstractLuxLayer, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
init_params = isnothing(init_params) ? θ : init_params
init_params = ComponentArrays.ComponentArray(init_params)
function generate_pino_phi_θ(chain::AbstractLuxLayer, nothing)
θ, st = LuxCore.setup(Random.default_rng(), chain)
PINOPhi(chain, st), θ
end

function generate_pino_phi_θ(chain::AbstractLuxLayer, init_params)
st = LuxCore.initialstates(Random.default_rng(), chain)
PINOPhi(chain, st), init_params
end

function (f::PINOPhi{C, T})(x::Array, θ) where {C <: Lux.Chain, T}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
x = convert.(eltypeθ, adapt(typeθ, x))
y, st = f.chain(x, θ, f.st)
y
function (f::PINOPhi{C, T})(x, θ) where {C <: AbstractLuxLayer, T}
dev = safe_get_device(θ)
return f(dev, safe_expand(dev, x), θ)
end

function (f::PINOPhi{C, T})(x::Tuple, θ) where {C <: DeepONet, T}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
x = (convert.(eltypeθ, adapt(typeθ, x[1])), convert.(eltypeθ, adapt(typeθ, x[2])))
y, st = f.chain(x, θ, f.st)
y
function (f::PINOPhi{C, T})(dev, x, θ) where {C <: AbstractLuxLayer, T}
f.smodel(dev(x), θ)
end

function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: DeepONet, T}
Expand Down Expand Up @@ -180,7 +177,7 @@ function get_trainset(
end

function get_trainset(
strategy::GridTraining, chain::Lux.Chain, bounds, number_of_parameters, tspan, eltypeθ)
strategy::GridTraining, chain::Chain, bounds, number_of_parameters, tspan, eltypeθ)
dt = strategy.dx
tspan_ = tspan[1]:dt:tspan[2]
pspan = [range(start = b[1], length = number_of_parameters, stop = b[2])
Expand All @@ -194,9 +191,9 @@ function get_trainset(
end

function get_trainset(
strategy::StochasticTraining, chain::Union{DeepONet, Lux.Chain},
strategy::StochasticTraining, chain::Union{DeepONet, Chain},
bounds, number_of_parameters, tspan, eltypeθ)
(number_of_parameters != strategy.points && chain isa Lux.Chain) &&
(number_of_parameters != strategy.points && chain isa Chain) &&
throw(error("number_of_parameters should be the same strategy.points for StochasticTraining"))
p = reduce(vcat,
[(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
Expand All @@ -208,7 +205,8 @@ end

function generate_loss(
strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(
strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ)
function loss(θ, _)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
Expand All @@ -217,7 +215,8 @@ end
function generate_loss(
strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ)
function loss(θ, _)
x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(
strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
end
Expand All @@ -240,10 +239,10 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
verbose = false,
saveat = nothing,
maxiters = nothing)
@unpack tspan, u0, f = prob
@unpack chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss = alg
(; tspan, u0, f) = prob
(; chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss) = alg

if !(chain isa Lux.AbstractLuxLayer)
if !(chain isa AbstractLuxLayer)
error("Only Lux.AbstractLuxLayer neural networks are supported")

if !(chain isa DeepONet) || !(chain isa Chain)
Expand All @@ -252,22 +251,24 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
end

phi, init_params = generate_pino_phi_θ(chain, init_params)
eltypeθ = eltype(init_params)

# init_params = ComponentArray(init_params)
# eltypeθ = eltype(init_params) #TODO?

isinplace(prob) &&
throw(error("The PINOODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

try
if chain isa DeepONet
in_dim = chain.branch.layers.layer_1.in_dims
u = rand(eltypeθ, in_dim, number_of_parameters)
v = rand(eltypeθ, 1, 10, 1)
u = rand(in_dim, number_of_parameters)
v = rand(1, 10, 1)
x = (u, v)
phi(x, init_params)
end
if chain isa Chain
in_dim = chain.layers.layer_1.in_dims
x = rand(eltypeθ, in_dim, number_of_parameters)
x = rand(in_dim, number_of_parameters)
phi(x, init_params)
end
catch err
Expand Down
4 changes: 2 additions & 2 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ end
"""
WeightedIntervalTraining(weights, samples)
A training strategy that generates points for training based on the given inputs.
We split the timespan into equal segments based on the number of weights,
A training strategy that generates points for training based on the given inputs.
We split the timespan into equal segments based on the number of weights,
then sample points in each segment based on that segments corresponding weight,
such that the total number of sampled points is equivalent to the given samples
Expand Down

0 comments on commit 6dd9e38

Please sign in to comment.