Skip to content

Commit

Permalink
add exx_lip back
Browse files Browse the repository at this point in the history
use vector instead of pointers

use unique_ptr to enable const-reference constructors
  • Loading branch information
maki49 committed Jul 9, 2024
1 parent 52a9600 commit 1f49762
Show file tree
Hide file tree
Showing 19 changed files with 690 additions and 792 deletions.
14 changes: 13 additions & 1 deletion source/module_base/lapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ extern "C"
void zherk_(const char *uplo, const char *trans, const int *n, const int *k,
const double *alpha, const std::complex<double> *A, const int *lda,
const double *beta, std::complex<double> *C, const int *ldc);
void cherk_(const char* uplo, const char* trans, const int* n, const int* k,
const float* alpha, const std::complex<float>* A, const int* lda,
const float* beta, std::complex<float>* C, const int* ldc);

// computes all eigenvalues of a symmetric tridiagonal matrix
// using the Pal-Walker-Kahan variant of the QL or QR algorithm.
Expand Down Expand Up @@ -436,13 +439,22 @@ class LapackConnector
// if trans=='N': C = a * A * A.H + b * C
// if trans=='C': C = a * A.H * A + b * C
static inline
void zherk(const char uplo, const char trans, const int n, const int k,
void herk(const char uplo, const char trans, const int n, const int k,
const double alpha, const std::complex<double> *A, const int lda,
const double beta, std::complex<double> *C, const int ldc)
{
const char uplo_changed = change_uplo(uplo);
const char trans_changed = change_trans_NC(trans);
zherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
}
static inline
void herk(const char uplo, const char trans, const int n, const int k,
const float alpha, const std::complex<float>* A, const int lda,
const float beta, std::complex<float>* C, const int ldc)
{
const char uplo_changed = change_uplo(uplo);
const char trans_changed = change_trans_NC(trans);
cherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
}
};
#endif // LAPACKCONNECTOR_HPP
17 changes: 17 additions & 0 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,23 @@ ModuleIO::Output_Potential ESolver_KS<T, Device>::create_Output_Potential(int it
//! the 16th-20th functions of ESolver_KS
//! mohan add 2024-05-12
//------------------------------------------------------------------------------
template <typename T, typename Device>
void ESolver_KS<T, Device>::set_xc_first_loop(const UnitCell& ucell)
{
/** In the special "two-level" calculation case,
the first scf iteration only calculate the functional without exact
exchange. but in "nscf" calculation, there is no need of "two-level"
method. */
if (ucell.atoms[0].ncpp.xc_func == "HF"
|| ucell.atoms[0].ncpp.xc_func == "PBE0"
|| ucell.atoms[0].ncpp.xc_func == "HSE") {
XC_Functional::set_xc_type("pbe");
}
else if (ucell.atoms[0].ncpp.xc_func == "SCAN0") {
XC_Functional::set_xc_type("scan");
}
}

//! This is for mixed-precision pw/LCAO basis sets.
template class ESolver_KS<std::complex<float>, base_device::DEVICE_CPU>;
template class ESolver_KS<std::complex<double>, base_device::DEVICE_CPU>;
Expand Down
6 changes: 4 additions & 2 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ class ESolver_KS : public ESolver_FP

std::string basisname; //PW or LCAO

void print_wfcfft(Input& inp, std::ofstream &ofs);
};
void print_wfcfft(Input& inp, std::ofstream& ofs);

virtual void set_xc_first_loop(const UnitCell& ucell);
};
} // end of namespace
#endif
28 changes: 6 additions & 22 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,28 +185,12 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
#ifdef __EXX
// 7) initialize exx
// PLEASE simplify the Exx_Global interface
if (GlobalV::CALCULATION == "scf" || GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax"
|| GlobalV::CALCULATION == "md")
{
if (GlobalC::exx_info.info_global.cal_exx)
{
/* In the special "two-level" calculation case,
first scf iteration only calculate the functional without exact
exchange. but in "nscf" calculation, there is no need of "two-level"
method. */
if (ucell.atoms[0].ncpp.xc_func == "HF" || ucell.atoms[0].ncpp.xc_func == "PBE0"
|| ucell.atoms[0].ncpp.xc_func == "HSE")
{
XC_Functional::set_xc_type("pbe");
}
else if (ucell.atoms[0].ncpp.xc_func == "SCAN0")
{
XC_Functional::set_xc_type("scan");
}

// GlobalC::exx_lcao.init();
if (GlobalC::exx_info.info_ri.real_number)
{
if (GlobalV::CALCULATION == "scf" || GlobalV::CALCULATION == "relax"
|| GlobalV::CALCULATION == "cell-relax"
|| GlobalV::CALCULATION == "md") {
if (GlobalC::exx_info.info_global.cal_exx) {
this->set_xc_first_loop(ucell);
if (GlobalC::exx_info.info_ri.real_number) {
this->exx_lri_double->init(MPI_COMM_WORLD, this->kv);
}
else
Expand Down
105 changes: 102 additions & 3 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
#include <ATen/kernels/blas.h>
#include <ATen/kernels/lapack.h>

#include <sys/time.h>
#include "module_hamilt_pw/hamilt_pwdft/hamilt_lcaopw.h"

namespace ModuleESolver
{

Expand All @@ -64,6 +67,44 @@ namespace ModuleESolver
this->phsol = nullptr;
}
}
template <typename T>
void ESolver_KS_LIP<T>::allocate_hamilt()
{
this->p_hamilt = new hamilt::HamiltLIP<T>(this->pelec->pot, this->pw_wfc, &this->kv
#ifdef __EXX
, *this->exx_lip
#endif
);
}

template <typename T>
void ESolver_KS_LIP<T>::before_all_runners(Input& inp, UnitCell& cell)
{
ESolver_KS_PW<T>::before_all_runners(inp, cell);
#ifdef __EXX
if (GlobalV::CALCULATION == "scf" || GlobalV::CALCULATION == "relax"
|| GlobalV::CALCULATION == "cell-relax"
|| GlobalV::CALCULATION == "md")
if (GlobalC::exx_info.info_global.cal_exx)
{
this->set_xc_first_loop(cell);
this->exx_lip = std::unique_ptr<Exx_Lip<T>>(new Exx_Lip<T>(GlobalC::exx_info.info_lip,
cell.symm, &this->kv, this->p_wf_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec));
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_wf_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
}
#endif
}

template <typename T>
void ESolver_KS_LIP<T>::iter_init(const int istep, const int iter)
{
ESolver_KS_PW<T>::iter_init(istep, iter);
#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx && !GlobalC::exx_info.info_global.separate_loop && this->two_level_step)
this->exx_lip->cal_exx();
#endif
}

template <typename T>
void ESolver_KS_LIP<T>::hamilt2density(const int istep, const int iter, const double ethr)
{
Expand Down Expand Up @@ -116,10 +157,9 @@ namespace ModuleESolver
ModuleBase::WARNING_QUIT("ESolver_KS_LIP", "HSolver has not been allocated.");
}
// add exx
#ifdef __LCAO
#ifdef __EXX
this->pelec->set_exx(GlobalC::exx_lip.get_exx_energy()); // Peize Lin add 2019-03-09
#endif
if (GlobalC::exx_info.info_global.cal_exx)
this->pelec->set_exx(this->exx_lip->get_exx_energy()); // Peize Lin add 2019-03-09
#endif

// calculate the delta_harris energy
Expand Down Expand Up @@ -147,6 +187,65 @@ namespace ModuleESolver
ModuleBase::timer::tick("ESolver_KS_LIP", "hamilt2density");
}

#ifdef __EXX
template <typename T>
bool ESolver_KS_LIP<T>::do_after_converge(int& iter)
{
if (GlobalC::exx_info.info_global.cal_exx)
{
// no separate_loop case
if (!GlobalC::exx_info.info_global.separate_loop)
{
GlobalC::exx_info.info_global.hybrid_step = 1;

// in no_separate_loop case, scf loop only did twice
// in first scf loop, exx updated once in beginning,
// in second scf loop, exx updated every iter

if (this->two_level_step)
return true;
else
{
// update exx and redo scf
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].ncpp.xc_func);
iter = 0;
std::cout << " Entering 2nd SCF, where EXX is updated" << std::endl;
this->two_level_step++;
return false;
}
}
// has separate_loop case
// exx converged or get max exx steps
else if (this->two_level_step == GlobalC::exx_info.info_global.hybrid_step
|| (iter == 1 && this->two_level_step != 0))
return true;
else
{
// update exx and redo scf
if (this->two_level_step == 0)
{
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].ncpp.xc_func);
}

std::cout << " Updating EXX " << std::flush;
timeval t_start; gettimeofday(&t_start, NULL);

this->exx_lip->cal_exx();
iter = 0;
this->two_level_step++;

timeval t_end; gettimeofday(&t_end, NULL);
std::cout << "and rerun SCF\t"
<< std::setprecision(3) << std::setiosflags(std::ios::scientific)
<< (double)(t_end.tv_sec - t_start.tv_sec) + (double)(t_end.tv_usec - t_start.tv_usec) / 1000000.0
<< std::defaultfloat << " (s)" << std::endl;
return false;
}
}
else { return true; }
}
#endif

template class ESolver_KS_LIP<std::complex<float>>;
template class ESolver_KS_LIP<std::complex<double>>;
// LIP is not supported on GPU yet.
Expand Down
14 changes: 14 additions & 0 deletions source/module_esolver/esolver_ks_lcaopw.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#define ESOLVER_KS_LIP_H
#include "module_esolver/esolver_ks_pw.h"
#include "module_hsolver/hsolver_lcaopw.h"

#ifdef __EXX
#include "module_ri/exx_lip.h"
#endif
namespace ModuleESolver
{

Expand All @@ -19,10 +23,20 @@ namespace ModuleESolver
/// All the other interfaces except this one are the same as ESolver_KS_PW.
virtual void hamilt2density(const int istep, const int iter, const double ethr) override;

void before_all_runners(Input& inp, UnitCell& cell) override;
void iter_init(const int istep, const int iter) override;

protected:

virtual void allocate_hsolver() override;
virtual void deallocate_hsolver() override;
virtual void allocate_hamilt() override;

#ifdef __EXX
std::unique_ptr<Exx_Lip<T>> exx_lip;
int two_level_step = 0;
virtual bool do_after_converge(int& iter) override;
#endif

};
} // namespace ModuleESolver
Expand Down
27 changes: 14 additions & 13 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW() {
this->pelec = nullptr;
}
// delete Hamilt
if (this->p_hamilt != nullptr) {
delete reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
this->p_hamilt = nullptr;
}
this->deallocate_hamilt();
if (this->device == base_device::GpuDevice) {
#if defined(__CUDA) || defined(__ROCM)
hsolver::destoryBLAShandle();
Expand Down Expand Up @@ -224,6 +221,17 @@ void ESolver_KS_PW<T, Device>::deallocate_hsolver()
this->phsol = nullptr;
}
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::allocate_hamilt()
{
this->p_hamilt = new hamilt::HamiltPW<T, Device>(this->pelec->pot, this->pw_wfc, &this->kv);
}
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::deallocate_hamilt()
{
delete reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
this->p_hamilt = nullptr;
}
template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::init_after_vc(Input& inp, UnitCell& ucell) {
ModuleBase::TITLE("ESolver_KS_PW", "init_after_vc");
ModuleBase::timer::tick("ESolver_KS_PW", "init_after_vc");
Expand Down Expand Up @@ -394,17 +402,10 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep) {
// init Hamilt, this should be allocated before each scf loop
// Operators in HamiltPW should be reallocated once cell changed
// delete Hamilt if not first scf
if (this->p_hamilt != nullptr) {
delete reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
this->p_hamilt = nullptr;
}
this->deallocate_hamilt();

// allocate HamiltPW
if (this->p_hamilt == nullptr) {
this->p_hamilt = new hamilt::HamiltPW<T, Device>(this->pelec->pot,
this->pw_wfc,
&this->kv);
}
this->allocate_hamilt();

//----------------------------------------------------------
// about vdw, jiyy add vdwd3 and linpz add vdwd2
Expand Down
2 changes: 2 additions & 0 deletions source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>

virtual void allocate_hsolver();
virtual void deallocate_hsolver();
virtual void allocate_hamilt();
virtual void deallocate_hamilt();

//! hide the psi in ESolver_KS for tmp use
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi = nullptr;
Expand Down
7 changes: 3 additions & 4 deletions source/module_hamilt_general/module_xc/exx_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ struct Exx_Info
{
const Conv_Coulomb_Pot_K::Ccp_Type& ccp_type;
const double& hse_omega;
double lambda;
double lambda = 0.3;

Exx_Info_Lip(const Exx_Info::Exx_Info_Global& info_global)
: ccp_type(info_global.ccp_type), hse_omega(info_global.hse_omega)
{
}
:ccp_type(info_global.ccp_type),
hse_omega(info_global.hse_omega) {}
};
Exx_Info_Lip info_lip;

Expand Down
3 changes: 1 addition & 2 deletions source/module_hamilt_pw/hamilt_pwdft/global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
namespace GlobalC
{
#ifdef __EXX
Exx_Info exx_info;
Exx_Lip exx_lip(exx_info.info_lip);
Exx_Info exx_info;
#endif
pseudopot_cell_vnl ppcell;
UnitCell ucell;
Expand Down
3 changes: 1 addition & 2 deletions source/module_hamilt_pw/hamilt_pwdft/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ static const char* _hipfftGetErrorString(hipfftResult_t error)
namespace GlobalC
{
#ifdef __EXX
extern Exx_Info exx_info;
extern Exx_Lip exx_lip;
extern Exx_Info exx_info;
#endif
extern pseudopot_cell_vnl ppcell;
} // namespace GlobalC
Expand Down
27 changes: 27 additions & 0 deletions source/module_hamilt_pw/hamilt_pwdft/hamilt_lcaopw.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef HAMILTLIP_H
#define HAMILTLIP_H

#include "module_hamilt_pw/hamilt_pwdft/hamilt_pw.h"
#ifdef __EXX
#include "module_ri/exx_lip.h"
#endif

namespace hamilt
{

template <typename T>
class HamiltLIP : public HamiltPW<T, base_device::DEVICE_CPU>
{
public:
HamiltLIP(elecstate::Potential* pot_in, ModulePW::PW_Basis_K* wfc_basis, K_Vectors* p_kv)
: HamiltPW<T, base_device::DEVICE_CPU>(pot_in, wfc_basis, p_kv) {};
#ifdef __EXX
HamiltLIP(elecstate::Potential* pot_in, ModulePW::PW_Basis_K* wfc_basis, K_Vectors* p_kv, Exx_Lip<T>& exx_lip_in)
: HamiltPW<T, base_device::DEVICE_CPU>(pot_in, wfc_basis, p_kv), exx_lip(exx_lip_in) {};
Exx_Lip<T>& exx_lip;
#endif
};

} // namespace hamilt

#endif
Loading

0 comments on commit 1f49762

Please sign in to comment.