Skip to content

Commit

Permalink
refactor: type differential_vars in generate_loss for DAEs as it caus…
Browse files Browse the repository at this point in the history
…es overwrite
  • Loading branch information
sathvikbhagavan committed Jan 31, 2024
1 parent fa71c18 commit a491a85
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff =
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end

function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool, differential_vars)
function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool, differential_vars::AbstractVector)
if autodiff
autodiff && throw(ArgumentError("autodiff not supported for DAE problem."))
else
Expand All @@ -64,7 +64,7 @@ function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool, differential_v
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, differential_vars) where {C, T, U}
p, differential_vars::AbstractVector) where {C, T, U}
out = Array(phi(t, θ))
dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars))
arrt = Array(t)
Expand All @@ -73,7 +73,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars)
differential_vars::AbstractVector)
ts = tspan[1]:(strategy.dx):tspan[2]
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
Expand Down

0 comments on commit a491a85

Please sign in to comment.