Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward enzyme rules for init and solve #416

Merged
merged 11 commits into from
Nov 8, 2023

Conversation

sharanry
Copy link
Contributor

@sharanry sharanry commented Nov 5, 2023

TODO:

  • Test against other algorithms
  • Error on unspported algorithms.

Copy link

codecov bot commented Nov 5, 2023

Codecov Report

Merging #416 (0935919) into main (b9da6ac) will decrease coverage by 1.11%.
The diff coverage is 0.00%.

@@            Coverage Diff             @@
##             main     #416      +/-   ##
==========================================
- Coverage   64.94%   63.83%   -1.11%     
==========================================
  Files          26       26              
  Lines        2068     2099      +31     
==========================================
- Hits         1343     1340       -3     
- Misses        725      759      +34     
Files Coverage Δ
ext/LinearSolveEnzymeExt.jl 0.00% <0.00%> (ø)

... and 1 file with indirect coverage changes

📣 Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today!

test/enzyme.jl Outdated
Comment on lines 219 to 220
@test_broken en_jac ≈ fd_jac
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests fail due to numerical imprecision. Not entirely sure why.

Copy link
Contributor Author

@sharanry sharanry Nov 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

df/db:

manual_jac = [-5.061916901336408, 1.5770609499033128, 5.446411839233853, 0.612648432464526]
fd_jac = [-5.061916947364807, 1.5770610570907593, 5.446411848068237, 0.6126484870910645]
en_jac = [-5.061916901336408, 1.5770609499033128, 5.446411839233853, 0.612648432464526]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an expected difference due to summation order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in 31745cc

end

dres = deepcopy(res)
invA = inv(A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never use inv

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in d010911

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fully, see #416 (comment). Your version has an extra factorization and is using the wrong linear solver.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in 31745cc

invA = inv(A)
db = linsolve.dval.b
dA = linsolve.dval.A
dres.u .= invA * (db - dA * res.u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to refactorize here. This is just A * u = db - dA * u, or u = A \ (db - dA * u). But this is just a linsolve call. Not only that, but it's the same operator as the one used in the forward pass, so you don't need to refactorize A. Therefore, this should simply use the same linsolve and do linsolve.b = db - dA * u and then solve!.

But I don't think the formula is correct. Isn't it just linsolve.b = db and then solve!(linsolve) then du = linsolve.u?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But linsolves.A seem to mutate after the solve. That is why I had to add all this other stuff.

MWE:

using LinearSolve
A = rand(5,5)
b = rand(5)

prob = LinearProblem(A,b)
linsolve = init(prob)
@show linsolve.A
sol = solve!(linsolve)
@show linsolve.A
5×5 Matrix{Float64}:
 0.207897  0.0737386  0.220551   0.935437  0.883482
 0.362979  0.687049   0.220086   0.216771  0.145246
 0.538937  0.577133   0.965879   0.438704  0.689019
 0.348145  0.174776   0.0491639  0.145817  0.155529
 0.92258   0.637729   0.508441   0.166917  0.218296

5×5 Matrix{Float64}:
 0.92258    0.637729   0.508441   0.166917  0.218296
 0.393439   0.436142   0.0200455  0.1511    0.0593595
 0.584162   0.469106   0.659464   0.270316  0.533653
 0.225343  -0.160427   0.16558    0.877305  0.755451
 0.37736   -0.151045  -0.211799   0.185687  0.0548682

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the mutation that it's doing is lu!(A), i.e. its using the same memory in the representation of the LU-factorization. You want to let it use exactly the same mutated A via cache.cacheval in order to do the next solve, which is precisely what's done in the caching interface. See https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. And the simple derivation of the formula is mentioned in https://scicomp.stackexchange.com/a/29421.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@sharanry sharanry Nov 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in 31745cc 19a54fe

@sharanry sharanry marked this pull request as ready for review November 6, 2023 17:30
@ChrisRackauckas
Copy link
Member

This looks like a real test failure.

if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle batching atm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw a good error for now? It would be good to get something merged even if it doesn't handle every case, as long as the errors are clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried to address this in 585cbb8. We can make a followup PR for batch support.

@sharanry
Copy link
Contributor Author

sharanry commented Nov 7, 2023

@ChrisRackauckas The test failures seem unrelated.

@ChrisRackauckas ChrisRackauckas merged commit 9b540ba into SciML:main Nov 8, 2023
10 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants