From e474540eaeca1d4ca46eb261757bd6b4eab17b7d Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 7 Jul 2023 17:49:35 -0400 Subject: [PATCH] Start implementation of kernels compatible with IIP problems --- src/DiffEqGPU.jl | 11 +++ src/integrators/integrator_utils.jl | 11 +-- src/integrators/types.jl | 3 +- src/perform_step/gpu_tsit5_perform_step.jl | 81 ++++++++++++++++++++++ src/solve.jl | 29 +++++--- 5 files changed, 119 insertions(+), 16 deletions(-) diff --git a/src/DiffEqGPU.jl b/src/DiffEqGPU.jl index a030dfbb..d64891b9 100644 --- a/src/DiffEqGPU.jl +++ b/src/DiffEqGPU.jl @@ -883,6 +883,16 @@ function vectorized_map_solve(probs, alg, adaptive = adaptive, kwargs...) end +function handle_iip_prob(prob) + if DiffEqBase.isinplace(prob) + remake(prob; + u0 = convert(SArray, prob.u0), + p = prob.p isa SciMLBase.NullParameters ? prob.p : convert(SArray, prob.p)) + else + prob + end +end + function batch_solve(ensembleprob, alg, ensemblealg::Union{EnsembleArrayAlgorithm, EnsembleKernelAlgorithm}, I, adaptive; @@ -926,6 +936,7 @@ function batch_solve(ensembleprob, alg, # Get inner saveat if global one isn't specified _saveat = get(probs[1].kwargs, :saveat, nothing) saveat = _saveat === nothing ? get(kwargs, :saveat, nothing) : _saveat + # probs = handle_iip_prob.(probs) solts, solus = batch_solve_up_kernel(ensembleprob, probs, alg, ensemblealg, I, adaptive; saveat = saveat, kwargs...) [begin diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index 8ffa847f..40b2fb43 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -26,8 +26,8 @@ end if saveat === nothing && save_everystep saved = true savedexactly = true - @inbounds us[integrator.step_idx] = integrator.u - @inbounds ts[integrator.step_idx] = integrator.t + @inbounds us[integrator.step_idx] = convert(eltype(us), integrator.u) + @inbounds ts[integrator.step_idx] = convert(eltype(ts), integrator.t) integrator.step_idx += 1 elseif saveat !== nothing saved = true @@ -35,9 +35,10 @@ end while integrator.cur_t <= length(saveat) && saveat[integrator.cur_t] <= integrator.t savet = saveat[integrator.cur_t] Θ = (savet - integrator.tprev) / integrator.dt - @inbounds us[integrator.cur_t] = _ode_interpolant(Θ, integrator.dt, - integrator.uprev, integrator) - @inbounds ts[integrator.cur_t] = savet + @inbounds us[integrator.cur_t] = convert(eltype(us), + _ode_interpolant(Θ, integrator.dt, + integrator.uprev, integrator)) + @inbounds ts[integrator.cur_t] = convert(eltype(ts), savet) integrator.cur_t += 1 end end diff --git a/src/integrators/types.jl b/src/integrators/types.jl index 8c5a6796..744ed308 100644 --- a/src/integrators/types.jl +++ b/src/integrators/types.jl @@ -318,7 +318,6 @@ end TS, CB, ST} cs, as, rs = SimpleDiffEq._build_tsit5_caches(T) - !IIP && @assert S <: SArray event_last_time = 1 vector_event_last_time = 0 last_event_error = zero(T) @@ -347,7 +346,7 @@ end saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} cs, as, btildes, rs = SimpleDiffEq._build_atsit5_caches(T) - !IIP && @assert S <: SArray + @assert S <: SArray qoldinit = T(1e-4) event_last_time = 1 diff --git a/src/perform_step/gpu_tsit5_perform_step.jl b/src/perform_step/gpu_tsit5_perform_step.jl index 158e4d2c..67f8ea4d 100644 --- a/src/perform_step/gpu_tsit5_perform_step.jl +++ b/src/perform_step/gpu_tsit5_perform_step.jl @@ -1,3 +1,84 @@ +@inline function step!(integ::GPUT5I{true, S, T}, ts, us) where {T, S} + c1, c2, c3, c4, c5, c6 = integ.cs + dt = integ.dt + t = integ.t + p = integ.p + a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, + a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76 = integ.as + + k1 = integ.k1 + k2 = integ.k2 + k3 = integ.k3 + k4 = integ.k4 + k5 = integ.k5 + k6 = integ.k6 + k7 = integ.k7 + + tmp = integ.tmp + f! = integ.f + integ.uprev .= integ.u + uprev = integ.u + + L = length(integ.u) + + integ.tprev = t + saved_in_cb = false + adv_integ = true + ## Check if tstops are within the range of time-series + if integ.tstops !== nothing && integ.tstops_idx <= length(integ.tstops) && + (integ.tstops[integ.tstops_idx] - integ.t - integ.dt - 100 * eps(T) < 0) + integ.t = integ.tstops[integ.tstops_idx] + ## Set correct dt + dt = integ.t - integ.tprev + integ.tstops_idx += 1 + else + ##Advance the integrator + integ.t += dt + end + + if integ.u_modified + f!(k1, uprev, p, t) + integ.u_modified = false + else + @inbounds k1 .= k7 + end + + @inbounds begin + for i in 1:L + tmp[i] = uprev[i] + dt * a21 * k1[i] + end + f!(k2, tmp, p, t + c1 * dt) + for i in 1:L + tmp[i] = uprev[i] + dt * (a31 * k1[i] + a32 * k2[i]) + end + f!(k3, tmp, p, t + c2 * dt) + for i in 1:L + tmp[i] = uprev[i] + dt * (a41 * k1[i] + a42 * k2[i] + a43 * k3[i]) + end + f!(k4, tmp, p, t + c3 * dt) + for i in 1:L + tmp[i] = uprev[i] + dt * (a51 * k1[i] + a52 * k2[i] + a53 * k3[i] + a54 * k4[i]) + end + f!(k5, tmp, p, t + c4 * dt) + for i in 1:L + tmp[i] = uprev[i] + + dt * + (a61 * k1[i] + a62 * k2[i] + a63 * k3[i] + a64 * k4[i] + a65 * k5[i]) + end + f!(k6, tmp, p, t + dt) + for i in 1:L + integ.u[i] = uprev[i] + + dt * (a71 * k1[i] + a72 * k2[i] + a73 * k3[i] + a74 * k4[i] + + a75 * k5[i] + a76 * k6[i]) + end + end + f!(k7, integ.u, p, t + dt) + + _, saved_in_cb = handle_callbacks!(integ, ts, us) + + return saved_in_cb +end + @inline function step!(integ::GPUT5I{false, S, T}, ts, us) where {T, S} c1, c2, c3, c4, c5, c6 = integ.cs dt = integ.dt diff --git a/src/solve.jl b/src/solve.jl index 52d80093..af8afd9d 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -34,6 +34,8 @@ function vectorized_solve(probs, prob::ODEProblem, alg; timeseries = prob.tspan[1]:dt:prob.tspan[2] nsteps = length(timeseries) + prob = handle_iip_prob(prob) + if saveat === nothing if save_everystep len = length(prob.tspan[1]:dt:prob.tspan[2]) @@ -208,17 +210,26 @@ end i = @index(Global, Linear) # get the actual problem for this thread - prob = @inbounds probs[i] + _prob = @inbounds probs[i] # get the input/output arrays for this thread ts = @inbounds view(_ts, :, i) us = @inbounds view(_us, :, i) - _saveat = get(prob.kwargs, :saveat, nothing) + _saveat = get(_prob.kwargs, :saveat, nothing) saveat = _saveat === nothing ? saveat : _saveat - integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops, + prob = if DiffEqBase.isinplace(_prob) + remake(_prob; + u0 = convert(MArray, _prob.u0), + p = _prob.p isa SciMLBase.NullParameters ? _prob.p : convert(MArray, _prob.p)) + else + _prob + end + + integ = init(alg, prob.f, DiffEqBase.isinplace(prob), prob.u0, prob.tspan[1], dt, + prob.p, tstops, callback, save_everystep, saveat) u0 = prob.u0 @@ -232,8 +243,8 @@ end @inbounds us[1] = u0 end else - @inbounds ts[integ.step_idx] = prob.tspan[1] - @inbounds us[integ.step_idx] = prob.u0 + @inbounds ts[integ.step_idx] = convert(eltype(ts), prob.tspan[1]) + @inbounds us[integ.step_idx] = convert(eltype(us), prob.u0) end integ.step_idx += 1 @@ -244,13 +255,13 @@ end end if integ.t > tspan[2] && saveat === nothing ## Intepolate to tf - @inbounds us[end] = integ(tspan[2]) - @inbounds ts[end] = tspan[2] + @inbounds us[end] = convert(eltype(us), integ(tspan[2])) + @inbounds ts[end] = convert(eltype(ts), tspan[2]) end if saveat === nothing && !save_everystep - @inbounds us[2] = integ.u - @inbounds ts[2] = integ.t + @inbounds us[2] = convert(eltype(us), integ.u) + @inbounds ts[2] = convert(eltype(ts), integ.t) end end