Skip to content

Commit

Permalink
Refactor david: complete removal of Psi and Hamilt datatypes (deepmod…
Browse files Browse the repository at this point in the history
…eling#4722)

* Refactor david: replace Psi parameter with pointer

* remove Psi-typed variable basis in david

* complete Psi removal in david: clear redundant and temporary code
  • Loading branch information
Cstandardlib authored Jul 17, 2024
1 parent 629ba64 commit be8f40f
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 70 deletions.
52 changes: 20 additions & 32 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
const int dim,
const int nband,
const int ldPsi,
psi::Psi<T, Device>& psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const int david_maxiter)
Expand Down Expand Up @@ -90,21 +90,17 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,

const int nbase_x = this->david_ndim * nband; // maximum dimension of the reduced basis set

T *psi_in = psi.get_pointer();
// T *psi_in = psi.get_pointer();

// the lowest N eigenvalues
base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->eigenvalue, nbase_x, "DAV::eig");
base_device::memory::set_memory_op<Real, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->eigenvalue, 0, nbase_x);

psi::Psi<T, Device> basis(1,
nbase_x,
dim,
&(psi.get_ngk(0))); // the reduced basis set
// basis(dim, nbase_x), leading dimension = dim
pbasis = basis.get_pointer();
ModuleBase::Memory::record("DAV::basis", nbase_x * dim * sizeof(T));
resmem_complex_op()(this->ctx, pbasis, nbase_x * dim, "DAV::basis");
setmem_complex_op()(this->ctx, pbasis, 0, nbase_x * dim);

//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// ModuleBase::ComplexMatrix hp(nbase_x, dim); // the product of H and psi in the reduced basis set
Expand Down Expand Up @@ -177,7 +173,6 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
this->SchmitOrth(dim,
nband,
m,
basis,
this->sphi,
&this->lagrange_matrix[m * nband],
pre_matrix_mm_m[m],
Expand All @@ -201,7 +196,7 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(this->hphi, pbasis, nbase_x, dim, 0, nband - 1);

this->cal_elem(dim, nbase, nbase_x, this->notconv, basis, this->hphi, this->sphi, this->hcc, this->scc);
this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hphi, this->sphi, this->hcc, this->scc);

this->diag_zhegvx(nbase, nband, this->hcc, this->scc, nbase_x, this->eigenvalue, this->vcc);

Expand All @@ -223,14 +218,13 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
nbase,
nbase_x,
this->notconv,
basis,
this->hphi,
this->sphi,
this->vcc,
unconv.data(),
this->eigenvalue);

this->cal_elem(dim, nbase, nbase_x, this->notconv, basis, this->hphi, this->sphi, this->hcc, this->scc);
this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hphi, this->sphi, this->hcc, this->scc);

this->diag_zhegvx(nbase, nband, this->hcc, this->scc, nbase_x, this->eigenvalue, this->vcc);

Expand Down Expand Up @@ -265,16 +259,16 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
gemm_op<T, Device>()(this->ctx,
'N',
'N',
dim, // m: row of A,C
dim, // m: row of A,C
nband, // n: col of B,C
nbase, // k: col of A, row of B
nbase, // k: col of A, row of B
this->one,
pbasis, // basis.get_pointer(), // A dim * nbase
pbasis, // A dim * nbase
dim,
this->vcc, // B nbase * n_band
this->vcc, // B nbase * n_band
nbase_x,
this->zero,
psi_in, // C dim * n_band
psi_in, // C dim * n_band
ldPsi
);

Expand All @@ -298,7 +292,6 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
eigenvalue_in,
psi_in, //psi,
ldPsi,
basis,
this->hphi,
this->sphi,
this->hcc,
Expand All @@ -317,13 +310,12 @@ int DiagoDavid<T, Device>::diag_mock(const HPsiFunc& hpsi_func,
}

template <typename T, typename Device>
void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func, // hamilt::Hamilt<T, Device>* phm_in,
void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int& dim,
const int& nbase, // current dimension of the reduced basis
const int nbase_x, // maximum dimension of the reduced basis set
const int& nbase, // current dimension of the reduced basis
const int nbase_x, // maximum dimension of the reduced basis set
const int& notconv,
psi::Psi<T, Device>& basis,
T* hphi,
T* sphi,
const T* vcc,
Expand Down Expand Up @@ -525,7 +517,6 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func, // hamilt::Hamil
this->SchmitOrth(dim,
nbase + notconv,
nbase + m,
basis,
sphi,
&lagrange[m * (nbase + notconv)],
pre_matrix_mm_m[m],
Expand Down Expand Up @@ -559,10 +550,9 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func, // hamilt::Hamil

template <typename T, typename Device>
void DiagoDavid<T, Device>::cal_elem(const int& dim,
int& nbase, // current dimension of the reduced basis
const int nbase_x, // maximum dimension of the reduced basis set
const int& notconv, // number of newly added basis vectors
const psi::Psi<T, Device>& basis,
int& nbase, // current dimension of the reduced basis
const int nbase_x, // maximum dimension of the reduced basis set
const int& notconv, // number of newly added basis vectors
const T* hphi,
const T* sphi,
T* hcc,
Expand Down Expand Up @@ -718,9 +708,8 @@ void DiagoDavid<T, Device>::refresh(const int& dim,
int& nbase,
const int nbase_x, // maximum dimension of the reduced basis set
const Real* eigenvalue_in,
const T *psi_in, // const psi::Psi<T, Device>& psi,
const T *psi_in,
const int ldPsi,
psi::Psi<T, Device>& basis,
T* hp,
T* sp,
T* hc,
Expand Down Expand Up @@ -861,7 +850,6 @@ template <typename T, typename Device>
void DiagoDavid<T, Device>::SchmitOrth(const int& dim,
const int nband,
const int m,
psi::Psi<T, Device>& basis,
const T* sphi,
T* lagrange_m,
const int mm_size,
Expand Down Expand Up @@ -1079,7 +1067,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
const int dim,
const int nband,
const int ldPsi,
psi::Psi<T, Device>& psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const int david_maxiter,
Expand All @@ -1101,7 +1089,7 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
int sum_dav_iter = 0;
do
{
sum_dav_iter += this->diag_mock(hpsi_func, spsi_func, dim, nband, ldPsi, psi, eigenvalue_in, david_diag_thr, david_maxiter);
sum_dav_iter += this->diag_mock(hpsi_func, spsi_func, dim, nband, ldPsi, psi_in, eigenvalue_in, david_diag_thr, david_maxiter);
++ntry;
} while (!check_block_conv(ntry, this->notconv, ntry_max, notconv_max));

Expand Down
41 changes: 18 additions & 23 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DiagoDavid : public DiagH<T, Device>
const int dim, // Dimension of the input matrix psi to be diagonalized
const int nband, // Number of required eigenpairs
const int ldPsi, // Leading dimension of the psi input
psi::Psi<T, Device>& psi, // Reference to the wavefunction object for eigenvectors
T *psi_in, // Pointer to eigenvectors
Real* eigenvalue_in, // Pointer to store the resulting eigenvalues
const Real david_diag_thr, // Convergence threshold for the Davidson iteration
const int david_maxiter, // Maximum allowed iterations for the Davidson method
Expand All @@ -56,25 +56,24 @@ class DiagoDavid : public DiagH<T, Device>
/// number of unconverged eigenvalues
int notconv = 0;

/// precondition for diag, diagonal approximation of
/// matrix A (i.e. Hamilt)
/// precondition for diag, diagonal approximation of matrix A(i.e. Hamilt)
const Real* precondition = nullptr;
Real* d_precondition = nullptr;

/// eigenvalue results
Real* eigenvalue = nullptr;

T *pbasis = nullptr; // basis set
T *pbasis = nullptr; /// pointer to basis set(dim, nbase_x), leading dimension = dim

T* hphi = nullptr; // the product of H and psi in the reduced basis set
T* hphi = nullptr; /// the product of H and psi in the reduced basis set

T* sphi = nullptr; // the Product of S and psi in the reduced basis set
T* sphi = nullptr; /// the Product of S and psi in the reduced basis set

T* hcc = nullptr; // Hamiltonian on the reduced basis
T* hcc = nullptr; /// Hamiltonian on the reduced basis

T* scc = nullptr; // Overlap on the reduced basis
T* scc = nullptr; /// Overlap on the reduced basis

T* vcc = nullptr; // Eigenvectors of hc
T* vcc = nullptr; /// Eigenvectors of hc

T* lagrange_matrix = nullptr;

Expand All @@ -83,13 +82,22 @@ class DiagoDavid : public DiagH<T, Device>
base_device::DEVICE_CPU* cpu_ctx = {};
base_device::AbacusDevice_t device = {};

int diag_mock(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int dim,
const int nband,
const int ldPsi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const int david_maxiter);

void cal_grad(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int& dim,
const int& nbase,
const int nbase_x,
const int& notconv,
psi::Psi<T, Device>& basis,
T* hphi,
T* sphi,
const T* vcc,
Expand All @@ -100,7 +108,6 @@ class DiagoDavid : public DiagH<T, Device>
int& nbase,
const int nbase_x,
const int& notconv,
const psi::Psi<T, Device>& basis,
const T* hphi,
const T* sphi,
T* hcc,
Expand All @@ -113,7 +120,6 @@ class DiagoDavid : public DiagH<T, Device>
const Real* eigenvalue,
const T *psi_in,
const int ldPsi,
psi::Psi<T, Device>& basis,
T* hphi,
T* sphi,
T* hcc,
Expand All @@ -123,7 +129,6 @@ class DiagoDavid : public DiagH<T, Device>
void SchmitOrth(const int& dim,
const int nband,
const int m,
psi::Psi<T, Device>& basis,
const T* sphi,
T* lagrange_m,
const int mm_size,
Expand All @@ -139,16 +144,6 @@ class DiagoDavid : public DiagH<T, Device>
Real* eigenvalue,
T* vcc);

int diag_mock(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int dim,
const int nband,
const int ldPsi,
psi::Psi<T, Device>& psi,
Real* eigenvalue_in,
const Real david_diag_thr,
const int david_maxiter);

bool check_block_conv(const int &ntry, const int &notconv, const int &ntry_max, const int &notconv_max);

using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
DiagoDavid<T, Device> david(pre_condition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
// do diag and add davidson iteration counts up to avg_iter
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
david.diag(hpsi_func, spsi_func, dim, nband, ldPsi, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max));
david.diag(hpsi_func, spsi_func, dim, nband, ldPsi, psi.get_pointer(), eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max));
}
return;
}
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class DiagoDavPrepare
auto spsi_func = [phm](const std::complex<float>* psi_in, std::complex<float>* spsi_out,const int nrow, const int npw, const int nbands){
phm->sPsi(psi_in, spsi_out, nrow, npw, nbands);
};
dav.diag(hpsi_func,spsi_func, dim, nband, ldPsi, phi, en, eps, maxiter);
dav.diag(hpsi_func,spsi_func, dim, nband, ldPsi, phi.get_pointer(), en, eps, maxiter);

#ifdef __MPI
end = MPI_Wtime();
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class DiagoDavPrepare
auto spsi_func = [phm](const double* psi_in, double* spsi_out,const int nrow, const int npw, const int nbands){
phm->sPsi(psi_in, spsi_out, nrow, npw, nbands);
};
dav.diag(hpsi_func,spsi_func, dim, nband, ldPsi, phi, en, eps, maxiter);
dav.diag(hpsi_func,spsi_func, dim, nband, ldPsi, phi.get_pointer(), en, eps, maxiter);

#ifdef __MPI
end = MPI_Wtime();
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class DiagoDavPrepare
auto spsi_func = [phm](const std::complex<double>* psi_in, std::complex<double>* spsi_out,const int nrow, const int npw, const int nbands){
phm->sPsi(psi_in, spsi_out, nrow, npw, nbands);
};
dav.diag(hpsi_func,spsi_func, dim, nband, ldPsi, phi, en, eps, maxiter);
dav.diag(hpsi_func,spsi_func, dim, nband, ldPsi, phi.get_pointer(), en, eps, maxiter);

#ifdef __MPI
end = MPI_Wtime();
Expand Down
21 changes: 11 additions & 10 deletions source/module_hsolver/test/hsolver_pw_sup.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,23 @@ int DiagoDavid<T, Device>::diag(const std::function<void(T*, T*, const int, cons
const int dim,
const int nband,
const int ldPsi,
psi::Psi<T, Device>& psi,
T *psi_in,
Real* eigenvalue_in,
const Real david_diag_thr,
const int david_maxiter,
const int ntry_max,
const int notconv_max) {
// do nothing, we dont need it
// do something
for (int ib = 0; ib < psi.get_nbands(); ib++) {
eigenvalue_in[ib] = 0.0;
for (int ig = 0; ig < psi.get_nbasis(); ig++) {
psi(ib, ig) += T(1.0, 0.0);
eigenvalue_in[ib] += psi(ib, ig).real();
}
eigenvalue_in[ib] /= psi.get_nbasis();
}
DiagoIterAssist<T, Device>::avg_iter += 1.0;
// for (int ib = 0; ib < psi.get_nbands(); ib++) {
// eigenvalue_in[ib] = 0.0;
// for (int ig = 0; ig < psi.get_nbasis(); ig++) {
// psi(ib, ig) += T(1.0, 0.0);
// eigenvalue_in[ib] += psi(ib, ig).real();
// }
// eigenvalue_in[ib] /= psi.get_nbasis();
// }
// DiagoIterAssist<T, Device>::avg_iter += 1.0;
return 1;
}
template class DiagoDavid<std::complex<float>, base_device::DEVICE_CPU>;
Expand Down
2 changes: 1 addition & 1 deletion source/module_lr/hsolver_lrtd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace LR
const int& nband = psi_k1_dav.get_nbands();
hsolver::DiagoDavid<T, Device> david(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
hsolver::DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(david.diag(hpsi_func, spsi_func,
dim, nband, dim, psi_k1_dav, eigenvalue.data(), this->diag_ethr, david_maxiter, ntry_max, 0/*notconv_max*/));
dim, nband, dim, psi_k1_dav.get_pointer(), eigenvalue.data(), this->diag_ethr, david_maxiter, ntry_max, 0/*notconv_max*/));
}
else if (this->method == "dav_subspace") //need refactor
{
Expand Down

0 comments on commit be8f40f

Please sign in to comment.