Skip to content

Commit

Permalink
Refactor: Separate LCAO-in-PW ESolver and HSolver from PW ones (deepm…
Browse files Browse the repository at this point in the history
…odeling#4536)

* LIP ESolver & HSolver

minor fixes

* release psig_ in PW after initialized

fix lip after release psig_

fix test build

* rename

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
3 people authored Jul 8, 2024
1 parent 7b9f47d commit f8e41d4
Show file tree
Hide file tree
Showing 20 changed files with 576 additions and 348 deletions.
2 changes: 2 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ OBJS_ESOLVER=esolver.o\
esolver_ks.o\
esolver_fp.o\
esolver_ks_pw.o\
esolver_ks_lcaopw.o\
esolver_sdft_pw.o\
esolver_lj.o\
esolver_dp.o\
Expand Down Expand Up @@ -299,6 +300,7 @@ OBJS_HSOLVER=diago_cg.o\
diago_dav_subspace.o\
diago_bpcg.o\
hsolver_pw.o\
hsolver_lcaopw.o\
hsolver_pw_sdft.o\
diago_iter_assist.o\
math_kernel_op.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_esolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ list(APPEND objects
esolver_ks.cpp
esolver_fp.cpp
esolver_ks_pw.cpp
esolver_ks_lcaopw.cpp
esolver_sdft_pw.cpp
esolver_lj.cpp
esolver_dp.cpp
Expand Down
24 changes: 18 additions & 6 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "esolver_ks_pw.h"
#include "esolver_sdft_pw.h"
#ifdef __LCAO
#include "esolver_ks_lcaopw.h"
#include "esolver_ks_lcao.h"
#include "esolver_ks_lcao_tddft.h"
#endif
Expand All @@ -15,12 +16,12 @@
namespace ModuleESolver
{

void ESolver::printname(void)
void ESolver::printname()
{
std::cout << classname << std::endl;
}

std::string determine_type(void)
std::string determine_type()
{
std::string esolver_type = "none";
if (GlobalV::BASIS_TYPE == "pw")
Expand All @@ -47,10 +48,10 @@ std::string determine_type(void)
}
else if(GlobalV::ESOLVER_TYPE == "ksdft")
{
esolver_type = "ksdft_pw";
esolver_type = "ksdft_lip";
}
#else
ModuleBase::WARNING_QUIT("ESolver", "LCAO basis type must be compiled with __LCAO");
ModuleBase::WARNING_QUIT("ESolver", "Calculation involving numerical orbitals must be compiled with __LCAO");
#endif
}
else if (GlobalV::BASIS_TYPE == "lcao")
Expand All @@ -65,7 +66,7 @@ std::string determine_type(void)
esolver_type = "ksdft_lcao";
}
#else
ModuleBase::WARNING_QUIT("ESolver", "LCAO basis type must be compiled with __LCAO");
ModuleBase::WARNING_QUIT("ESolver", "Calculation involving numerical orbitals must be compiled with __LCAO");
#endif
}

Expand Down Expand Up @@ -138,7 +139,18 @@ ESolver* init_esolver()
}
}
#ifdef __LCAO
else if (esolver_type == "ksdft_lcao")
else if (esolver_type == "ksdft_lip")
{
if (GlobalV::precision_flag == "single")
{
p_esolver = new ESolver_KS_LIP<std::complex<float>>();
}
else
{
p_esolver = new ESolver_KS_LIP<std::complex<double>>();
}
}
else if (esolver_type == "ksdft_lcao")
{
if (GlobalV::GAMMA_ONLY_LOCAL)
{
Expand Down
153 changes: 153 additions & 0 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#include "esolver_ks_lcaopw.h"

#include "module_hamilt_pw/hamilt_pwdft/elecond.h"
#include "module_io/input_conv.h"
#include "module_io/nscf_band.h"
#include "module_io/output_log.h"
#include "module_io/write_dos_pw.h"
#include "module_io/write_istate_info.h"
#include "module_io/write_wfc_pw.h"

#include <iostream>

//--------------temporary----------------------------
#include "module_elecstate/module_charge/symmetry_rho.h"
#include "module_elecstate/occupy.h"
#include "module_hamilt_general/module_ewald/H_Ewald_pw.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/print_info.h"
//-----force-------------------
#include "module_hamilt_pw/hamilt_pwdft/forces.h"
//-----stress------------------
#include "module_hamilt_pw/hamilt_pwdft/stress_pw.h"
//---------------------------------------------------
#include "module_base/memory.h"
#include "module_elecstate/elecstate_pw.h"
#include "module_hamilt_pw/hamilt_pwdft/hamilt_pw.h"
#include "module_hsolver/diago_iter_assist.h"
#include "module_hsolver/hsolver_lcaopw.h"
#include "module_hsolver/kernels/dngvd_op.h"
#include "module_hsolver/kernels/math_kernel_op.h"
#include "module_io/berryphase.h"
#include "module_io/numerical_basis.h"
#include "module_io/numerical_descriptor.h"
#include "module_io/rho_io.h"
#include "module_io/to_wannier90_pw.h"
#include "module_io/winput.h"
#include "module_io/write_pot.h"
#include "module_io/write_wfc_r.h"

#include <ATen/kernels/blas.h>
#include <ATen/kernels/lapack.h>

namespace ModuleESolver
{

template <typename T>
ESolver_KS_LIP<T>::ESolver_KS_LIP()
{
this->classname = "ESolver_KS_LIP";
this->basisname = "LIP";
}

template <typename T>
void ESolver_KS_LIP<T>::allocate_hsolver()
{
this->phsol = new hsolver::HSolverLIP<T>(this->pw_wfc);
}
template <typename T>
void ESolver_KS_LIP<T>::deallocate_hsolver()
{
if (this->phsol != nullptr)
{
delete reinterpret_cast<hsolver::HSolverLIP<T>*>(this->phsol);
this->phsol = nullptr;
}
}
template <typename T>
void ESolver_KS_LIP<T>::hamilt2density(const int istep, const int iter, const double ethr)
{
ModuleBase::TITLE("ESolver_KS_LIP", "hamilt2density");
ModuleBase::timer::tick("ESolver_KS_LIP", "hamilt2density");

if (this->phsol != nullptr)
{
// reset energy
this->pelec->f_en.eband = 0.0;
this->pelec->f_en.demet = 0.0;
// choose if psi should be diag in subspace
// be careful that istep start from 0 and iter start from 1
// if (iter == 1)
hsolver::DiagoIterAssist<T>::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true;
hsolver::DiagoIterAssist<T>::SCF_ITER = iter;
hsolver::DiagoIterAssist<T>::PW_DIAG_THR = ethr;
hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX;

// It is not a good choice to overload another solve function here, this will spoil the concept of
// multiple inheritance and polymorphism. But for now, we just do it in this way.
// In the future, there will be a series of class ESolver_KS_LCAO_PW, HSolver_LCAO_PW and so on.
std::weak_ptr<psi::Psi<T>> psig = this->p_wf_init->get_psig();

if (psig.expired())
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density", "psig lifetime is expired");
}

// from HSolverLIP
this->phsol->solve(this->p_hamilt, // hamilt::Hamilt<T>* pHamilt,
this->kspw_psi[0], // psi::Psi<T>& psi,
this->pelec, // elecstate::ElecState<T>* pelec,
psig.lock().get()[0]); // psi::Psi<T>& transform,

if (GlobalV::out_bandgap)
{
if (!GlobalV::TWO_EFERMI)
{
this->pelec->cal_bandgap();
}
else
{
this->pelec->cal_bandgap_updw();
}
}
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_LIP", "HSolver has not been allocated.");
}
// add exx
#ifdef __LCAO
#ifdef __EXX
this->pelec->set_exx(GlobalC::exx_lip.get_exx_energy()); // Peize Lin add 2019-03-09
#endif
#endif

// calculate the delta_harris energy
// according to new charge density.
// mohan add 2009-01-23
this->pelec->cal_energies(1);

Symmetry_rho srho;
for (int is = 0; is < GlobalV::NSPIN; is++)
{
srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::Pgrid, GlobalC::ucell.symm);
}

// compute magnetization, only for LSDA(spin==2)
GlobalC::ucell.magnet.compute_magnetization(this->pelec->charge->nrxx,
this->pelec->charge->nxyz,
this->pelec->charge->rho,
this->pelec->nelec_spin.data());

// deband is calculated from "output" charge density calculated
// in sum_band
// need 'rho(out)' and 'vr (v_h(in) and v_xc(in))'
this->pelec->f_en.deband = this->pelec->cal_delta_eband();

ModuleBase::timer::tick("ESolver_KS_LIP", "hamilt2density");
}

template class ESolver_KS_LIP<std::complex<float>>;
template class ESolver_KS_LIP<std::complex<double>>;
// LIP is not supported on GPU yet.
} // namespace ModuleESolver
29 changes: 29 additions & 0 deletions source/module_esolver/esolver_ks_lcaopw.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef ESOLVER_KS_LIP_H
#define ESOLVER_KS_LIP_H
#include "module_esolver/esolver_ks_pw.h"
#include "module_hsolver/hsolver_lcaopw.h"
namespace ModuleESolver
{

template <typename T>
class ESolver_KS_LIP : public ESolver_KS_PW<T, base_device::DEVICE_CPU>
{
private:
using Real = typename GetTypeReal<T>::type;

public:
ESolver_KS_LIP();

~ESolver_KS_LIP() = default;

/// All the other interfaces except this one are the same as ESolver_KS_PW.
virtual void hamilt2density(const int istep, const int iter, const double ethr) override;

protected:

virtual void allocate_hsolver() override;
virtual void deallocate_hsolver() override;

};
} // namespace ModuleESolver
#endif
76 changes: 27 additions & 49 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,10 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW() {
template <typename T, typename Device>
ESolver_KS_PW<T, Device>::~ESolver_KS_PW() {
// delete HSolver and ElecState
if (this->phsol != nullptr) {
delete reinterpret_cast<hsolver::HSolverPW<T, Device>*>(this->phsol);
this->phsol = nullptr;
}
if (this->pelec != nullptr) {
delete reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(
this->pelec);
this->deallocate_hsolver();
if (this->pelec != nullptr)
{
delete reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(this->pelec);
this->pelec = nullptr;
}
// delete Hamilt
Expand Down Expand Up @@ -155,9 +152,9 @@ void ESolver_KS_PW<T, Device>::before_all_runners(Input& inp, UnitCell& ucell) {
ESolver_KS<T, Device>::before_all_runners(inp, ucell);

// 2) initialize HSolver
if (this->phsol == nullptr) {
this->phsol
= new hsolver::HSolverPW<T, Device>(this->pw_wfc, &this->wf);
if (this->phsol == nullptr)
{
this->allocate_hsolver();
}

// 3) initialize ElecState,
Expand Down Expand Up @@ -215,7 +212,17 @@ void ESolver_KS_PW<T, Device>::before_all_runners(Input& inp, UnitCell& ucell) {
GlobalV::nelec);
}
}

template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::allocate_hsolver()
{
this->phsol = new hsolver::HSolverPW<T, Device>(this->pw_wfc, &this->wf);
}
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::deallocate_hsolver()
{
delete reinterpret_cast<hsolver::HSolverPW<T, Device>*>(this->phsol);
this->phsol = nullptr;
}
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::init_after_vc(Input& inp, UnitCell& ucell) {
ModuleBase::TITLE("ESolver_KS_PW", "init_after_vc");
Expand Down Expand Up @@ -259,8 +266,7 @@ void ESolver_KS_PW<T, Device>::init_after_vc(Input& inp, UnitCell& ucell) {
inp.erf_sigma);

delete this->phsol;
this->phsol
= new hsolver::HSolverPW<T, Device>(this->pw_wfc, &this->wf);
this->allocate_hsolver();

delete this->pelec;
this->pelec
Expand Down Expand Up @@ -576,36 +582,15 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX
= GlobalV::PW_DIAG_NMAX;

if (GlobalV::BASIS_TYPE != "lcao_in_pw") {
// from HSolverPW
this->phsol->solve(
this->p_hamilt, // hamilt::Hamilt<T, Device>* pHamilt,
this->kspw_psi[0], // psi::Psi<T, Device>& psi,
this->pelec, // elecstate::ElecState<T, Device>* pelec,
GlobalV::KS_SOLVER); // const std::string method_in,
} else {
// It is not a good choice to overload another solve function here,
// this will spoil the concept of multiple inheritance and
// polymorphism. But for now, we just do it in this way. In the
// future, there will be a series of class ESolver_KS_LCAO_PW,
// HSolver_LCAO_PW and so on.
std::weak_ptr<psi::Psi<T, Device>> psig
= this->p_wf_init->get_psig();

if (psig.expired()) {
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density",
"psig lifetime is expired");
}
this->phsol->solve(this->p_hamilt, // hamilt::Hamilt<T, Device>* pHamilt,
this->kspw_psi[0], // psi::Psi<T, Device>& psi,
this->pelec, // elecstate::ElecState<T, Device>* pelec,
GlobalV::KS_SOLVER); // const std::string method_in,

// from HSolverPW
this->phsol->solve(
this->p_hamilt, // hamilt::Hamilt<T, Device>* pHamilt,
this->kspw_psi[0], // psi::Psi<T, Device>& psi,
this->pelec, // elecstate::ElecState<T, Device>* pelec,
psig.lock().get()[0]); // psi::Psi<T, Device>& transform,
}
if (GlobalV::out_bandgap) {
if (!GlobalV::TWO_EFERMI) {
if (GlobalV::out_bandgap)
{
if (!GlobalV::TWO_EFERMI)
{
this->pelec->cal_bandgap();
} else {
this->pelec->cal_bandgap_updw();
Expand All @@ -615,13 +600,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep,
ModuleBase::WARNING_QUIT("ESolver_KS_PW",
"HSolver has not been initialed!");
}
// add exx
#ifdef __LCAO
#ifdef __EXX
this->pelec->set_exx(
GlobalC::exx_lip.get_exx_energy()); // Peize Lin add 2019-03-09
#endif
#endif

// calculate the delta_harris energy
// according to new charge density.
Expand Down
Loading

0 comments on commit f8e41d4

Please sign in to comment.