Skip to content

Commit

Permalink
Fixed PositiveDefinite erroring on taking the gradient of `permuted…
Browse files Browse the repository at this point in the history
…ims` on an empty array. Noted broken Tracker gradient test for `PositiveDefinite`
  • Loading branch information
nicholaskl97 committed Jan 30, 2025
1 parent 3ea0e79 commit 5c6888b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ end
function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st)
ϕ0, new_model_st = pd.model(st.x0, ps, st.model)
ϕx, final_model_st = pd.model(x, ps, new_model_st)
init = @ignore_derivatives permutedims(empty(ϕ0))
return (
mapreduce(hcat, zip(eachcol(x), eachcol(ϕx)); init=permutedims(empty(ϕ0))) do (x, ϕx)
mapreduce(hcat, zip(eachcol(x), eachcol(ϕx)); init=init) do (x, ϕx)
pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0)
end,
merge(st, (; model=final_model_st))
Expand Down
2 changes: 1 addition & 1 deletion test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ end
@jet pd(x, ps, st)

__f = (x, ps) -> sum(first(pd(x, ps, st)))
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoTracker()])
end
end

Expand Down

0 comments on commit 5c6888b

Please sign in to comment.