Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ODEFunctionExpr and simplify code #128

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 1 addition & 123 deletions src/ode_def_opts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,127 +78,5 @@ function ode_def_opts(name::Symbol, opts::Dict{Symbol, Bool}, curmod, ex::Expr,
mtk_diffeqs = [D(vars[i]) ~ mtk_ops[i] for i in 1:length(vars)]

sys = ODESystem(mtk_diffeqs, t, vars, params, name = gensym(:Parameterized))

f_ex_oop, f_ex_iip = ModelingToolkit.generate_function(sys, vars, params)

if opts[:build_tgrad]
try
tgrad_ex_oop, tgrad_ex_iip = ModelingToolkit.generate_tgrad(sys, vars, params)
catch
@warn "tgrad construction failed"
tgrad_ex_oop, tgrad_ex_iip = nothing, nothing
end
else
tgrad_ex_oop, tgrad_ex_iip = nothing, nothing
end

if opts[:build_jac]
try
J_ex_oop, J_ex_iip = ModelingToolkit.generate_jacobian(sys, vars, params)
catch
@warn "Jacobian construction failed"
J_ex_oop, J_ex_iip = nothing, nothing
end
else
J_ex_oop, J_ex_iip = nothing, nothing
end

if opts[:build_invW] && length(mtk_diffeqs) < 4
try
W_exs = ModelingToolkit.generate_factorized_W(sys, vars, params, false)
W_ex_oop, W_ex_iip = W_exs[1]
W_t_ex_oop, W_t_ex_iip = W_exs[2]
catch
@warn "W-expression construction failed"
W_ex_oop, W_ex_iip = (nothing, nothing)
W_t_ex_oop, W_t_ex_iip = (nothing, nothing)
end
else
W_ex_oop, W_ex_iip = (nothing, nothing)
W_t_ex_oop, W_t_ex_iip = (nothing, nothing)
end

fname = gensym(:ParameterizedDiffEqFunction)
tname = gensym(:ParameterizedTGradFunction)
jname = gensym(:ParameterizedJacobianFunction)
Wname = gensym(:ParameterizedWFactFunction)
W_tname = gensym(:ParameterizedW_tFactFunction)
funcname = gensym(:ParameterizedODEFunction)

if tgrad_ex_oop !== nothing
full_tex = quote
$tname($(tgrad_ex_oop.args[1].args...)) = $(tgrad_ex_oop.args[2])
$tname($(tgrad_ex_iip.args[1].args...)) = $(tgrad_ex_iip.args[2])
end
else
full_tex = quote
$tname = nothing
end
end

if J_ex_oop !== nothing
full_jex = quote
$jname($(J_ex_oop.args[1].args...)) = $(J_ex_oop.args[2])
$jname($(J_ex_iip.args[1].args...)) = $(J_ex_iip.args[2])
end
else
full_jex = quote
$jname = nothing
end
end

if W_ex_oop !== nothing
full_wex = quote
$Wname($(W_ex_oop.args[1].args...)) = $(W_ex_oop.args[2])
$Wname($(W_ex_iip.args[1].args...)) = $(W_ex_iip.args[2])
$W_tname($(W_t_ex_oop.args[1].args...)) = $(W_t_ex_oop.args[2])
$W_tname($(W_t_ex_iip.args[1].args...)) = $(W_t_ex_iip.args[2])
end
else
full_wex = quote
$Wname = nothing
$W_tname = nothing
end
end

quote
struct $name{F, TG, TJ, TW, TWt, S} <:
ParameterizedFunctions.DiffEqBase.AbstractParameterizedFunction{true}
f::F
mass_matrix::ParameterizedFunctions.LinearAlgebra.UniformScaling{Bool}
analytic::Nothing
tgrad::TG
jac::TJ
jvp::Nothing
vjp::Nothing
jac_prototype::Nothing
sparsity::Nothing
Wfact::TW
Wfact_t::TWt
paramjac::Nothing
syms::Vector{Symbol}
indepvar::Symbol
colorvec::Nothing
sys::S
initialization_data::Nothing
nlprob_data::Nothing
end

(f::$name)(args...) = f.f(args...)

function ParameterizedFunctions.SciMLBase.remake(func::$name; kwargs...)
return func
end

$fname($(f_ex_oop.args[1].args...)) = $(f_ex_oop.args[2])
$fname($(f_ex_iip.args[1].args...)) = $(f_ex_iip.args[2])
$full_tex
$full_jex
$full_wex

$name($fname, ParameterizedFunctions.LinearAlgebra.I, nothing, $tname, $jname,
nothing, nothing,
nothing, nothing, $Wname, $W_tname, nothing, $syms, $(Meta.quot(depvar)),
nothing, $sys, nothing, nothing)
end |> esc
ODEFunctionExpr(sys, tgrad = opts[:build_tgrad], jac = opts[:build_jac])
end
Loading