Skip to content

Commit

Permalink
fix: incorrect DGM architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 14, 2024
1 parent 196a252 commit 63062cc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
18 changes: 10 additions & 8 deletions src/dgm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function (layer::DGMLSTMLayer)((S, x), ps, st::NamedTuple)
return S_new, st
end

dgm_lstm_block_rearrange(mx, (S, x)) = (mx, x)
dgm_lstm_block_rearrange(Sᵢ₊₁, (Sᵢ, x)) = Sᵢ₊₁, x

function DGMLSTMBlock(layers...)
blocks = AbstractLuxLayer[]
Expand Down Expand Up @@ -94,16 +94,18 @@ f(t, x, \\theta) &= \\sigma_{out}(W S^{L+1} + b).
"""
function DGM(in_dims::Int, out_dims::Int, modes::Int, layers::Int,
activation1, activation2, out_activation)
return DGM(Chain(SkipConnection(
Dense(in_dims => modes, activation1; init_bias = zeros32),
DGMLSTMBlock([DGMLSTMLayer(in_dims, modes, activation1, activation2)
for _ in 1:layers]...),
Dense(modes => out_dims, out_activation; init_bias = zeros32))))
return DGM(Chain(
SkipConnection(
Dense(in_dims => modes, activation1),
DGMLSTMBlock([DGMLSTMLayer(in_dims, modes, activation1, activation2)
for _ in 1:layers]...)),
Dense(modes => out_dims, out_activation)))
end

"""
DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function,
strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)
DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function,
activation2::Function, out_activation::Function, strategy::AbstractTrainingStrategy;
kwargs...)
returns a `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a `PDESystem` into an `OptimizationProblem` using the Deep Galerkin method.
Expand Down
26 changes: 8 additions & 18 deletions test/dgm_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,8 @@ import ModelingToolkit: Interval, infimum, supremum
@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
prob = discretize(pde_system, discretization)

global iter = 0
callback = function (p, l)
global iter += 1
if iter % 50 == 0
println("$iter => $l")
end
p.iter % 50 == 0 && println("$(p.iter) => $l")
return false
end

Expand All @@ -48,7 +44,8 @@ import ModelingToolkit: Interval, infimum, supremum
(length(xs), length(ys)))
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))
@test maximum(abs, u_predict - u_real) < 0.1

@test u_realu_predict atol=0.01 norm=Base.Fix2(norm, Inf)
end

@testset "Black-Scholes PDE: European Call Option" begin
Expand Down Expand Up @@ -79,16 +76,12 @@ end
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [g(t, x)])
prob = discretize(pde_system, discretization)

global iter = 0
callback = function (p, l)
global iter += 1
if iter % 50 == 0
println("$iter => $l")
end
p.iter % 50 == 0 && println("$(p.iter) => $l")
return false
end

res = Optimization.solve(prob, Adam(0.1); callback, maxiters = 500)
res = Optimization.solve(prob, Adam(0.1); callback, maxiters = 100)
prob = remake(prob, u0 = res.u)
res = Optimization.solve(prob, Adam(0.01); callback, maxiters = 500)
phi = discretization.phi
Expand All @@ -106,7 +99,7 @@ end

u_real = [analytic_sol_func(t, x) for t in ts, x in xs]
u_predict = [first(phi([t, x], res.u)) for t in ts, x in xs]
@test_broken u_predictu_real rtol=0.05
@test u_predictu_real rtol=0.05
end

@testset "Burger's equation" begin
Expand Down Expand Up @@ -144,12 +137,9 @@ end
discretization = DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy)
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t, x)])
prob = discretize(pde_system, discretization)
global iter = 0

callback = function (p, l)
global iter += 1
if iter % 20 == 0
println("$iter => $l")
end
p.iter % 50 == 0 && println("$(p.iter) => $l")
return false
end

Expand Down

0 comments on commit 63062cc

Please sign in to comment.