Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated the closures with @closure to avoid boxing #924

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ steps:
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

- label: "Documentation"
- label: "Documentation"p
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- label: "Documentation"p
- label: "Documentation"

plugins:
- JuliaCI/julia#v1:
version: "1"
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was for searching a keyword across all files, automating the search with a script.

Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -38,6 +39,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was throwing KeyError: key "SparseArrays" not found error continuously, this helped.

Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Expand Down
8 changes: 4 additions & 4 deletions docs/src/developer/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ eltypeθ = eltype(init_params)
phi = NeuralPDE.get_phi(chain)
derivative = NeuralPDE.get_numeric_derivative()

u_ = (cord, θ, phi) -> sum(phi(cord, θ))
u_ = @closure (cord, θ, phi) -> sum(phi(cord, θ))

phi([1, 2], init_params)

phi_ = (p) -> phi(p, init_params)[1]
phi_ = @closure (p) -> phi(p, init_params)[1]
dphi = Zygote.gradient(phi_, [1.0, 2.0])

dphi1 = derivative(phi, u_, [1.0, 2.0], [[0.0049215667, 0.0]], 1, init_params)
Expand All @@ -57,7 +57,7 @@ multioutput = chain isa AbstractArray
strategy = NeuralPDE.GridTraining(dx)
integral = NeuralPDE.get_numeric_integral(strategy, indvars, multioutput, chain, derivative)

_pde_loss_function = NeuralPDE.build_loss_function(eq, indvars, depvars, phi, derivative,
_pde_loss_function = @closure NeuralPDE.build_loss_function(eq, indvars, depvars, phi, derivative,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't a closure

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added @closure here under the assumption that it would be necessary for capturing variables, but I see that this isn't forming a closure.

integral, multioutput, init_params,
strategy)
```
Expand All @@ -82,7 +82,7 @@ julia> bc_indvars = NeuralPDE.get_variables(bcs,indvars,depvars)
```

```julia
_bc_loss_functions = [NeuralPDE.build_loss_function(bc, indvars, depvars,
_bc_loss_functions = [ @closure NeuralPDE.build_loss_function(bc, indvars, depvars,
phi, derivative, integral, multioutput,
init_params, strategy,
bc_indvars = bc_indvar)
Expand Down
26 changes: 21 additions & 5 deletions docs/src/examples/nonlinear_elliptic.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,18 @@ root(x) = f(x) - g(x)

# Analytic solution
k = find_zero(root, (0, 1), Bisection()) # k is a root of the algebraic (transcendental) equation f(x) = g(x)
θ(x, y) = (cosh(sqrt(f(k)) * x) + sinh(sqrt(f(k)) * x)) * (y + 1) # Analytical solution to Helmholtz equation
w_analytic(x, y) = θ(x, y) - h(k) / f(k)
u_analytic(x, y) = k * w_analytic(x, y)
θ = let k = k
(x, y) -> (cosh(sqrt(f(k)) * x) + sinh(sqrt(f(k)) * x)) * (y + 1)
end

w_analytic = let θ = θ, h_k = h(k) / f(k) # Closure for analytic function
(x, y) -> θ(x, y) - h_k
end

u_analytic = let k = k, w_analytic = w_analytic # Closure for u_analytic
(x, y) -> k * w_analytic(x, y)
end


# Nonlinear Steady-State Systems of Two Reaction-Diffusion Equations with 3 arbitrary function f, g, h
eqs_ = [
Expand Down Expand Up @@ -105,14 +114,21 @@ res = solve(prob, BFGS(); maxiters = 100, callback)
phi = discretization.phi

# Analysis
# Analysis with closure
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why mention this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PR #900, it was mentioned to annotate closures with @closure to avoid boxing. I reviewed the relevant places where closures might be necessary and added the annotations and mentioned here the analysis is with closure.

xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
depvars = [:u, :w]
minimizers_ = [res.u.depvar[depvars[i]] for i in 1:2]

analytic_sol_func(x, y) = [u_analytic(x, y), w_analytic(x, y)]
analytic_sol_func = let u_analytic = u_analytic, w_analytic = w_analytic # Closure for analytic function
(x, y) -> [u_analytic(x, y), w_analytic(x, y)]
end

u_real = [[analytic_sol_func(x, y)[i] for x in xs for y in ys] for i in 1:2]
u_predict = [[phi[i]([x, y], minimizers_[i])[1] for x in xs for y in ys] for i in 1:2]
u_predict = let phi = phi, minimizers_ = minimizers_ # Closure for predicted values
[[phi[i]([x, y], minimizers_[i])[1] for x in xs for y in ys] for i in 1:2]
end
diff_u = [abs.(u_real[i] .- u_predict[i]) for i in 1:2]

ps = []
for i in 1:2
p1 = plot(xs, ys, u_real[i], linetype = :contourf, title = "u$i, analytic")
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/Lotka_Volterra_BPINNs.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Let's define a PINN.

```@example bpinn
# Neural Networks must have 2 outputs as u -> [dx,dy] in function lotka_volterra()
chain = Chain(Dense(1, 6, tanh), Dense(6, 6, tanh), Dense(6, 2))
chain = @closure Chain(Dense(1, 6, tanh), Dense(6, 6, tanh), Dense(6, 2))
```

The dataset we generated can be passed for doing parameter estimation using provided priors in `param` keyword argument for [`BNNODE`](@ref).
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ chain = Lux.Chain(Dense(1, inn, Lux.σ),
lb = [x_0]
ub = [x_end]
function norm_loss_function(phi, θ, p)
function inner_f(x, θ)
@closure function inner_f(x, θ)
0.01 * phi(x, θ) .- 1
end
prob = IntegralProblem(inner_f, lb, ub, θ)
Expand Down
6 changes: 3 additions & 3 deletions docs/src/tutorials/derivative_neural_network.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ domains = [t ∈ Interval(0.0, 1.0), x ∈ Interval(0.0, 1.0)]

input_ = length(domains)
n = 15
chain = [Chain(Dense(input_, n, σ), Dense(n, n, σ), Dense(n, 1)) for _ in 1:7]
chain = [@closure Chain(Dense(input_, n, σ), Dense(n, n, σ), Dense(n, 1)) for _ in 1:7]

training_strategy = StochasticTraining(128)
discretization = PhysicsInformedNN(chain, training_strategy)
Expand All @@ -107,7 +107,7 @@ pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions
bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions[1:7]
approx_derivative_loss_functions = sym_prob.loss_functions.bc_loss_functions[9:end]

callback = function (p, l)
callback = @closure function (p, l)
println("loss: ", l)
println("pde_losses: ", map(l_ -> l_(p.u), pde_inner_loss_functions))
println("bcs_losses: ", map(l_ -> l_(p.u), bcs_inner_loss_functions))
Expand All @@ -128,7 +128,7 @@ And some analysis:
using Plots

ts, xs = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
minimizers_ = [res.u.depvar[sym_prob.depvars[i]] for i in 1:length(chain)]
minimizers_ = [@closure res.u.depvar[sym_prob.depvars[i]] for i in 1:length(chain)]

u1_real(t, x) = exp(-t) * sinpi(x)
u2_real(t, x) = exp(-t) * cospi(x)
Expand Down
4 changes: 2 additions & 2 deletions docs/src/tutorials/low_level.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ phi = sym_prob.phi
pde_loss_functions = sym_prob.loss_functions.pde_loss_functions
bc_loss_functions = sym_prob.loss_functions.bc_loss_functions

callback = function (p, l)
@closure callback = function (p, l)
println("loss: ", l)
println("pde_losses: ", map(l_ -> l_(p.u), pde_loss_functions))
println("bcs_losses: ", map(l_ -> l_(p.u), bc_loss_functions))
Expand All @@ -60,7 +60,7 @@ end

loss_functions = [pde_loss_functions; bc_loss_functions]

loss_function(θ, p) = sum(map(l -> l(θ), loss_functions))
@closure loss_function(θ, p) = sum(map(l -> l(θ), loss_functions))

f_ = OptimizationFunction(loss_function, AutoZygote())
prob = OptimizationProblem(f_, sym_prob.flat_init_params)
Expand Down
6 changes: 3 additions & 3 deletions docs/src/tutorials/pino_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using NeuralOperators
using NeuralPDE

# Define the parametric ODE equation
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
@closure equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)
Expand Down Expand Up @@ -46,7 +46,7 @@ Now let's compare the prediction from the learned operator with the ground truth
```@example pino
using Plots

function get_trainset(bounds, tspan, number_of_parameters, dt)
@closure function get_trainset(bounds, tspan, number_of_parameters, dt)
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
t_ = collect(tspan[1]:dt:tspan[2])
Expand All @@ -55,7 +55,7 @@ function get_trainset(bounds, tspan, number_of_parameters, dt)
end

# Compute the ground truth solution for each parameter
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t
@closure ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t
function ground_solution_f(p, t)
reduce(hcat,
[[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
Expand Down
5 changes: 3 additions & 2 deletions docs/src/tutorials/systems.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ sym_prob = symbolic_discretize(pdesystem, discretization)
pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions
bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions

callback = function (p, l)
callback = @closure function (p, l)
println("loss: ", l)
println("pde_losses: ", map(l_ -> l_(p.u), pde_inner_loss_functions))
println("bcs_losses: ", map(l_ -> l_(p.u), bcs_inner_loss_functions))
Expand All @@ -106,7 +106,8 @@ bc_loss_functions = sym_prob.loss_functions.bc_loss_functions

loss_functions = [pde_loss_functions; bc_loss_functions]

loss_function(θ, _) = sum(l -> l(θ), loss_functions)
loss_function = @closure (θ, _) -> sum(l -> l(θ), loss_functions)


f_ = OptimizationFunction(loss_function, AutoZygote())
prob = OptimizationProblem(f_, sym_prob.flat_init_params)
Expand Down
8 changes: 4 additions & 4 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ change during optimization
@concrete mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
end

function NonAdaptiveLoss{T}(; pde_loss_weights = 1.0, bc_loss_weights = 1.0,
Expand All @@ -28,7 +28,7 @@ end

NonAdaptiveLoss(; kwargs...) = NonAdaptiveLoss{Float64}(; kwargs...)

function generate_adaptive_loss_function(::PINNRepresentation, ::NonAdaptiveLoss, _, __)
@closure function generate_adaptive_loss_function(::PINNRepresentation, ::NonAdaptiveLoss, _, __)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong spot

return Returns(nothing)
end

Expand Down Expand Up @@ -83,7 +83,7 @@ function GradientScaleAdaptiveLoss(args...; kwargs...)
return GradientScaleAdaptiveLoss{Float64}(args...; kwargs...)
end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
@closure function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::GradientScaleAdaptiveLoss, pde_loss_functions, bc_loss_functions)
weight_change_inertia = adaloss.weight_change_inertia
iteration = pinnrep.iteration
Expand Down Expand Up @@ -168,7 +168,7 @@ end

MiniMaxAdaptiveLoss(args...; kwargs...) = MiniMaxAdaptiveLoss{Float64}(args...; kwargs...)

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
@closure function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the wrong spot

adaloss::MiniMaxAdaptiveLoss, _, __)
pde_max_optimiser_setup = Optimisers.setup(
adaloss.pde_max_optimiser, adaloss.pde_loss_weights)
Expand Down
8 changes: 5 additions & 3 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,12 @@ function get_numeric_integral(pinnrep::PINNRepresentation)
end
end
integration_arr = Matrix{Float64}(undef, 1, 0)
for i in 1:size(cord, 2)
integration_arr = hcat(integration_arr,
integration_(cord[:, i], lb_[:, i], ub_[:, i], θ))

for i in axes(cord, 2)
integration_arr = hcat(integration_arr,
integration_(cord[:, i], lb_[:, i], ub_[:, i], θ))
end

return integration_arr
end
end
Expand Down
13 changes: 11 additions & 2 deletions src/eltype_matching.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
struct EltypeAdaptor{T} end

(l::EltypeAdaptor)(x) = fmap(Adapt.adapt(l), x)
function ensure_same_device(x, device)
if (typeof(x) != device) && !(x isa Number)
error("Device mismatch detected. Ensure all data is on the same device.")
end
return x
end


(l::EltypeAdaptor)(x) = fmap(y -> ensure_same_device(y, l), x)

function (l::EltypeAdaptor)(x::AbstractArray{T}) where {T}
return (isbitstype(T) || T <: Number) ? Adapt.adapt(l, x) : map(l, x)
return (isbitstype(T) || T <: Number) ? x : map(y -> ensure_same_device(y, l), x)
end

function Adapt.adapt_storage(::EltypeAdaptor{T}, x::AbstractArray) where {T}
Expand Down
Loading