From b161c2b414ba09d22b5dd74a08266c13414a764b Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 7 Oct 2024 19:11:39 +0530 Subject: [PATCH 1/2] chore: allow mtkp in BacksolveADjoint --- src/concrete_solve.jl | 12 ++++++------ src/parameters_handling.jl | 1 + src/sensitivity_interface.jl | 5 +---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index dd2af4e3a..503f07d3e 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -372,11 +372,11 @@ function DiffEqBase._concrete_solve_adjoint( saveat = eltype(prob.tspan)[], save_idxs = nothing, kwargs...) - if !(sensealg isa GaussAdjoint) && - !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(AdjointSensitivityParameterCompatibilityError()) - end + # if !(sensealg isa GaussAdjoint) && + # !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + # (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + # throw(AdjointSensitivityParameterCompatibilityError()) + # end if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity @@ -640,7 +640,7 @@ function DiffEqBase._concrete_solve_adjoint( du0 = reshape(du0, size(u0)) dp = p === nothing || p === SciMLBase.NullParameters() ? nothing : - dp isa AbstractArray ? reshape(dp', size(p)) : dp + dp isa AbstractArray ? reshape(dp', size(tunables)) : dp if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator diff --git a/src/parameters_handling.jl b/src/parameters_handling.jl index e30ddb232..4d1f4eddf 100644 --- a/src/parameters_handling.jl +++ b/src/parameters_handling.jl @@ -9,6 +9,7 @@ recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x) recursive_copyto!(y::AbstractArray, x::Number) = y .= x recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x) +recursive_copyto!(y::AbstractArray, x::Tuple) = recursive_copyto!(y, only(x)) function recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F} map(recursive_copyto!, values(y), values(x)) end diff --git a/src/sensitivity_interface.jl b/src/sensitivity_interface.jl index 84321f067..7769e562b 100644 --- a/src/sensitivity_interface.jl +++ b/src/sensitivity_interface.jl @@ -413,10 +413,7 @@ function _adjoint_sensitivities(sol, sensealg, alg; callback = nothing, kwargs...) mtkp = SymbolicIndexingInterface.parameter_values(sol) - if !(mtkp isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (mtkp isa AbstractArray && !Base.isconcretetype(eltype(mtkp))) - throw(AdjointSensitivityParameterCompatibilityError()) - end + rcb = nothing if sol.prob isa ODEProblem adj_prob, rcb = ODEAdjointProblem(sol, sensealg, alg, t, dgdu_discrete, From e8ef95632095dd07e3c4a17c8c0b593e7aed7b2f Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 7 Oct 2024 19:13:24 +0530 Subject: [PATCH 2/2] chore: rm Parameter check for sensealg --- src/concrete_solve.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 503f07d3e..f0763b510 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -372,11 +372,6 @@ function DiffEqBase._concrete_solve_adjoint( saveat = eltype(prob.tspan)[], save_idxs = nothing, kwargs...) - # if !(sensealg isa GaussAdjoint) && - # !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - # (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - # throw(AdjointSensitivityParameterCompatibilityError()) - # end if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity