Skip to content

Commit

Permalink
allow for tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 28, 2017
1 parent 3083d48 commit a605ec1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
19 changes: 16 additions & 3 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ function build_solution(
N = length((size(prob.u0)..., length(u)))
end

if has_analytic(prob.f)
if typeof(prob.f) <: Tuple
f = prob.f[1]
else
f = prob.f
end

if has_analytic(f)
u_analytic = Vector{typeof(prob.u0)}(0)
errors = Dict{Symbol,eltype(prob.u0)}()
sol = ODESolution{T,N,typeof(u),typeof(u_analytic),typeof(errors),typeof(t),typeof(k),
Expand All @@ -45,9 +51,16 @@ function build_solution(
end

function calculate_solution_errors!(sol::AbstractODESolution;fill_uanalytic=true,timeseries_errors=true,dense_errors=true)

if typeof(sol.prob.f) <: Tuple
f = sol.prob.f[1]
else
f = sol.prob.f
end

if fill_uanalytic
for i in 1:size(sol.u,1)
push!(sol.u_analytic,sol.prob.f(Val{:analytic},sol.t[i],sol.prob.u0))
push!(sol.u_analytic,f(Val{:analytic},sol.t[i],sol.prob.u0))
end
end

Expand All @@ -61,7 +74,7 @@ function calculate_solution_errors!(sol::AbstractODESolution;fill_uanalytic=true
if sol.dense && dense_errors
densetimes = collect(linspace(sol.t[1],sol.t[end],100))
interp_u = sol(densetimes)
interp_analytic = [sol.prob.f(Val{:analytic},t,sol.u[1]) for t in densetimes]
interp_analytic = [f(Val{:analytic},t,sol.u[1]) for t in densetimes]
sol.errors[:L∞] = maximum(vecvecapply((x)->abs.(x),interp_u-interp_analytic))
sol.errors[:L2] = sqrt(recursive_mean(vecvecapply((x)->float.(x).^2,interp_u-interp_analytic)))
end
Expand Down
17 changes: 15 additions & 2 deletions src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ function build_solution(
N = length((size(prob.u0)..., length(u)))
end

if has_analytic(prob.f)
if typeof(prob.f) <: Tuple
f = prob.f[1]
else
f = prob.f
end

if has_analytic(f)
u_analytic = Vector{typeof(prob.u0)}(0)
errors = Dict{Symbol,eltype(prob.u0)}()
sol = RODESolution{T,N,typeof(u),typeof(u_analytic),typeof(errors),typeof(t),typeof(W),
Expand All @@ -49,9 +55,16 @@ function build_solution(
end

function calculate_solution_errors!(sol::AbstractRODESolution;fill_uanalytic=true,timeseries_errors=true,dense_errors=true)

if typeof(sol.prob.f) <: Tuple
f = sol.prob.f[1]
else
f = sol.prob.f
end

if fill_uanalytic
for i in 1:size(sol.u,1)
push!(sol.u_analytic,sol.prob.f(Val{:analytic},sol.t[i],sol.prob.u0,sol.W[i]))
push!(sol.u_analytic,f(Val{:analytic},sol.t[i],sol.prob.u0,sol.W[i]))
end
end

Expand Down

0 comments on commit a605ec1

Please sign in to comment.