Skip to content

Commit

Permalink
Merge pull request #211 from YichengDWu/hybrid
Browse files Browse the repository at this point in the history
Allow hybrid derivative
  • Loading branch information
YichengDWu authored May 26, 2023
2 parents 682b406 + e6375bb commit e02be90
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/pde/discretize.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
function build_loss_function(pde_system::ModelingToolkit.PDESystem, pinn::PINN,
strategy::AbstractTrainingAlg; derivative=finitediff, fdtype=Float64)
strategy::AbstractTrainingAlg, derivative,
derivative_bc, fdtype)
(; eqs, bcs, domain, ps, defaults, indvars, depvars) = pde_system
(; phi, init_params) = pinn

Expand All @@ -21,6 +22,7 @@ function build_loss_function(pde_system::ModelingToolkit.PDESystem, pinn::PINN,
datafree_pde_loss_functions = Tuple(build_loss_function(pinnrep, eq, i)
for (i, eq) in enumerate(eqs))

pinnrep = Lux.@set pinnrep.derivative = derivative_bc
datafree_bc_loss_functions = Tuple(build_loss_function(pinnrep, bc,
i +
length(datafree_pde_loss_functions))
Expand All @@ -32,7 +34,8 @@ function build_loss_function(pde_system::ModelingToolkit.PDESystem, pinn::PINN,
end

function build_loss_function(pde_system::PDESystem, pinn::PINN,
strategy::AbstractTrainingAlg; derivative=finitediff, fdtype=Float64)
strategy::AbstractTrainingAlg, derivative,
derivative_bc, fdtype)
(; eqs, bcs, ivs, dvs) = pde_system
(; phi, init_params) = pinn

Expand All @@ -47,6 +50,7 @@ function build_loss_function(pde_system::PDESystem, pinn::PINN,
datafree_pde_loss_functions = Tuple(build_loss_function(pinnrep, first(eq), i)
for (i, eq) in enumerate(eqs))

pinnrep = Lux.@set pinnrep.derivative = derivative_bc
datafree_bc_loss_functions = Tuple(build_loss_function(pinnrep, first(bc),
i +
length(datafree_pde_loss_functions))
Expand All @@ -57,6 +61,7 @@ function build_loss_function(pde_system::PDESystem, pinn::PINN,
return pde_and_bcs_loss_function
end

#=
function build_loss_function(pde_system::ParametricPDESystem, pinn::PINN,
strategy::AbstractTrainingAlg, coord_branch_net;
derivative=finitediff)
Expand Down Expand Up @@ -84,9 +89,11 @@ function build_loss_function(pde_system::ParametricPDESystem, pinn::PINN,
datafree_bc_loss_functions)
return pde_and_bcs_loss_function
end
=#

"""
discretize(pde_system::PDESystem, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg;
strategy::AbstractTrainingAlg; derivative=finitediff,
additional_loss)
Convert the PDESystem into an `OptimizationProblem`. You will have access to each loss function
Expand All @@ -95,17 +102,18 @@ Convert the PDESystem into an `OptimizationProblem`. You will have access to eac
function discretize(pde_system, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg;
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
fdtype=Float64, adtype=Optimization.AutoZygote())
derivative_bc = derivative, fdtype=Float64,
adtype=Optimization.AutoZygote())
datasets = sample(pde_system, sampler)
init_params = Lux.fmap(Base.Fix1(broadcast, fdtype), pinn.init_params)
init_params = _ComponentArray(init_params)

datasets = map(Base.Fix1(broadcast, fdtype), datasets)
datasets = init_params isa AbstractGPUComponentVector ?
map(Base.Fix1(adapt, CuArray), datasets) : datasets
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy;
derivative=derivative,
fdtype=fdtype)
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy,
derivative, derivative_bc,
fdtype)

function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(pinn.phi, θ)
Expand All @@ -114,7 +122,8 @@ function discretize(pde_system, pinn::PINN, sampler::PINNSampler,
return Optimization.OptimizationProblem(f, init_params, datasets)
end

function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSampler,
# ParametricPDESystem no long supported
#=function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg, functionsampler::FunctionSampler,
coord_branch_net::AbstractArray;
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
Expand All @@ -132,8 +141,8 @@ function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSa
coord_branch_net = coord_branch_net isa Union{AbstractVector, StepRangeLen} ?
[coord_branch_net] : coord_branch_net
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy,
coord_branch_net; derivative=derivative,
fdtype=fdtype)
coord_branch_net, derivative,
fdtype)
function full_loss_function(θ, p)
return pde_and_bcs_loss_function(θ, p) + additional_loss(pinn.phi, θ)
end
Expand All @@ -142,11 +151,13 @@ function discretize(pde_system::ParametricPDESystem, pinn::PINN, sampler::PINNSa
p = PINOParameterHandler(datasets, pfs)
return Optimization.OptimizationProblem(f, init_params, p)
end
=#

function symbolic_discretize(pde_system, pinn::PINN, sampler::PINNSampler,
strategy::AbstractTrainingAlg;
additional_loss=Sophon.null_additional_loss, derivative=finitediff,
adtype=Optimization.AutoZygote(), fdtype=Float64)
derivative_bc = derivative, fdtype=Float64,
adtype=Optimization.AutoZygote())
(; eqs, bcs, domain, ps, defaults, indvars, depvars) = pde_system
(; phi, init_params) = pinn

Expand All @@ -169,6 +180,8 @@ function symbolic_discretize(pde_system, pinn::PINN, sampler::PINNSampler,
args, body = build_symbolic_loss_function(pinnrep, eq)
return :($args -> $body)
end

pinnrep = Lux.@set pinnrep.derivative = derivative_bc
bc_loss_function = map(bcs) do bc
args, body = build_symbolic_loss_function(pinnrep, bc)
return :($args -> $body)
Expand Down
31 changes: 31 additions & 0 deletions test/hybriddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using ModelingToolkit, DomainSets, TaylorDiff, Sophon, Test
using Optimization, OptimizationOptimJL, TaylorDiff

@parameters x,t
@variables u(..), v(..)
Dₜ = Differential(t)
Dₓ² = Differential(x)^2

eqs=[Dₜ(u(x,t)) ~ -Dₓ²(v(x,t))/2 - (abs2(v(x,t)) + abs2(u(x,t))) * v(x,t),
Dₜ(v(x,t)) ~ Dₓ²(u(x,t))/2 + (abs2(v(x,t)) + abs2(u(x,t))) * u(x,t)]

bcs = [u(x, 0.0f0) ~ 2sech(x),
v(x, 0.0f0) ~ 0.0f0,
u(-5.0f0, t) ~ u(5.0f0, t),
v(-5.0f0, t) ~ v(5.0f0, t)]

domains = [x Interval(-5.0f0, 5.0f0),
t Interval(0.0f0, π/2f0)]

@named pde_system = PDESystem(eqs, bcs, domains, [x,t], [u(x,t),v(x,t)])

finitediff = Sophon.finitediff
taylordiff = isdefined(Base, :get_extension) ? Sophon.taylordiff : Sophon.SophonTaylorDiffExt.taylordiff

pinn = PINN(u = Siren(2,1; hidden_dims=16,num_layers=4, omega = 1.0),
v = Siren(2,1; hidden_dims=16,num_layers=4, omega = 1.0))

sampler = QuasiRandomSampler(500, (200,200,20,20))
strategy = NonAdaptiveTraining(1,(10,10,1,1))

@test_nowarn Sophon.discretize(pde_system, pinn, sampler, strategy; derivative=finitediff, derivative_bc=taylordiff)

0 comments on commit e02be90

Please sign in to comment.