-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MILC batched deflation #1529
base: develop
Are you sure you want to change the base?
MILC batched deflation #1529
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR @leonhostetler. This will add some critical functionality between MILC and QUDA.
- I agree with your change from error to warning for the size of the basis. to be honest I'm not even sure if we need it to be a warning at all.
I think we need to demand That line would then become:
If this seems reasonable can you please put this change in @leonhostetler ? I don't think I can push directly to your branch. |
@weinbe2 I have committed the change you suggested after re-running tests on my end. |
Thanks @leonhostetler! Per our offline convo I'll do some due-diligence last tests, but ideally we'll get this all merged in tomorrow. |
As an update here: after some testing, I realized this PR breaks backwards compatibility of new QUDA + old MILC. This is because it modifies structs; C-style function signature/argument resolution doesn't encode this and because the function signature didn't change it'll happily link and then give undefined behavior at runtime. @leonhostetler and I are discussing how to safely resolve this offline. |
This pull request implements batched deflation for MILC. Most of the work was limited to the QUDA/MILC interface, however, there is one change (the first one listed below) that goes beyond that scope, so please check that one in particular to ensure that my solution is okay.
errorQuda()
inconstructDeflationSpace()
towarningQuda()
. See below for more descriptionqudaCleanUpDeflationSpace()
. See below for more descriptioninvertQuda()
to preserve the deflation space. The preservation of the eigenvectors and eigenvalues is then controlled from the MILC side by the appropriate parameters in theQudaEigParam
struct.tol_restart
andcuda_prec_eigensolver
from the MILC sideinvertQuda()
were copied toinvertQudaMsrc()
.Note there is a companion MILC pull request at milc-qcd/milc_qcd#76
Note that this implementation only performs deflation for even parity solves. This seems to be fine when the UML solver is selected for MILC. With UML, the odd parity solution is reconstructed from the even parity solution and then polished with a few CG iterations. Usually this is a small number and might not benefit from deflation anyway. On the other hand, if the CG or CGZ solvers are preferred then odd parity deflation should also be implemented in the future.
More details:
errorQuda()
inconstructDeflationSpace()
towarningQuda()
. Here's why:When loading eigenvectors from file,
loadFromFile()
->computeEvals()
is called, which extends the deflation space by the amount of the batch size:After the first deflated solve, the deflation space is preserved, but it is now of size
evecs.size() + batch_size
.This is fine for the first deflated solve, but on the second deflated solve, when
constructDeflationSpace()
is called and the preserved deflation space is attempted to be loaded, the following check fails:It expects
However, it is now actually
Replacing the
errorQuda()
withwarningQuda()
allows it to run, and it runs fine. However, if there are cases where we really need the warning to be an error, then this is probably not the right way to fix it.qudaCleanUpDeflationSpace()
, which can be called from MILC to clean up the deflation space. An alternative approach would be to setpreserve_deflation_space
to false on the last solve. However, this would limit how much we could amortize the cost of the eigenvector loading. The parameter input file for a large MILC calculation is chunked into "readin sets". MILC reads one of these readin sets, performs all calculations therein, and then reads the next readin set. So during a job, MILC does not know when it is doing the last solve. At best, it knows it's the last solve in the current readin set. So it seemed to me the best way to ensure that the deflation space is preserved over all readin sets is to add a cleanup function that can be called from MILC'sfinalize_quda()
function inmilc_qcd/generic/milc_to_quda_utilities.c
.