diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 9477bc9d0..52824ec2c 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -27,7 +27,13 @@ function build_solution( N = length((size(prob.u0)..., length(u))) end - if has_analytic(prob.f) + if typeof(prob.f) <: Tuple + f = prob.f[1] + else + f = prob.f + end + + if has_analytic(f) u_analytic = Vector{typeof(prob.u0)}(0) errors = Dict{Symbol,eltype(prob.u0)}() sol = ODESolution{T,N,typeof(u),typeof(u_analytic),typeof(errors),typeof(t),typeof(k), @@ -45,9 +51,16 @@ function build_solution( end function calculate_solution_errors!(sol::AbstractODESolution;fill_uanalytic=true,timeseries_errors=true,dense_errors=true) + + if typeof(sol.prob.f) <: Tuple + f = sol.prob.f[1] + else + f = sol.prob.f + end + if fill_uanalytic for i in 1:size(sol.u,1) - push!(sol.u_analytic,sol.prob.f(Val{:analytic},sol.t[i],sol.prob.u0)) + push!(sol.u_analytic,f(Val{:analytic},sol.t[i],sol.prob.u0)) end end @@ -61,7 +74,7 @@ function calculate_solution_errors!(sol::AbstractODESolution;fill_uanalytic=true if sol.dense && dense_errors densetimes = collect(linspace(sol.t[1],sol.t[end],100)) interp_u = sol(densetimes) - interp_analytic = [sol.prob.f(Val{:analytic},t,sol.u[1]) for t in densetimes] + interp_analytic = [f(Val{:analytic},t,sol.u[1]) for t in densetimes] sol.errors[:L∞] = maximum(vecvecapply((x)->abs.(x),interp_u-interp_analytic)) sol.errors[:L2] = sqrt(recursive_mean(vecvecapply((x)->float.(x).^2,interp_u-interp_analytic))) end diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index add6d1479..bebbdca72 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -29,7 +29,13 @@ function build_solution( N = length((size(prob.u0)..., length(u))) end - if has_analytic(prob.f) + if typeof(prob.f) <: Tuple + f = prob.f[1] + else + f = prob.f + end + + if has_analytic(f) u_analytic = Vector{typeof(prob.u0)}(0) errors = Dict{Symbol,eltype(prob.u0)}() sol = RODESolution{T,N,typeof(u),typeof(u_analytic),typeof(errors),typeof(t),typeof(W), @@ -49,9 +55,16 @@ function build_solution( end function calculate_solution_errors!(sol::AbstractRODESolution;fill_uanalytic=true,timeseries_errors=true,dense_errors=true) + + if typeof(sol.prob.f) <: Tuple + f = sol.prob.f[1] + else + f = sol.prob.f + end + if fill_uanalytic for i in 1:size(sol.u,1) - push!(sol.u_analytic,sol.prob.f(Val{:analytic},sol.t[i],sol.prob.u0,sol.W[i])) + push!(sol.u_analytic,f(Val{:analytic},sol.t[i],sol.prob.u0,sol.W[i])) end end