Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing a function to be called multiple times with different inputs #627

Draft
wants to merge 49 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
ed280a6
Modified 1D approx test to show get_argument bug
nicholaskl97 Sep 2, 2022
1c0f0d0
Updated get_argument for eval with multiple inputs
nicholaskl97 Oct 27, 2022
63e0ddc
Forced get_argument when strategy != Quadrature
nicholaskl97 Oct 27, 2022
4e7b1b8
Test file for fixing get_argument
nicholaskl97 Oct 27, 2022
8a612dc
Test file for debugging symbolic_discretize
nicholaskl97 Oct 27, 2022
83e2475
transform_expression uses indvars now
nicholaskl97 Dec 31, 2022
13df657
Some test files
nicholaskl97 Dec 31, 2022
b17f92b
Merge branch 'master' into get_argument-fix
nicholaskl97 Dec 31, 2022
e885f45
Reverted get_argument to original state
nicholaskl97 Dec 31, 2022
74a2749
Removed temporary debug files
nicholaskl97 Dec 31, 2022
d0df2a3
Updated _vcat to accept multiple arguments
nicholaskl97 Jan 1, 2023
41a75f6
get_argument returns all args no just first per eq
nicholaskl97 Jan 12, 2023
c5d9960
Added implicit 1D and another 2D test case
nicholaskl97 Jan 12, 2023
64b56de
generate gridtrain trainsets based of pde vars
nicholaskl97 Jan 12, 2023
55fa847
added OptimJL and OptimOptimisers
nicholaskl97 Jan 12, 2023
b7e3d7a
get_bounds works with new transform_expression
nicholaskl97 Jan 12, 2023
fb199e4
Added test of ODE with hard constraint ic
nicholaskl97 Jan 12, 2023
2572dbf
_vcat now fills out scalar inputs to match batches
nicholaskl97 Jan 24, 2023
3e36fbe
cord now only has variables that show up in the eq
nicholaskl97 Jan 26, 2023
d115eae
GridTraining train_sets now work on the GPU
nicholaskl97 Feb 7, 2023
abb85a8
_vcat maintains Array types when filling
nicholaskl97 Feb 7, 2023
c7d3dc5
Formatting change
nicholaskl97 Feb 7, 2023
d9da546
StochasticTraining now actually uses bcs_points
nicholaskl97 Feb 17, 2023
18338d3
get_bounds uses bcs_points
nicholaskl97 Feb 17, 2023
cee31db
get_bounds uses get_variables
nicholaskl97 Feb 17, 2023
ea1c3b0
Merge branch 'master' into master
nicholaskl97 Feb 17, 2023
be3abf1
Increased test number of points
nicholaskl97 Feb 20, 2023
308454c
get_bounds is now okay with eqs with no variables
nicholaskl97 Feb 20, 2023
09b6cf6
symbolic_utilities doesn't need LinearAlgebra
nicholaskl97 Feb 20, 2023
6e4206b
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Feb 21, 2023
55d142a
Can now handle Ix(u(x,1)) and not just Ix(u(x,y))
nicholaskl97 Feb 21, 2023
a9b6b47
import ComponentArrays used in training_strategies
nicholaskl97 Feb 21, 2023
f815469
Added import ComponentArrays statements
nicholaskl97 Feb 22, 2023
5889a1b
Revert "Added import ComponentArrays statements"
nicholaskl97 Feb 22, 2023
424a7ef
Revert "import ComponentArrays used in training_strategies"
nicholaskl97 Feb 22, 2023
d581889
Revert "added OptimJL and OptimOptimisers"
nicholaskl97 Feb 22, 2023
edcb1a7
Replaced Lux.ComponentArray with using Co...Arrays
nicholaskl97 Feb 22, 2023
b07ae13
Formatted with JuliaFormtter
nicholaskl97 Feb 23, 2023
7a1e0b5
Docstrings were counting against code coverage
nicholaskl97 Mar 7, 2023
7f527c7
Improperly used docstrings changed to comments
nicholaskl97 Mar 8, 2023
530d50e
Added comments for _vcat
nicholaskl97 Mar 8, 2023
48c8b04
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Mar 8, 2023
e4f1536
Updated docstring for build_symbolic_loss_function
nicholaskl97 Mar 9, 2023
238b315
Reductions needed inits for cases like u(0)=0
nicholaskl97 Mar 10, 2023
44f3a28
Formatted with JuliaFormatter
nicholaskl97 Mar 10, 2023
fc7d36c
Added a new integral test
nicholaskl97 Apr 3, 2023
550ab40
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Apr 3, 2023
4dcf2a8
Merge remote-tracking branch 'origin/master'
nicholaskl97 May 29, 2023
00f07fc
Merge remote-tracking branch 'origin/master'
nicholaskl97 Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 75 additions & 29 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,81 @@ Take expressions in the form:

to

:((cord, θ, phi, derivative, u)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ[1:33], θ"[34:66])
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[1], cord[2])
[(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, θ1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, θ2))) - 0,
(+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, θ2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, θ1))) - 0]
end
end
end)

for Flux.Chain, and

:((cord, θ, phi, derivative, u)->begin
#= ... =#
#= ... =#
begin
(u1, u2) = (θ.depvar.u1, θ.depvar.u2)
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[1], cord[2])
[(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, u1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, u1))) - 0,
(+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, u2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, u2))) - 0]
end
end
end)

for Lux.AbstractExplicitLayer
:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ[1:205], θ[206:410])
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(4, derivative(phi2, u, _vcat(x, y), [[0.0, ε]], 1, θ2)), derivative(phi1, u, _vcat(x, y), [[ε, 0.0]], 1, θ1)) .- 0
Comment on lines +20 to +24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are those made and then not used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not used in this case any more. I think they may be used in the integral case, but that might not be true either. I can look through the different cases to see if they are ever used and remove them if not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it looks like deprecated code now so it would be good to just remove it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, currently, cord1 = vcat(...) is being used for integral equations only and in my efforts to see if it's possible to remove it, I've found something I broke that wasn't being tested for the integral equations, so I'm can work more on fixing that next week. In particular, I'll look for a fix that removes any need for those lines.

Copy link
Author

@nicholaskl97 nicholaskl97 Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing I realized as I was working on this was that Ix(u(sin(x)) will now be interpreted as $\int u(\sin x) dx$. However, Dx(u(sin(x)) is (under my current changes) being interpreted as $u'(\sin x)$, not $\frac{d}{dx}\left[ u(\sin x) \right] = u'(\sin x) \cos x$. They're interpreted this way because that's how the numeric integral and numeric derivative functions were already written. However, it feels a little inconsistent with the way the integral was interpreted; it's instead consistent with an interpretation of Ix(u(sin(x)) as $U(\sin x)$, where $U$ is an antiderivative of $u$.

It feels to me like Ix(u(sin(x)) is $\int u(\sin x) dx$ and Dx(u(sin(x)) is $\frac{d}{dx}\left[ u(\sin x) \right]$, but then I don't know how you would actually specify $U(\sin x)$ or $u'(\sin x)$, or if you should even be allowed to. (I'm fine not letting people use $U(\sin x)$ since it's not uniquely defined, but it feels like they should be able to use $u'(\sin x)$.)

Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, Dx(u(sin(x)) is (under my current changes) being interpreted as

That's not correct and would not play nicely. It should give the same result as what happens when basic symbolic interactions are done:

julia> using Symbolics

julia> @variables u x
2-element Vector{Num}:
 u
 x

julia> @variables u(..) x
2-element Vector{Any}:
  u
 x

julia> u(sin(x))
u(sin(x))

julia> D = Differential(x)
(::Differential) (generic function with 2 methods)

julia> D(u(sin(x)))
Differential(x)(u(sin(x)))

julia> expand_derivatives(D(u(sin(x))))
cos(x)*Differential(sin(x))(u(sin(x)))

end
end
end)

for Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0, and

:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ[1:205], θ[206:410])
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(9, derivative(phi1, u, _vcat(x, y), [[0.0, ε]], 1, θ1)), derivative(phi2, u, _vcat(x, y), [[ε, 0.0]], 1, θ2)) .- 0
end
end
end)

for Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0 (i.e., separate loss functions are created for each equation)

with Flux.Chain; and

:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ.depvar.u1, θ.depvar.u2)
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(4, derivative(phi2, u, _vcat(x, y), [[0.0, ε]], 1, θ2)), derivative(phi1, u, _vcat(x, y), [[ε, 0.0]], 1, θ1)) .- 0
end
end
end)

for Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0 and

:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ.depvar.u1, θ.depvar.u2)
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(9, derivative(phi1, u, _vcat(x, y), [[0.0, ε]], 1, θ1)), derivative(phi2, u, _vcat(x, y), [[ε, 0.0]], 1, θ2)) .- 0
end
end
end)

for Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0

with Lux.Chain
"""
function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
eq_params = SciMLBase.NullParameters(),
Expand Down