Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 18, 2024
1 parent 2967a4a commit 36226f9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ MethodOfLines = "0.11"
ModelingToolkit = "9.9"
MonteCarloMeasurements = "1"
NeuralPDE = "5"
NeuralOperators = "0.5.0"
NeuralOperators = "0.5"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
Expand Down
32 changes: 13 additions & 19 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function PINOODE(chain,
end

@concrete struct PINOPhi
model <:AbstractLuxLayer
model <: AbstractLuxLayer
smodel <: StatefulLuxLayer
end

Expand Down Expand Up @@ -166,18 +166,17 @@ function initial_condition_loss(
end

function get_trainset(
strategy::GridTraining, chain::DeepONet, bounds, number_of_parameters, tspan, eltypeθ)
strategy::GridTraining, chain::DeepONet, bounds, number_of_parameters, tspan)
dt = strategy.dx
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
t_ = collect(tspan[1]:dt:tspan[2])
t = reshape(t_, 1, size(t_, 1), 1)
p, t = convert.(eltypeθ, p), convert.(eltypeθ, t)
(p, t)
end

function get_trainset(
strategy::GridTraining, chain::Chain, bounds, number_of_parameters, tspan, eltypeθ)
strategy::GridTraining, chain::Chain, bounds, number_of_parameters, tspan)
dt = strategy.dx
tspan_ = tspan[1]:dt:tspan[2]
pspan = [range(start = b[1], length = number_of_parameters, stop = b[2])
Expand All @@ -186,37 +185,33 @@ function get_trainset(
points -> collect(points), Iterators.product([pspan..., tspan_]...)))...)
x = reshape(x_, size(bounds, 1) + 1, prod(size.(pspan, 1)), size(tspan_, 1))
p, t = x[1:(end - 1), :, :], x[[end], :, :]
p, t = convert.(eltypeθ, p), convert.(eltypeθ, t)
(p, t)
end

function get_trainset(
strategy::StochasticTraining, chain::Union{DeepONet, Chain},
bounds, number_of_parameters, tspan, eltypeθ)
bounds, number_of_parameters, tspan)
(number_of_parameters != strategy.points && chain isa Chain) &&
throw(error("number_of_parameters should be the same strategy.points for StochasticTraining"))
p = reduce(vcat,
[(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
for bound in bounds])
t = (tspan[2] .- tspan[1]) .* rand(1, strategy.points, 1) .+ tspan[1]
p, t = convert.(eltypeθ, p), convert.(eltypeθ, t)
(p, t)
end

function generate_loss(
strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(
strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ)
strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan)
x = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan)
function loss(θ, _)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
end

function generate_loss(
strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ)
strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan)
function loss(θ, _)
x = get_trainset(
strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
end
Expand Down Expand Up @@ -253,22 +248,21 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
phi, init_params = generate_pino_phi_θ(chain, init_params)

init_params = ComponentArray(init_params)
eltypeθ = eltype(init_params)

isinplace(prob) &&
throw(error("The PINOODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

try
if chain isa DeepONet
in_dim = chain.branch.layers.layer_1.in_dims
u = rand(eltypeθ, in_dim, number_of_parameters)
v = rand(eltypeθ, 1, 10, 1)
u = rand(in_dim, number_of_parameters)
v = rand(1, 10, 1)
x = (u, v)
phi(x, init_params)
end
if chain isa Chain
in_dim = chain.layers.layer_1.in_dims
x = rand(eltypeθ, in_dim, number_of_parameters)
x = rand(in_dim, number_of_parameters)
phi(x, init_params)
end
catch err
Expand All @@ -286,7 +280,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
end

inner_f = generate_loss(
strategy, prob, phi, bounds, number_of_parameters, tspan, eltypeθ)
strategy, prob, phi, bounds, number_of_parameters, tspan)

function total_loss(θ, _)
L2_loss = inner_f(θ, nothing)
Expand All @@ -312,7 +306,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

x = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan)
if chain isa DeepONet
u = phi(x, res.u)
elseif chain isa Chain
Expand Down
17 changes: 8 additions & 9 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Test
using OptimizationOptimisers
using Lux
using Statistics, Random
Expand All @@ -25,7 +24,7 @@ function get_trainset(chain::Lux.Chain, bounds, number_of_parameters, tspan, dt)
end

#Test Chain with Float64 accuracy
@testset "Example du = cos(p * t)" begin
@testitem "Example du = cos(p * t)" tags=[:pinoode] begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0
Expand Down Expand Up @@ -57,7 +56,7 @@ end
end

#Test DeepONet with Float64 accuracy
@testset "Example du = cos(p * t)" begin
@testitem "Example du = cos(p * t)" tags=[:pinoode] begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0
Expand Down Expand Up @@ -97,7 +96,7 @@ end
@test ground_solutionpredict_sol rtol=0.05
end

@testset "Example du = cos(p * t) + u" begin
@testitem "Example du = cos(p * t) + u" tags=[:pinoode] begin
eq_(u, p, t) = cos(p * t) + u
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
Expand All @@ -124,7 +123,7 @@ end
@test ground_solutionpredict_sol rtol=0.05
end

@testset "Example with data du = p*t^2" begin
@testitem "Example with data du = p*t^2" tags=[:pinoode] begin
equation = (u, p, t) -> p * t^2
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
Expand Down Expand Up @@ -169,8 +168,8 @@ end
@test ground_solutionpredict_sol rtol=0.05
end

#multiple parameters chain
@testset "Example du = cos(p * t)" begin
#multiple parameters Сhain
@testitem "Example multiple parameters Сhain du = p1 * cos(p2 * t) + p3" tags=[:pinoode] begin
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
tspan = (0.0, 1.0)
u0 = 1.0
Expand Down Expand Up @@ -212,7 +211,7 @@ end
end

#multiple parameters DeepOnet
@testset "Example du = cos(p * t)" begin
@testitem "Example multiple parameters DeepOnet du = p1 * cos(p2 * t) + p3" tags=[:pinoode] begin
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
tspan = (0.0, 1.0)
u0 = 1.0
Expand Down Expand Up @@ -254,7 +253,7 @@ end
end

#vector output
@testset "Example du = [cos(p * t), sin(p * t)]" begin
@testitem "Example du = [cos(p * t), sin(p * t)]" tags=[:pinoode] begin
equation = (u, p, t) -> [cos(p * t), sin(p * t)]
tspan = (0.0f0, 1.0f0)
u0 = [1.0f0, 0.0f0]
Expand Down

0 comments on commit 36226f9

Please sign in to comment.