-
-
Notifications
You must be signed in to change notification settings - Fork 209
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #790 from SciML/DAE_problem
DAE problem
- Loading branch information
Showing
8 changed files
with
372 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Differential Algebraic Equation Specialized Physics-Informed Neural Solver | ||
|
||
```@docs | ||
NNDAE | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Solving DAEs with Physics-Informed Neural Networks (PINNs) | ||
|
||
!!! note | ||
|
||
It is highly recommended you first read the [solving ordinary differential | ||
equations with DifferentialEquations.jl tutorial](https://docs.sciml.ai/DiffEqDocs/stable/tutorials/dae_example/) before reading this tutorial. | ||
|
||
|
||
This tutorial is an introduction to using physics-informed neural networks (PINNs) for solving differential algebraic equations (DAEs). | ||
|
||
|
||
## Solving an DAE with PINNs | ||
|
||
Let's solve a simple DAE system: | ||
|
||
```@example dae | ||
using NeuralPDE | ||
using Random, Flux | ||
using OrdinaryDiffEq, Optimisers, Statistics | ||
import Lux, OptimizationOptimisers, OptimizationOptimJL | ||
example = (du, u, p, t) -> [cos(2pi * t) - du[1], u[2] + cos(2pi * t) - du[2]] | ||
u₀ = [1.0, -1.0] | ||
du₀ = [0.0, 0.0] | ||
tspan = (0.0f0, 1.0f0) | ||
``` | ||
|
||
Differential_vars is a logical array which declares which variables are the differential (non-algebraic) vars | ||
|
||
```@example dae | ||
differential_vars = [true, false] | ||
``` | ||
|
||
```@example dae | ||
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) | ||
chain = Flux.Chain(Dense(1, 15, cos), Dense(15, 15, sin), Dense(15, 2)) | ||
opt = OptimizationOptimisers.Adam(0.1) | ||
alg = NNDAE(chain, opt; autodiff = false) | ||
sol = solve(prob, | ||
alg, verbose = false, dt = 1 / 100.0f0, | ||
maxiters = 3000, abstol = 1.0f-10) | ||
``` | ||
|
||
Now lets compare the predictions from the learned network with the ground truth which we can obtain by numerically solving the DAE. | ||
|
||
```@example dae | ||
function example1(du, u, p, t) | ||
du[1] = cos(2pi * t) | ||
du[2] = u[2] + cos(2pi * t) | ||
nothing | ||
end | ||
M = [1.0 0 | ||
0 0] | ||
f = ODEFunction(example1, mass_matrix = M) | ||
prob_mm = ODEProblem(f, u₀, tspan) | ||
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) | ||
``` | ||
|
||
```@example dae | ||
using Plots | ||
plot(ground_sol, tspan = tspan, layout = (2, 1), label = "ground truth") | ||
plot!(sol, tspan = tspan, layout = (2, 1), label = "dae with pinns") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
""" | ||
```julia | ||
NNDAE(chain, | ||
OptimizationOptimisers.Adam(0.1), | ||
init_params = nothing; | ||
autodiff = false, | ||
kwargs...) | ||
``` | ||
Algorithm for solving differential algebraic equationsusing a neural network. This is a specialization | ||
of the physics-informed neural network which is used as a solver for a standard `DAEProblem`. | ||
!!! warn | ||
Note that NNDAE only supports DAEs which are written in the out-of-place form, i.e. | ||
`du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the NNDAE | ||
will exit with an error. | ||
## Positional Arguments | ||
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractExplicitLayer`. | ||
* `opt`: The optimizer to train the neural network. | ||
* `init_params`: The initial parameter of the neural network. By default, this is `nothing` | ||
which thus uses the random initialization provided by the neural network library. | ||
## Keyword Arguments | ||
* `autodiff`: The switch between automatic(not supported yet) and numerical differentiation for | ||
the PDE operators. The reverse mode of the loss function is always | ||
automatic differentiation (via Zygote), this is only for the derivative | ||
in the loss function (the derivative with respect to time). | ||
* `strategy`: The training strategy used to choose the points for the evaluations. | ||
By default, `GridTraining` is used with `dt` if given. | ||
""" | ||
struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy} | ||
} <: DiffEqBase.AbstractDAEAlgorithm | ||
chain::C | ||
opt::O | ||
init_params::P | ||
autodiff::Bool | ||
strategy::S | ||
kwargs::K | ||
end | ||
|
||
function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, | ||
kwargs...) | ||
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) | ||
end | ||
|
||
function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool, differential_vars) | ||
if autodiff | ||
autodiff && throw(ArgumentError("autodiff not supported for DAE problem.")) | ||
else | ||
dphi = (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t))) | ||
batch_size = size(t)[1] | ||
|
||
reduce(vcat, | ||
[if dv == true | ||
dphi[[i], :] | ||
else | ||
zeros(1, batch_size) | ||
end | ||
for (i, dv) in enumerate(differential_vars)]) | ||
end | ||
end | ||
|
||
function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, | ||
p, differential_vars) where {C, T, U} | ||
out = Array(phi(t, θ)) | ||
dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars)) | ||
arrt = Array(t) | ||
loss = reduce(hcat, [f(dphi[:, i], out[:, i], p, arrt[i]) for i in 1:size(out, 2)]) | ||
sum(abs2, loss) / length(t) | ||
end | ||
|
||
function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, | ||
differential_vars) | ||
ts = tspan[1]:(strategy.dx):tspan[2] | ||
autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) | ||
function loss(θ, _) | ||
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) | ||
end | ||
return loss | ||
end | ||
|
||
function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem, | ||
alg::NNDAE, | ||
args...; | ||
dt = nothing, | ||
# timeseries_errors = true, | ||
save_everystep = true, | ||
# adaptive = false, | ||
abstol = 1.0f-6, | ||
reltol = 1.0f-3, | ||
verbose = false, | ||
saveat = nothing, | ||
maxiters = nothing, | ||
tstops = nothing) | ||
u0 = prob.u0 | ||
du0 = prob.du0 | ||
tspan = prob.tspan | ||
f = prob.f | ||
p = prob.p | ||
t0 = tspan[1] | ||
|
||
#hidden layer | ||
chain = alg.chain | ||
opt = alg.opt | ||
autodiff = alg.autodiff | ||
|
||
#train points generation | ||
init_params = alg.init_params | ||
|
||
# A logical array which declares which variables are the differential (non-algebraic) vars | ||
differential_vars = prob.differential_vars | ||
|
||
if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain | ||
phi, init_params = generate_phi_θ(chain, t0, u0, init_params) | ||
else | ||
error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported") | ||
end | ||
|
||
if isinplace(prob) | ||
throw(error("The NNODE solver only supports out-of-place DAE definitions, i.e. du=f(u,p,t).")) | ||
end | ||
|
||
try | ||
phi(t0, init_params) | ||
catch err | ||
if isa(err, DimensionMismatch) | ||
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match")) | ||
else | ||
throw(err) | ||
end | ||
end | ||
|
||
strategy = if alg.strategy === nothing | ||
if dt !== nothing | ||
GridTraining(dt) | ||
else | ||
error("dt is not defined") | ||
end | ||
end | ||
|
||
inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars) | ||
|
||
# Creates OptimizationFunction Object from total_loss | ||
total_loss(θ, _) = inner_f(θ, phi) | ||
|
||
# Optimization Algo for Training Strategies | ||
opt_algo = Optimization.AutoZygote() | ||
# Creates OptimizationFunction Object from total_loss | ||
optf = OptimizationFunction(total_loss, opt_algo) | ||
|
||
iteration = 0 | ||
callback = function (p, l) | ||
iteration += 1 | ||
verbose && println("Current loss is: $l, Iteration: $iteration") | ||
l < abstol | ||
end | ||
optprob = OptimizationProblem(optf, init_params) | ||
res = solve(optprob, opt; callback, maxiters, alg.kwargs...) | ||
|
||
#solutions at timepoints | ||
if saveat isa Number | ||
ts = tspan[1]:saveat:tspan[2] | ||
elseif saveat isa AbstractArray | ||
ts = saveat | ||
elseif dt !== nothing | ||
ts = tspan[1]:dt:tspan[2] | ||
elseif save_everystep | ||
ts = range(tspan[1], tspan[2], length = 100) | ||
else | ||
ts = [tspan[1], tspan[2]] | ||
end | ||
|
||
if u0 isa Number | ||
u = [first(phi(t, res.u)) for t in ts] | ||
else | ||
u = [phi(t, res.u) for t in ts] | ||
end | ||
|
||
sol = DiffEqBase.build_solution(prob, alg, ts, u; | ||
k = res, dense = true, | ||
calculate_error = false, | ||
retcode = ReturnCode.Success) | ||
DiffEqBase.has_analytic(prob.f) && | ||
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, | ||
dense_errors = false) | ||
sol | ||
end #solve |
Oops, something went wrong.