Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MILC batched deflation #1529

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from

Conversation

leonhostetler
Copy link

This pull request implements batched deflation for MILC. Most of the work was limited to the QUDA/MILC interface, however, there is one change (the first one listed below) that goes beyond that scope, so please check that one in particular to ensure that my solution is okay.

  1. Changed errorQuda() in constructDeflationSpace() to warningQuda(). See below for more description
  2. Added a function qudaCleanUpDeflationSpace(). See below for more description
  3. Added some basic stuff to invertQuda() to preserve the deflation space. The preservation of the eigenvectors and eigenvalues is then controlled from the MILC side by the appropriate parameters in the QudaEigParam struct.
  4. Added ability to control tol_restart and cuda_prec_eigensolver from the MILC side
  5. Changes to invertQuda() were copied to invertQudaMsrc().

Note there is a companion MILC pull request at milc-qcd/milc_qcd#76

Note that this implementation only performs deflation for even parity solves. This seems to be fine when the UML solver is selected for MILC. With UML, the odd parity solution is reconstructed from the even parity solution and then polished with a few CG iterations. Usually this is a small number and might not benefit from deflation anyway. On the other hand, if the CG or CGZ solvers are preferred then odd parity deflation should also be implemented in the future.

More details:

  1. I changed errorQuda() in constructDeflationSpace() to warningQuda(). Here's why:

When loading eigenvectors from file, loadFromFile() -> computeEvals() is called, which extends the deflation space by the amount of the batch size:

if (size + batch_size > static_cast<int>(evecs.size())) resize(evecs, size + batch_size, QUDA_NULL_FIELD_CREATE);

After the first deflated solve, the deflation space is preserved, but it is now of size evecs.size() + batch_size.

This is fine for the first deflated solve, but on the second deflated solve, when constructDeflationSpace() is called and the preserved deflation space is attempted to be loaded, the following check fails:

if ((!space->svd && param.eig_param.n_conv != (int)space->evecs.size())
    || (space->svd && 2 * param.eig_param.n_conv != (int)space->evecs.size()))
  errorQuda("Preserved deflation space size %lu does not match expected %d", space->evecs.size(),
	    param.eig_param.n_conv);

It expects

param.eig_param.n_conv == space->evecs.size()

However, it is now actually

param.eig_param.n_conv + param.eig_param.compute_evals_batch_size == space->evecs.size()

Replacing the errorQuda() with warningQuda() allows it to run, and it runs fine. However, if there are cases where we really need the warning to be an error, then this is probably not the right way to fix it.

  1. Added a function qudaCleanUpDeflationSpace(), which can be called from MILC to clean up the deflation space. An alternative approach would be to set preserve_deflation_space to false on the last solve. However, this would limit how much we could amortize the cost of the eigenvector loading. The parameter input file for a large MILC calculation is chunked into "readin sets". MILC reads one of these readin sets, performs all calculations therein, and then reads the next readin set. So during a job, MILC does not know when it is doing the last solve. At best, it knows it's the last solve in the current readin set. So it seemed to me the best way to ensure that the deflation space is preserved over all readin sets is to add a cleanup function that can be called from MILC's finalize_quda() function in milc_qcd/generic/milc_to_quda_utilities.c.

@leonhostetler leonhostetler requested a review from a team as a code owner December 22, 2024 15:55
Copy link
Member

@maddyscientist maddyscientist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR @leonhostetler. This will add some critical functionality between MILC and QUDA.

  • I agree with your change from error to warning for the size of the basis. to be honest I'm not even sure if we need it to be a warning at all.

@weinbe2
Copy link
Contributor

weinbe2 commented Jan 15, 2025

Thanks for this PR @leonhostetler. This will add some critical functionality between MILC and QUDA.

  • I agree with your change from error to warning for the size of the basis. to be honest I'm not even sure if we need it to be a warning at all.

I think we need to demand space->evecs.size() >= param.eig_param.n_conv, as in, make sure there at least enough eigenvectors. I agree there's no need for a warning otherwise, if there are extra vectors, whatever.

That line would then become:

if ((!space->svd && (int)space->evecs.size() < param.eig_param.n_conv)
    || (space->svd && (int)space->evecs.size() < (2 * param.eig_param.n_conv)))
  errorQuda("Preserved deflation space size %lu is smaller than the necessary %d", space->evecs.size(),
	    param.eig_param.n_conv);

If this seems reasonable can you please put this change in @leonhostetler ? I don't think I can push directly to your branch.

@leonhostetler
Copy link
Author

@weinbe2 I have committed the change you suggested after re-running tests on my end.

@weinbe2
Copy link
Contributor

weinbe2 commented Jan 16, 2025

@weinbe2 I have committed the change you suggested after re-running tests on my end.

Thanks @leonhostetler! Per our offline convo I'll do some due-diligence last tests, but ideally we'll get this all merged in tomorrow.

@weinbe2
Copy link
Contributor

weinbe2 commented Jan 22, 2025

As an update here: after some testing, I realized this PR breaks backwards compatibility of new QUDA + old MILC. This is because it modifies structs; C-style function signature/argument resolution doesn't encode this and because the function signature didn't change it'll happily link and then give undefined behavior at runtime. @leonhostetler and I are discussing how to safely resolve this offline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants