From a45357666700f0eb990513b30ea4a8c1915a4510 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 15 Oct 2024 20:23:26 -0400 Subject: [PATCH] fix: different device handling --- Project.toml | 2 + src/NeuralPDE.jl | 1 + src/advancedHMC_MCMC.jl | 5 ++- src/neural_adapter.jl | 6 +-- src/ode_solve.jl | 19 ++++++++-- src/pinn_types.jl | 5 ++- src/rode_solve.jl | 17 ++++++--- src/training_strategies.jl | 74 ++++++++++++++++++++++--------------- test/NNPDE_tests_gpu_Lux.jl | 47 +++++++---------------- 9 files changed, 98 insertions(+), 78 deletions(-) diff --git a/Project.toml b/Project.toml index 6449919394..95f333523d 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -70,6 +71,7 @@ LuxCUDA = "0.3.3" LuxCore = "1.0.1" LuxLib = "1.3.2" MCMCChains = "6" +MLDataDevices = "1.2.0" MethodOfLines = "0.11.6" ModelingToolkit = "9.46" MonteCarloMeasurements = "1.1" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 2a161883f8..fc044ea45a 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -25,6 +25,7 @@ using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer using Lux: FromFluxAdaptor, recursive_eltype using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer using MCMCChains: MCMCChains, Chains, sample +using MLDataDevices: cpu_device, get_device using ModelingToolkit: ModelingToolkit, Num, PDESystem, toexpr, expand_derivatives, infimum, supremum using MonteCarloMeasurements: Particles diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 0597167fdf..c45b97e85e 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -19,7 +19,10 @@ NN OUTPUT AT t,θ ~ phi(t,θ). """ function (f::LogTargetDensity)(t::AbstractVector, θ) θ = vector_to_parameters(θ, f.init_params) - return f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* f.smodel(t', θ) + dev = get_device(θ) + t = t |> dev + u0 = f.prob.u0 |> dev + return u0 .+ (t' .- f.prob.tspan[1]) .* f.smodel(t', θ) end (f::LogTargetDensity)(t::Number, θ) = f([t], θ)[:, 1] diff --git a/src/neural_adapter.jl b/src/neural_adapter.jl index 23a78e476b..9c9580fee5 100644 --- a/src/neural_adapter.jl +++ b/src/neural_adapter.jl @@ -38,7 +38,7 @@ function get_loss_function_neural_adapter( eqs isa Array || (eqs = [eqs]) eltypeθ = recursive_eltype(init_params) train_set = generate_training_sets(pde_system.domain, strategy.dx, eqs, eltypeθ) - return get_loss_function(loss, train_set, eltypeθ, strategy) + return get_loss_function(init_params, loss, train_set, eltypeθ, strategy) end function get_loss_function_neural_adapter(loss, init_params, pde_system, @@ -51,7 +51,7 @@ function get_loss_function_neural_adapter(loss, init_params, pde_system, eltypeθ = recursive_eltype(init_params) bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy) - return get_loss_function(loss, bound, eltypeθ, strategy) + return get_loss_function(init_params, loss, bound, eltypeθ, strategy) end function get_loss_function_neural_adapter( @@ -64,7 +64,7 @@ function get_loss_function_neural_adapter( eltypeθ = recursive_eltype(init_params) bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy) - return get_loss_function(loss, bound[1][1], bound[2][1], eltypeθ, strategy) + return get_loss_function(init_params, loss, bound[1][1], bound[2][1], eltypeθ, strategy) end """ diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 232c9c3df0..47c13aba75 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -113,6 +113,8 @@ respects boundary conditions, i.e. `phi(t) = u0 + t*NN(t)`. smodel <: StatefulLuxLayer end +Functors.@functor ODEPhi (u0, t0) + function ODEPhi(model::AbstractLuxLayer, t0::Number, u0, st) return ODEPhi(u0, t0, StatefulLuxLayer{true}(model, nothing, st)) end @@ -127,13 +129,22 @@ function generate_phi_θ(chain::AbstractLuxLayer, t, u0, init_params) return ODEPhi(chain, t, u0, st), init_params end -(f::ODEPhi{<:Number})(t::Number, θ) = f.u0 + (t - f.t0) * first(f.smodel([t], θ.depvar)) +function (f::ODEPhi)(t, θ) + dev = get_device(θ) + return (dev(f))(dev, dev(t), θ) +end -(f::ODEPhi{<:Number})(t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar) +function (f::ODEPhi{<:Number})(dev, t::Number, θ) + return f.u0 + (t - f.t0) * first(f.smodel(dev([t]), θ.depvar)) +end + +function (f::ODEPhi{<:Number})(_, t::AbstractVector, θ) + return f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar) +end -(f::ODEPhi)(t::Number, θ) = f.u0 .+ (t .- f.t0) .* f.smodel([t], θ.depvar) +(f::ODEPhi)(dev, t::Number, θ) = f.u0 .+ (t .- f.t0) .* f.smodel(dev([t]), θ.depvar) -(f::ODEPhi)(t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar) +(f::ODEPhi)(_, t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar) """ ode_dfdx(phi, t, θ, autodiff) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 7e7c2e4aed..0f848183f1 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -38,9 +38,9 @@ function Phi(layer::AbstractLuxLayer) layer, nothing, initialstates(Random.default_rng(), layer))) end -(f::Phi)(x::Number, θ) = f([x], θ)[1] +(f::Phi)(x::Number, θ) = (f([x], θ) |> cpu_device())[1] -(f::Phi)(x::AbstractArray, θ) = f.smodel(x, θ) +(f::Phi)(x::AbstractArray, θ) = f.smodel(get_device(θ)(x), θ) """ PhysicsInformedNN(chain, strategy; init_params = nothing, phi = nothing, @@ -357,6 +357,7 @@ get_u() = (cord, θ, phi) -> phi(cord, θ) function numeric_derivative(phi, u, x, εs, order, θ) ε = εs[order] _epsilon = inv(first(ε[ε .!= zero(ε)])) + ε = ε |> get_device(x) # any(x->x!=εs[1],εs) # εs is the epsilon for each order, if they are all the same then we use a fancy formula diff --git a/src/rode_solve.jl b/src/rode_solve.jl index 6f06b98bc2..0c0cd176ed 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -20,19 +20,26 @@ end smodel <: StatefulLuxLayer end +Functors.@functor RODEPhi (u0, t0) + RODEPhi(phi::ODEPhi) = RODEPhi(phi.u0, phi.t0, phi.smodel) -function (f::RODEPhi{<:Number})(t::Number, W, θ) - return f.u0 + (t - f.t0) * first(f.smodel([t, W], θ.depvar)) +function (f::RODEPhi)(t, W, θ) + dev = get_device(θ) + return (dev(f))(dev, dev(t), dev(W), θ) +end + +function (f::RODEPhi{<:Number})(dev, t::Number, W, θ) + return f.u0 + (t - f.t0) * first(f.smodel(dev([t, W]), θ.depvar)) end -function (f::RODEPhi{<:Number})(t::AbstractVector, W, θ) +function (f::RODEPhi{<:Number})(_, t::AbstractVector, W, θ) return f.u0 .+ (t' .- f.t0) .* f.smodel(vcat(t', W'), θ.depvar) end -(f::RODEPhi)(t::Number, W, θ) = f.u0 .+ (t .- f.t0) .* f.smodel([t, W], θ.depvar) +(f::RODEPhi)(dev, t::Number, W, θ) = f.u0 .+ (t .- f.t0) .* f.smodel(dev([t, W]), θ.depvar) -function (f::RODEPhi)(t::AbstractVector, W, θ) +function (f::RODEPhi)(_, t::AbstractVector, W, θ) return f.u0 .+ (t' .- f.t0) .* f.smodel(vcat(t', W'), θ.depvar) end diff --git a/src/training_strategies.jl b/src/training_strategies.jl index fefc50a9d2..a84191539d 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -25,7 +25,7 @@ function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation, # vector of points (pde_train_sets must be rowwise) pde_loss_functions = if train_sets_pde !== nothing pde_train_sets = [train_set[:, 2:end] for train_set in train_sets_pde] |> adaptor - [get_loss_function(_loss, _set, eltypeθ, strategy) + [get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy) for (_loss, _set) in zip(datafree_pde_loss_function, pde_train_sets)] else nothing @@ -33,7 +33,7 @@ function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation, bc_loss_functions = if train_sets_bc !== nothing bcs_train_sets = [train_set[:, 2:end] for train_set in train_sets_bc] |> adaptor - [get_loss_function(_loss, _set, eltypeθ, strategy) + [get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy) for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)] else nothing @@ -53,17 +53,20 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, # the points in the domain and on the boundary pde_train_sets, bcs_train_sets = train_sets |> adaptor - pde_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) + pde_loss_functions = [get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy) for (_loss, _set) in zip( datafree_pde_loss_function, pde_train_sets)] - bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy) + bc_loss_functions = [get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy) for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)] return pde_loss_functions, bc_loss_functions end -function get_loss_function(loss_function, train_set, eltype0, ::GridTraining; τ = nothing) +function get_loss_function( + init_params, loss_function, train_set, eltype0, ::GridTraining; τ = nothing) + init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params + train_set = train_set |> get_device(init_params) |> EltypeAdaptor{eltype0}() return θ -> mean(abs2, loss_function(train_set, θ)) end @@ -100,19 +103,21 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) pde_bounds, bcs_bounds = bounds - pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) + pde_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy) for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] - bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) + bc_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy) for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] pde_loss_functions, bc_loss_functions end -function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining; - τ = nothing) +function get_loss_function(init_params, loss_function, bound, eltypeθ, + strategy::StochasticTraining; τ = nothing) + init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params + dev = get_device(init_params) return θ -> begin - sets = generate_random_points(strategy.points, bound, eltypeθ) |> + sets = generate_random_points(strategy.points, bound, eltypeθ) |> dev |> EltypeAdaptor{recursive_eltype(θ)}() return mean(abs2, loss_function(sets, θ)) end @@ -175,35 +180,36 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) pde_bounds, bcs_bounds = bounds - pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) + pde_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy) for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] strategy_ = QuasiRandomTraining(strategy.bcs_points; strategy.sampling_alg, strategy.resampling, strategy.minibatch) - bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy_) + bc_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy_) for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] return pde_loss_functions, bc_loss_functions end -function get_loss_function(loss_function, bound, eltypeθ, strategy::QuasiRandomTraining; - τ = nothing) +function get_loss_function(init_params, loss_function, bound, eltypeθ, + strategy::QuasiRandomTraining; τ = nothing) (; sampling_alg, points, resampling, minibatch) = strategy + init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params + dev = get_device(init_params) + return if resampling θ -> begin sets = @ignore_derivatives QuasiMonteCarlo.sample( points, bound[1], bound[2], sampling_alg) - sets = sets |> EltypeAdaptor{eltypeθ}() + sets = sets |> dev |> EltypeAdaptor{eltypeθ}() return mean(abs2, loss_function(sets, θ)) end else point_batch = generate_quasi_random_points_batch( - points, bound, eltypeθ, sampling_alg, minibatch) - θ -> begin - sets = point_batch[rand(1:minibatch)] |> EltypeAdaptor{eltypeθ}() - return mean(abs2, loss_function(sets, θ)) - end + points, bound, eltypeθ, sampling_alg, minibatch) |> dev |> + EltypeAdaptor{eltypeθ}() + θ -> mean(abs2, loss_function(point_batch[rand(1:minibatch)], θ)) end end @@ -250,27 +256,33 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, pde_bounds, bcs_bounds = bounds lbs, ubs = pde_bounds - pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) + pde_loss_functions = [get_loss_function(pinnrep, _loss, lb, ub, eltypeθ, strategy) for (_loss, lb, ub) in zip(datafree_pde_loss_function, lbs, ubs)] lbs, ubs = bcs_bounds - bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) + bc_loss_functions = [get_loss_function(pinnrep, _loss, lb, ub, eltypeθ, strategy) for (_loss, lb, ub) in zip(datafree_bc_loss_function, lbs, ubs)] return pde_loss_functions, bc_loss_functions end -function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::QuadratureTraining; - τ = nothing) - length(lb) == 0 && return (θ) -> mean(abs2, loss_function(rand(eltypeθ, 1, 10), θ)) +function get_loss_function(init_params, loss_function, lb, ub, eltypeθ, + strategy::QuadratureTraining; τ = nothing) + init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params + dev = get_device(init_params) + + if length(lb) == 0 + return (θ) -> mean(abs2, loss_function(dev(rand(eltypeθ, 1, 10)), θ)) + end + area = eltypeθ(prod(abs.(ub .- lb))) f_ = (lb, ub, loss_, θ) -> begin function integrand(x, θ) - x = x |> EltypeAdaptor{eltypeθ}() - sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x + x = x |> dev |> EltypeAdaptor{eltypeθ}() + return sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x end integral_function = BatchIntegralFunction(integrand, max_batch = strategy.batch) prob = IntegralProblem(integral_function, (lb, ub), θ) - solve(prob, strategy.quadrature_alg; strategy.reltol, strategy.abstol, + return solve(prob, strategy.quadrature_alg; strategy.reltol, strategy.abstol, strategy.maxiters)[1] end return (θ) -> 1 / area * f_(lb, ub, loss_function, θ) @@ -299,7 +311,9 @@ This training strategy can only be used with ODEs (`NNODE`). points::Int end -function get_loss_function(loss_function, train_set, eltype0, ::WeightedIntervalTraining; - τ = nothing) +function get_loss_function(init_params, loss_function, train_set, eltype0, + ::WeightedIntervalTraining; τ = nothing) + init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params + train_set = train_set |> get_device(init_params) |> EltypeAdaptor{eltype0}() return (θ) -> mean(abs2, loss_function(train_set, θ)) end diff --git a/test/NNPDE_tests_gpu_Lux.jl b/test/NNPDE_tests_gpu_Lux.jl index 378c240165..90674b23ff 100644 --- a/test/NNPDE_tests_gpu_Lux.jl +++ b/test/NNPDE_tests_gpu_Lux.jl @@ -1,17 +1,14 @@ -using Lux, ComponentArrays, OptimizationOptimisers -using Test, NeuralPDE -using Optimization -using LuxCUDA, QuasiMonteCarlo +using Lux, ComponentArrays, OptimizationOptimisers, Test, NeuralPDE, Optimization, LuxCUDA, + QuasiMonteCarlo, Random import ModelingToolkit: Interval, infimum, supremum -using Random Random.seed!(100) callback = function (p, l) println("Current loss is: $l") return false end -CUDA.allowscalar(false) + const gpud = gpu_device() @testset "ODE" begin @@ -32,22 +29,16 @@ const gpud = gpu_device() dt = 0.1f0 # Neural network inner = 20 - chain = Lux.Chain(Lux.Dense(1, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, 1)) + chain = Chain(Dense(1, inner, σ), Dense(inner, inner, σ), Dense(inner, inner, σ), + Dense(inner, inner, σ), Dense(inner, inner, σ), Dense(inner, 1)) strategy = GridTraining(dt) ps = Lux.setup(Random.default_rng(), chain)[1] |> ComponentArray |> gpud - discretization = PhysicsInformedNN(chain, - strategy; - init_params = ps) + discretization = PhysicsInformedNN(chain, strategy; init_params = ps) @named pde_system = PDESystem(eq, bcs, domains, [θ], [u(θ)]) prob = discretize(pde_system, discretization) - res = Optimization.solve(prob, OptimizationOptimisers.Adam(1e-2); maxiters = 2000) + res = solve(prob, OptimizationOptimisers.Adam(1e-2); maxiters = 2000) phi = discretization.phi analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2 ts = [infimum(d.domain):(dt / 10):supremum(d.domain) for d in domains][1] @@ -73,13 +64,9 @@ end @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)]) inner = 30 - chain = Lux.Chain(Lux.Dense(2, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, 1)) + chain = Chain(Dense(2, inner, σ), Dense(inner, inner, σ), + Dense(inner, inner, σ), Dense(inner, inner, σ), + Dense(inner, inner, σ), Dense(inner, inner, σ), Dense(inner, 1)) strategy = StochasticTraining(500) ps = Lux.setup(Random.default_rng(), chain)[1] |> ComponentArray |> gpud .|> Float64 @@ -119,11 +106,8 @@ end @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)]) inner = 20 - chain = Lux.Chain(Lux.Dense(2, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, 1)) + chain = Chain(Dense(2, inner, σ), Dense(inner, inner, σ), + Dense(inner, inner, σ), Dense(inner, inner, σ), Dense(inner, 1)) strategy = QuasiRandomTraining( 500; sampling_alg = SobolSample(), resampling = false, minibatch = 30) @@ -173,11 +157,8 @@ end # Neural network inner = 25 - chain = Lux.Chain(Lux.Dense(3, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, inner, Lux.σ), - Lux.Dense(inner, 1)) + chain = Chain(Dense(3, inner, σ), Dense(inner, inner, σ), + Dense(inner, inner, σ), Dense(inner, inner, σ), Dense(inner, 1)) strategy = GridTraining(0.05) ps = Lux.setup(Random.default_rng(), chain)[1] |> ComponentArray |> gpud .|> Float64