Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MILC batched deflation #1529

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion include/quda_milc_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,52 @@ extern "C" {
} QudaInvertArgs_t;

/**
* Parameters related to deflated solvers.
* Parameters related to deflated linear solvers.
*/
typedef struct {
size_t struct_size; /** Size of this struct in bytes. Used to check that host application and QUDA see the same struct size **/
double tol_restart;
QudaPrecision prec_eigensolver;
int poly_deg; /** Degree of the Chebyshev polynomial **/
double a_min; /** Range used in polynomial acceleration **/
double a_max;
QudaBoolean preserve_evals; /** Whether to preserve the evals or recompute them **/
int n_ev; /** Size of the eigenvector search space **/
int n_kr; /** Total size of Krylov space **/
int n_conv; /** Number of requested converged eigenvectors **/
int n_ev_deflate; /** Number of requested converged eigenvectors to use in deflation **/
double tol; /** Tolerance on the least well known eigenvalue's residual **/
int max_restarts; /** For IRLM/IRAM, quit after n restarts **/
int batched_rotate; /** For the Ritz rotation, the maximal number of extra vectors the solver may allocate **/
int block_size; /** For block method solvers, the block size **/
char vec_infile[256]; /** Filename prefix where to load the null-space vectors */
char vec_outfile[256]; /** Filename prefix for where to save the null-space vectors */
QudaParity vec_in_parity; /** Parity of the incoming eigenvectors **/
QudaPrecision save_prec; /** The precision with which to save the vectors */
QudaBoolean partfile; /** Whether to save eigenvectors in QIO singlefile or partfile format */
QudaBoolean io_parity_inflate; /** Whether to inflate single-parity eigen-vector I/O **/
QudaBoolean use_norm_op;
QudaBoolean use_pc;
QudaEigType eig_type; /** Type of eigensolver algorithm to employ **/
QudaEigSpectrumType spectrum; /** Which part of the spectrum to solve **/
double qr_tol; /** Tolerance on the QR iteration **/
QudaBoolean require_convergence; /** If true, the solver will error out if the convergence criteria are not met **/
int check_interval; /** For IRLM/IRAM, check every nth restart **/
QudaBoolean use_dagger; /** If use_dagger, use Mdag **/
QudaBoolean compute_gamma5; /** Performs the \gamma_5 OP solve by post multiplying the eignvectors with \gamma_5 before computing the eigenvalues */
QudaBoolean compute_svd; /** Performs an MdagM solve, then constructs the left and right SVD. **/
QudaBoolean use_eigen_qr; /** Use Eigen routines to eigensolve the upper Hessenberg via QR **/
QudaBoolean use_poly_acc; /** Use Polynomial Acceleration **/
QudaBoolean arpack_check; /** In the test function, cross check the device result against ARPACK **/
char arpack_logfile[512]; /** For Arpack cross check, name of the Arpack logfile **/
int compute_evals_batch_size; /** The batch size used when computing eigenvalues **/
QudaBoolean preserve_deflation; /** Whether to preserve the deflation space between solves **/

} QudaEigensolverArgs_t;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This broadly looks good---to be more future-proof (in particular to if folks want to use this structure for other fermion types for whatever reason), we should add more of the fields that are currently getting default values in qudaInvertDeflatable.

QudaEigType eig_type = ( qep.block_size > 1 ) ? QUDA_EIG_BLK_TR_LANCZOS : QUDA_EIG_TR_LANCZOS;  /* or QUDA_EIG_IR_ARNOLDI, QUDA_EIG_BLK_IR_ARNOLDI */
QudaEigSpectrumType spectrum = QUDA_SPECTRUM_SR_EIG; /* Smallest Real. Other options: LM, SM, LR, SR, LI, SI */
QudaBoolean require_convergence = QUDA_BOOLEAN_TRUE;
QudaBoolean check_interval = 10;
QudaBoolean use_dagger = QUDA_BOOLEAN_FALSE;
QudaBoolean compute_gamma5 = QUDA_BOOLEAN_FALSE;
QudaBoolean compute_svd = QUDA_BOOLEAN_FALSE;
QudaBoolean use_eigen_qr = QUDA_BOOLEAN_TRUE;
QudaBoolean use_poly_acc = QUDA_BOOLEAN_TRUE;
QudaBoolean compute_evals_batch_size = 16;

The defaults can be then set on the MILC side of things instead of on the QUDA side of things.


Re: loading eigenvectors computed for the odd operator then applying them to the even operator (or any combination), you could add a field QudaParity vec_in_parity; or something of the sort. For the scope of this PR you could just require that the input parity be even otherwise it errors, but in a next PR we could expand it to handling constructing the even eigenvectors from input odd ones (or vice-versa).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part should be fixed now with cdedc07. The second part, RE loading eigenvectors, should be fixed now with 02eb6f5 and 2598582. Let me know if anything should be changed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! One thing I forgot: you should add a field struct_size that gets set to sizeof(QudaEigensolverArgs_t) on the MILC side, and then on the QUDA side you separately evaluate that and compare it with what the size of the struct gets set to in MILC. It should be the first field in the structure.

The reason to do this is to make error checking more explicit---if QUDA and MILC disagree on the size of the structure due to some mixed version issues, we'll get an error. Let me know if this makes sense/if you'd like me to better sketch it out.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed by b3f4e8d and c87e4dd



/**
* Parameters related to EigCG deflated solvers.
*/

typedef struct {
Expand Down Expand Up @@ -163,6 +208,11 @@ extern "C" {
*/
void qudaSetLayout(QudaLayout_t layout);

/**
* Clean up the QUDA deflation space.
*/
void qudaCleanUpDeflationSpace();

/**
* Destroy the QUDA context.
*/
Expand Down Expand Up @@ -363,6 +413,42 @@ extern "C" {
double* const final_rel_resid,
int* num_iters);

/**
* Solve Ax=b with deflation for an improved staggered operator. All fields are fields
* passed and returned are host (CPU) field in MILC order. This
* function requires that persistent gauge and clover fields have
* been created prior. This interface is experimental.
*
* @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single)
* @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single)
* @param[in] mass Fermion mass parameter
* @param[in] inv_args Struct setting some solver metadata
* @param[in] eig_args Struct setting some eigensolver metadata
* @param[in] target_residual Target residual
* @param[in] target_relative_residual Target Fermilab residual
* @param[in] milc_fatlink Fat-link field on the host
* @param[in] milc_longlink Long-link field on the host
* @param[in] source Right-hand side source field
* @param[out] solution Solution spinor field
* @param[in] final_residual True residual
* @param[in] final_relative_residual True Fermilab residual
* @param[in] num_iters Number of iterations taken
*/
void qudaInvertDeflatable(int external_precision,
int quda_precision,
double mass,
QudaInvertArgs_t inv_args,
QudaEigensolverArgs_t eig_args,
double target_residual,
double target_fermilab_residual,
const void* const milc_fatlink,
const void* const milc_longlink,
void* source,
void* solution,
double* const final_resid,
double* const final_rel_resid,
int* num_iters);

/**
* Prepare a staggered/HISQ multigrid solve with given fat and
* long links. All fields passed are host (CPU) fields
Expand Down Expand Up @@ -455,6 +541,44 @@ extern "C" {
int* num_iters,
int num_src);

/**
* Solve Ax=b with deflation for an improved staggered operator with many right hand sides.
* All fields are fields passed and returned are host (CPU) field in MILC order.
* This function requires that persistent gauge and clover fields have
* been created prior. This interface is experimental.
*
* @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single)
* @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single)
* @param[in] mass Fermion mass parameter
* @param[in] inv_args Struct setting some solver metadata
* @param[in] eig_args Struct setting some eigensolver metadata
* @param[in] target_residual Target residual
* @param[in] target_relative_residual Target Fermilab residual
* @param[in] milc_fatlink Fat-link field on the host
* @param[in] milc_longlink Long-link field on the host
* @param[in] source array of right-hand side source fields
* @param[out] solution array of solution spinor fields
* @param[in] final_residual True residual
* @param[in] final_relative_residual True Fermilab residual
* @param[in] num_iters Number of iterations taken
* @param[in] num_src Number of source fields
*/
void qudaInvertMsrcDeflatable(int external_precision,
int quda_precision,
double mass,
QudaInvertArgs_t inv_args,
QudaEigensolverArgs_t eig_args,
double target_residual,
double target_fermilab_residual,
const void* const fatlink,
const void* const longlink,
void** sourceArray,
void** solutionArray,
double* const final_residual,
double* const final_fermilab_residual,
int* num_iters,
int num_src);

/**
* Solve for multiple shifts (e.g., masses) using an improved
* staggered operator. All fields are fields passed and returned
Expand Down
Loading