Skip to content

Commit

Permalink
Performance: make KG function more converged (deepmodeling#2862)
Browse files Browse the repository at this point in the history
* do not free INTER_POOL if not initialized

* fix: wrong stress results of scan when NPROC_IN_POOL are large

* add memory record for S-KG

* Performance: make KG function more converged

* update INPUT

* add const number and annotation
  • Loading branch information
Qianruipku authored Aug 26, 2023
1 parent c23c943 commit db71b1c
Show file tree
Hide file tree
Showing 13 changed files with 5,940 additions and 5,888 deletions.
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ void ESolver_KS_PW<FPTYPE, Device>::postprocess()

if (INPUT.cal_cond)
{
this->KG(INPUT.cond_nche, INPUT.cond_fwhm, INPUT.cond_wcut, INPUT.cond_dw, INPUT.cond_dt, this->pelec->wg);
this->KG(INPUT.cond_fwhm, INPUT.cond_wcut, INPUT.cond_dw, INPUT.cond_dt, this->pelec->wg);
}
}

Expand Down
45 changes: 39 additions & 6 deletions source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,46 @@ namespace ModuleESolver
virtual void hamilt2estates(const double ethr) override;
virtual void nscf() override;
void postprocess() override;
//calculate conductivities with Kubo-Greenwood formula
void KG(const int nche_KG, const double fwhmin, const double wcut,
const double dw_in, const double dt_in, ModuleBase::matrix& wg);
void jjcorr_ks(const int ik, const int nt, const double dt, ModuleBase::matrix& wg, hamilt::Velocity& velop,
double* ct11, double* ct12, double* ct22);

protected:
/**
* @brief calculate Onsager coefficients Lmn(\omega) and conductivities with Kubo-Greenwood formula
*
* @param fwhmin FWHM for delta function
* @param wcut cutoff \omega for Lmn(\omega)
* @param dw_in \omega step
* @param dt_in time step
* @param wg wg(ik,ib) occupation for the ib-th band in the ik-th kpoint
*/
void KG(const double fwhmin,
const double wcut,
const double dw_in,
const double dt_in,
ModuleBase::matrix& wg);

/**
* @brief calculate the response function Cmn(t) for currents
*
* @param ik k point
* @param nt number of steps of time
* @param dt time step
* @param decut ignore dE which is larger than decut
* @param wg wg(ik,ib) occupation for the ib-th band in the ik-th kpoint
* @param velop velocity operator
* @param ct11 C11(t)
* @param ct12 C12(t)
* @param ct22 C22(t)
*/
void jjcorr_ks(const int ik,
const int nt,
const double dt,
const double decut,
ModuleBase::matrix& wg,
hamilt::Velocity& velop,
double* ct11,
double* ct12,
double* ct22);

protected:
virtual void beforescf(const int istep) override;
virtual void eachiterinit(const int istep, const int iter) override;
virtual void updatepot(const int istep, const int iter) override;
Expand Down
124 changes: 70 additions & 54 deletions source/module_esolver/esolver_ks_pw_tool.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "esolver_ks_pw.h"
#include "module_base/global_function.h"
#include "module_base/global_variable.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_elecstate/occupy.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/binstream.h"
namespace ModuleESolver
{
Expand All @@ -21,10 +21,13 @@ namespace ModuleESolver
// e/k = 11604.518026 , 1 eV = 11604.5 K
//------------------------------------------------------------------
#define TWOSQRT2LN2 2.354820045030949 // FWHM = 2sqrt(2ln2) * \sigma
#define FACTOR 1.839939223835727e7
template<typename FPTYPE, typename Device>
void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const double fwhmin, const double wcut,
const double dw_in, const double dt_in, ModuleBase::matrix& wg)
#define FACTOR 1.839939223835727e7
template <typename FPTYPE, typename Device>
void ESolver_KS_PW<FPTYPE, Device>::KG(const double fwhmin,
const double wcut,
const double dw_in,
const double dt_in,
ModuleBase::matrix& wg)
{
//-----------------------------------------------------------
// KS conductivity
Expand All @@ -33,32 +36,34 @@ void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const double fwhmin, c
int nw = ceil(wcut / dw_in);
double dw = dw_in / ModuleBase::Ry_to_eV; // converge unit in eV to Ry
double sigma = fwhmin / TWOSQRT2LN2 / ModuleBase::Ry_to_eV;
double dt = dt_in; // unit in a.u., 1 a.u. = 4.837771834548454e-17 s
int nt = ceil(sqrt(20) / sigma / dt);
double dt = dt_in; // unit in a.u., 1 a.u. = 4.837771834548454e-17 s
const double expfactor = 23; //exp(-23) = 1e-10
int nt = ceil(sqrt(2*expfactor)/sigma/dt); //set nt empirically
std::cout << "nw: " << nw << " ; dw: " << dw * ModuleBase::Ry_to_eV << " eV" << std::endl;
std::cout << "nt: " << nt << " ; dt: " << dt << " a.u.(ry^-1)" << std::endl;
assert(nw >= 1);
assert(nt >= 1);
const int nk = this->kv.nks;

double *ct11 = new double[nt];
double *ct12 = new double[nt];
double *ct22 = new double[nt];
double* ct11 = new double[nt];
double* ct12 = new double[nt];
double* ct22 = new double[nt];
ModuleBase::GlobalFunc::ZEROS(ct11, nt);
ModuleBase::GlobalFunc::ZEROS(ct12, nt);
ModuleBase::GlobalFunc::ZEROS(ct22, nt);

hamilt::Velocity velop(this->pw_wfc, this->kv.isk.data(), &GlobalC::ppcell, &GlobalC::ucell, INPUT.cond_nonlocal);
double decut = (wcut + 5*fwhmin) / ModuleBase::Ry_to_eV;
for (int ik = 0; ik < nk; ++ik)
{
velop.init(ik);
jjcorr_ks(ik, nt, dt, wg, velop, ct11,ct12,ct22);
jjcorr_ks(ik, nt, dt, decut, wg, velop, ct11, ct12, ct22);
}
#ifdef __MPI
MPI_Allreduce(MPI_IN_PLACE, ct11, nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
MPI_Allreduce(MPI_IN_PLACE, ct12, nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
MPI_Allreduce(MPI_IN_PLACE, ct22, nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
#endif
#endif
//------------------------------------------------------------------
// Output
//------------------------------------------------------------------
Expand All @@ -72,8 +77,15 @@ void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const double fwhmin, c
}

template <typename FPTYPE, typename Device>
void ESolver_KS_PW<FPTYPE, Device>:: jjcorr_ks(const int ik, const int nt, const double dt, ModuleBase::matrix& wg, hamilt::Velocity &velop,
double* ct11, double* ct12, double* ct22)
void ESolver_KS_PW<FPTYPE, Device>::jjcorr_ks(const int ik,
const int nt,
const double dt,
const double decut,
ModuleBase::matrix& wg,
hamilt::Velocity& velop,
double* ct11,
double* ct12,
double* ct22)
{
char transn = 'N';
char transc = 'C';
Expand All @@ -83,17 +95,17 @@ void ESolver_KS_PW<FPTYPE, Device>:: jjcorr_ks(const int ik, const int nt, const
const double ef = this->pelec->eferm.ef;
const int npw = this->kv.ngk[ik];
const int reducenb2 = (nbands - 1) * nbands / 2;
bool gamma_only = false; //ABACUS do not support gamma_only yet.
std::complex<double> *levc = &(this->psi[0](ik, 0, 0));
std::complex<double> *prevc = new std::complex<double>[3 * npwx * nbands];
std::complex<double> *pij = new std::complex<double>[nbands * nbands];
double *pij2 = new double[reducenb2];
bool gamma_only = false; // ABACUS do not support gamma_only yet.
std::complex<double>* levc = &(this->psi[0](ik, 0, 0));
std::complex<double>* prevc = new std::complex<double>[3 * npwx * nbands];
std::complex<double>* pij = new std::complex<double>[nbands * nbands];
double* pij2 = new double[reducenb2];
ModuleBase::GlobalFunc::ZEROS(pij2, reducenb2);
// px|right>
velop.act(this->psi, nbands*GlobalV::NPOL, levc, prevc);
velop.act(this->psi, nbands * GlobalV::NPOL, levc, prevc);
for (int id = 0; id < ndim; ++id)
{

zgemm_(&transc,
&transn,
&nbands,
Expand All @@ -110,26 +122,26 @@ void ESolver_KS_PW<FPTYPE, Device>:: jjcorr_ks(const int ik, const int nt, const
#ifdef __MPI
MPI_Allreduce(MPI_IN_PLACE, pij, nbands * nbands, MPI_DOUBLE_COMPLEX, MPI_SUM, POOL_WORLD);
#endif
if(!gamma_only)
for (int ib = 0, ijb = 0; ib < nbands; ++ib)
{
for (int jb = ib + 1; jb < nbands; ++jb, ++ijb)
if (!gamma_only)
for (int ib = 0, ijb = 0; ib < nbands; ++ib)
{
pij2[ijb] += norm(pij[ib * nbands + jb]);
for (int jb = ib + 1; jb < nbands; ++jb, ++ijb)
{
pij2[ijb] += norm(pij[ib * nbands + jb]);
}
}
}
}

if(GlobalV::RANK_IN_POOL == 0)
if (GlobalV::RANK_IN_POOL == 0)
{
int nkstot = this->kv.nkstot;
int ikglobal = this->kv.getik_global(ik);
std::stringstream ss;
ss<<GlobalV::global_out_dir<<"vmatrix"<<ikglobal+1<<".dat";
ss << GlobalV::global_out_dir << "vmatrix" << ikglobal + 1 << ".dat";
Binstream binpij(ss.str(), "w");
binpij<<8*reducenb2;
binpij << 8 * reducenb2;
binpij.write(pij2, reducenb2);
binpij<<8*reducenb2;
binpij << 8 * reducenb2;
}

int ntper = nt / GlobalV::NPROC_IN_POOL;
Expand All @@ -149,14 +161,16 @@ void ESolver_KS_PW<FPTYPE, Device>:: jjcorr_ks(const int ik, const int nt, const
double tmct11 = 0;
double tmct12 = 0;
double tmct22 = 0;
double *enb = &(this->pelec->ekb(ik, 0));
double* enb = &(this->pelec->ekb(ik, 0));
for (int ib = 0, ijb = 0; ib < nbands; ++ib)
{
double ei = enb[ib];
double fi = wg(ik, ib);
for (int jb = ib + 1; jb < nbands; ++jb, ++ijb)
{
double ej = enb[jb];
if (ej - ei > decut )
continue;
double fj = wg(ik, jb);
double tmct = sin((ej - ei) * (it)*dt) * (fi - fj) * pij2[ijb];
tmct11 += tmct;
Expand All @@ -176,32 +190,33 @@ void ESolver_KS_PW<FPTYPE, Device>:: jjcorr_ks(const int ik, const int nt, const

template <typename FPTYPE, typename Device>
void ESolver_KS_PW<FPTYPE, Device>::calcondw(const int nt,
const double dt,
const double fwhmin,
const double wcut,
const double dw_in,
double *ct11,
double *ct12,
double *ct22)
const double dt,
const double fwhmin,
const double wcut,
const double dw_in,
double* ct11,
double* ct12,
double* ct22)
{
double factor = FACTOR;
const int ndim = 3;
int nw = ceil(wcut / dw_in);
double dw = dw_in / ModuleBase::Ry_to_eV; // converge unit in eV to Ry
double sigma = fwhmin / TWOSQRT2LN2 / ModuleBase::Ry_to_eV;
std::ofstream ofscond("je-je.txt");
ofscond << std::setw(8) << "#t(a.u.)" << std::setw(15) << "c11(t)" << std::setw(15) << "c12(t)" << std::setw(15) << "c22(t)" << std::setw(15)
<< "decay" << std::endl;
ofscond << std::setw(8) << "#t(a.u.)" << std::setw(15) << "c11(t)" << std::setw(15) << "c12(t)" << std::setw(15)
<< "c22(t)" << std::setw(15) << "decay" << std::endl;
for (int it = 0; it < nt; ++it)
{
ofscond << std::setw(8) << (it)*dt << std::setw(15) << -2 * ct11[it] << std::setw(15) << -2 * ct12[it] << std::setw(15)
<< -2 * ct22[it] << std::setw(15) << exp(-double(1) / 2 * sigma * sigma * pow((it)*dt, 2)) << std::endl;
ofscond << std::setw(8) << (it)*dt << std::setw(15) << -2 * ct11[it] << std::setw(15) << -2 * ct12[it]
<< std::setw(15) << -2 * ct22[it] << std::setw(15)
<< exp(-double(1) / 2 * sigma * sigma * pow((it)*dt, 2)) << std::endl;
}
ofscond.close();
double *cw11 = new double[nw];
double *cw12 = new double[nw];
double *cw22 = new double[nw];
double *kappa = new double[nw];
double* cw11 = new double[nw];
double* cw12 = new double[nw];
double* cw22 = new double[nw];
double* kappa = new double[nw];
ModuleBase::GlobalFunc::ZEROS(cw11, nw);
ModuleBase::GlobalFunc::ZEROS(cw12, nw);
ModuleBase::GlobalFunc::ZEROS(cw22, nw);
Expand All @@ -218,8 +233,8 @@ void ESolver_KS_PW<FPTYPE, Device>::calcondw(const int nt,
}
}
ofscond.open("Onsager.txt");
ofscond << std::setw(8) << "## w(eV) " << std::setw(20) << "sigma(Sm^-1)" << std::setw(20) << "kappa(W(mK)^-1)" << std::setw(20)
<< "L12/e(Am^-1)" << std::setw(20) << "L22/e^2(Wm^-1)" << std::endl;
ofscond << std::setw(8) << "## w(eV) " << std::setw(20) << "sigma(Sm^-1)" << std::setw(20) << "kappa(W(mK)^-1)"
<< std::setw(20) << "L12/e(Am^-1)" << std::setw(20) << "L22/e^2(Wm^-1)" << std::endl;
for (int iw = 0; iw < nw; ++iw)
{
cw11[iw] *= double(2) / ndim / GlobalC::ucell.omega * factor; // unit in Sm^-1
Expand All @@ -229,12 +244,13 @@ void ESolver_KS_PW<FPTYPE, Device>::calcondw(const int nt,
* pow(2.17987092759e-18 / 1.6021766208e-19, 2); // unit in Wm^-1
kappa[iw] = (cw22[iw] - pow(cw12[iw], 2) / cw11[iw]) / Occupy::gaussian_parameter / ModuleBase::Ry_to_eV
/ 11604.518026;
ofscond << std::setw(8) << (iw + 0.5) * dw * ModuleBase::Ry_to_eV << std::setw(20) << cw11[iw] << std::setw(20) << kappa[iw]
<< std::setw(20) << cw12[iw] << std::setw(20) << cw22[iw] << std::endl;
ofscond << std::setw(8) << (iw + 0.5) * dw * ModuleBase::Ry_to_eV << std::setw(20) << cw11[iw] << std::setw(20)
<< kappa[iw] << std::setw(20) << cw12[iw] << std::setw(20) << cw22[iw] << std::endl;
}
std::cout << std::setprecision(6) << "DC electrical conductivity: " << cw11[0] - (cw11[1] - cw11[0]) * 0.5 << " Sm^-1"
<< std::endl;
std::cout << std::setprecision(6) << "Thermal conductivity: " << kappa[0] - (kappa[1] - kappa[0]) * 0.5 << " W(mK)^-1" << std::endl;
std::cout << std::setprecision(6) << "DC electrical conductivity: " << cw11[0] - (cw11[1] - cw11[0]) * 0.5
<< " Sm^-1" << std::endl;
std::cout << std::setprecision(6) << "Thermal conductivity: " << kappa[0] - (kappa[1] - kappa[0]) * 0.5
<< " W(mK)^-1" << std::endl;
;
ofscond.close();

Expand Down
13 changes: 9 additions & 4 deletions source/module_esolver/esolver_sdft_pw_tool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
double dw = dw_in / ModuleBase::Ry_to_eV; //converge unit in eV to Ry
double sigma = fwhmin / TWOSQRT2LN2 / ModuleBase::Ry_to_eV;
double dt = dt_in; //unit in a.u., 1 a.u. = 4.837771834548454e-17 s
int nt = ceil(sqrt(20)/sigma/dt);
const double expfactor = 18.42; //exp(-18.42) = 1e-8
int nt = ceil(sqrt(2*expfactor)/sigma/dt); //set nt empirically
std::cout<<"nw: "<<nw<<" ; dw: "<<dw*ModuleBase::Ry_to_eV<<" eV"<<std::endl;
std::cout<<"nt: "<<nt<<" ; dt: "<<dt<<" a.u.(ry^-1)"<<std::endl;
assert(nw >= 1);
Expand Down Expand Up @@ -214,7 +215,7 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
// ks conductivity
//-----------------------------------------------------------
if(GlobalV::MY_STOGROUP == 0 && totbands_ks > 0)
jjcorr_ks(ik, nt, dt, this->pelec->wg, velop, ct11,ct12,ct22);
jjcorr_ks(ik, nt, dt, (wcut + 5*fwhmin) / ModuleBase::Ry_to_eV, this->pelec->wg, velop, ct11,ct12,ct22);

//-----------------------------------------------------------
// sto conductivity
Expand All @@ -239,9 +240,9 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
ModuleBase::Memory::record("SDFT::j2psi", memory_cost);
//(1-f)*j*sqrt(f)|psi>
psi::Psi<std::complex<double>> j1sfpsi(1, ndim * totbands_per, npwx, kv.ngk.data());
ModuleBase::Memory::record("SDFT::psi0", memory_cost);
ModuleBase::Memory::record("SDFT::j1sfpsi", memory_cost);
psi::Psi<std::complex<double>> j2sfpsi(1, ndim * totbands_per, npwx, kv.ngk.data());
ModuleBase::Memory::record("SDFT::psi0", memory_cost);
ModuleBase::Memory::record("SDFT::j2sfpsi", memory_cost);
double* en;
if(ksbandper > 0) en = new double [ksbandper];
for(int ib = 0 ; ib < ksbandper ; ++ib)
Expand Down Expand Up @@ -474,6 +475,10 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
// const int dim_jmatrix = totbands_per*totbands;
ModuleBase::ComplexMatrix j1l(ndim,dim_jmatrix), j2l(ndim,dim_jmatrix);
ModuleBase::ComplexMatrix j1r(ndim,dim_jmatrix), j2r(ndim,dim_jmatrix);
ModuleBase::Memory::record("SDFT::j1l", sizeof(std::complex<double>) * ndim * dim_jmatrix);
ModuleBase::Memory::record("SDFT::j2l", sizeof(std::complex<double>) * ndim * dim_jmatrix);
ModuleBase::Memory::record("SDFT::j1r", sizeof(std::complex<double>) * ndim * dim_jmatrix);
ModuleBase::Memory::record("SDFT::j2r", sizeof(std::complex<double>) * ndim * dim_jmatrix);
char transa = 'C';
char transb = 'N';
int totbands_per3 = ndim*totbands_per;
Expand Down
1 change: 0 additions & 1 deletion tests/integrate/186_PW_KG_100/INPUT
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ mixing_type broyden
mixing_beta 0.4

cal_cond 1
cond_nche 18
cond_fwhm 8
cond_wcut 20
cond_dw 0.02
Expand Down
Loading

0 comments on commit db71b1c

Please sign in to comment.