Skip to content

Commit

Permalink
Remove DiagH Base class in pw (deepmodeling#5225)
Browse files Browse the repository at this point in the history
* Remove DiagoDavid Base class

* Remove Diago_DavSubspace Base class

* Remove DiagoCG Base class

* Remove DiagoBPCG Base class

* Remove DiagoBPCG Base class override

* Update docs

* Remove DiagH_mock in test_hsolver

* Remove DiagH_mock in test_hsolver
  • Loading branch information
Cstandardlib authored Oct 14, 2024
1 parent e9250db commit 0f0c838
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 37 deletions.
8 changes: 4 additions & 4 deletions source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace hsolver {
* @tparam Device The device used for calculations (e.g., cpu or gpu).
*/
template <typename T = std::complex<double>, typename Device = base_device::DEVICE_CPU>
class DiagoBPCG : public DiagH<T, Device>
class DiagoBPCG
{
private:
// Note GetTypeReal<T>::type will
Expand Down Expand Up @@ -56,15 +56,15 @@ class DiagoBPCG : public DiagH<T, Device>
void init_iter(const psi::Psi<T, Device> &psi_in);

/**
* @brief Diagonalize the Hamiltonian using the CG method.
* @brief Diagonalize the Hamiltonian using the BPCG method.
*
* This function is an override function for the CG method. It is called by the HsolverPW::solve() function.
* This function is called by the HsolverPW::solve() function.
*
* @param phm_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator.
* @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major].
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
*/
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in) override;
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in);


private:
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_cg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace hsolver {

template <typename T, typename Device = base_device::DEVICE_CPU>
class DiagoCG final : public DiagH<T, Device>
class DiagoCG final
{
// private: accessibility within class is private by default
// Note GetTypeReal<T>::type will
Expand All @@ -36,11 +36,11 @@ class DiagoCG final : public DiagH<T, Device>
const int& pw_diag_nmax,
const int& nproc_in_pool);

~DiagoCG() override;
~DiagoCG();

// virtual void init(){};
// refactor hpsi_info
// this is the override function diag() for CG method
// this is the diag() function for CG method
void diag(const Func& hpsi_func, const Func& spsi_func, ct::Tensor& psi, ct::Tensor& eigen, const ct::Tensor& prec = {});

private:
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace hsolver
{

template <typename T = std::complex<double>, typename Device = base_device::DEVICE_CPU>
class Diago_DavSubspace : public DiagH<T, Device>
class Diago_DavSubspace
{
private:
// Note GetTypeReal<T>::type will
Expand All @@ -29,7 +29,7 @@ class Diago_DavSubspace : public DiagH<T, Device>
const bool& need_subspace_in,
const diag_comm_info& diag_comm_in);

virtual ~Diago_DavSubspace() override;
~Diago_DavSubspace();

// See diago_david.h for information on the HPsiFunc function type
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace hsolver
{

template <typename T = std::complex<double>, typename Device = base_device::DEVICE_CPU>
class DiagoDavid : public DiagH<T, Device>
class DiagoDavid
{
private:
// Note GetTypeReal<T>::type will
Expand All @@ -25,7 +25,7 @@ class DiagoDavid : public DiagH<T, Device>
const bool use_paw_in,
const diag_comm_info& diag_comm_in);

virtual ~DiagoDavid() override;
~DiagoDavid();


// declare type of matrix-blockvector functions.
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!");
}

// prepare for the precondition of diagonalization
Expand Down
26 changes: 1 addition & 25 deletions source/module_hsolver/test/test_hsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,4 @@ class TestHSolver : public ::testing::Test

// double test_diagethr_d = hs_d.set_diagethr(0.0, 0, 0, 0.0);
// EXPECT_EQ(test_diagethr_d, 0.0);
// }
namespace hsolver
{
template <typename T, typename Device = base_device::DEVICE_CPU>
class DiagH_mock : public DiagH<T, Device>
{
private:
using Real = typename GetTypeReal<T>::type;

public:
DiagH_mock()
{
}
~DiagH_mock()
{
}

void diag(hamilt::Hamilt<T, Device>* phm_in, psi::Psi<T, Device>& psi, Real* eigenvalue_in)
{
return;
}
};
template class DiagH_mock<std::complex<float>>;
template class DiagH_mock<std::complex<double>>;
}
// }

0 comments on commit 0f0c838

Please sign in to comment.