Skip to content

Commit

Permalink
refactor: exx dm
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Oct 16, 2023
1 parent 4269f01 commit 28c8648
Show file tree
Hide file tree
Showing 13 changed files with 280 additions and 113 deletions.
22 changes: 0 additions & 22 deletions source/module_elecstate/elecstate_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,6 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)
//cal_dm(this->loc->ParaV, this->wg, psi, this->loc->dm_gamma);
elecstate::cal_dm_psi(this->DM->get_paraV_pointer(), this->wg, psi, *(this->DM));
this->DM->cal_DMR();

// interface for RI-related calculation, which needs loc.dm_gamma
#ifdef __EXX
if (GlobalC::exx_info.info_global.cal_exx || this->loc->out_dm)
{
this->loc->dm_gamma.resize(GlobalV::NSPIN);
for (int is = 0; is < GlobalV::NSPIN; ++is)
{
this->loc->set_dm_gamma(is, this->DM->get_DMK_pointer(is));
}
}
#else
if (this->loc->out_dm) // keep interface for old Output_DM until new one is ready
{
this->loc->dm_gamma.resize(GlobalV::NSPIN);
for (int is = 0; is < GlobalV::NSPIN; ++is)
{
this->loc->set_dm_gamma(is, this->DM->get_DMK_pointer(is));
}
}
#endif

ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");

for (int ik = 0; ik < psi.get_nk(); ++ik)
Expand Down
10 changes: 5 additions & 5 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ namespace ModuleESolver
#ifdef __EXX
// calculate exact-exchange
if (GlobalC::exx_info.info_ri.real_number)
this->exd->exx_eachiterinit(this->LOC, *(this->p_chgmix), iter);
this->exd->exx_eachiterinit(*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), *(this->p_chgmix), iter);
else
this->exc->exx_eachiterinit(this->LOC, *(this->p_chgmix), iter);
this->exc->exx_eachiterinit(*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), *(this->p_chgmix), iter);
#endif

if (GlobalV::dft_plus_u)
Expand Down Expand Up @@ -845,7 +845,7 @@ namespace ModuleESolver
// rpa_interface.rpa_exx_lcao().info.files_abfs = GlobalV::rpa_orbitals;
// rpa_interface.out_for_RPA(*(this->LOWF.ParaV), *(this->psi), this->LOC, this->pelec);
RPA_LRI<TK, double> rpa_lri_double(GlobalC::exx_info.info_ri);
rpa_lri_double.cal_postSCF_exx(this->LOC, MPI_COMM_WORLD, this->kv, *this->LOWF.ParaV);
rpa_lri_double.cal_postSCF_exx(*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), MPI_COMM_WORLD, this->kv);
rpa_lri_double.init(MPI_COMM_WORLD, this->kv);
rpa_lri_double.out_for_RPA(*(this->LOWF.ParaV), *(this->psi), this->pelec);
}
Expand All @@ -872,9 +872,9 @@ namespace ModuleESolver
{
#ifdef __EXX
if (GlobalC::exx_info.info_ri.real_number)
return this->exd->exx_after_converge(*this->p_hamilt, this->LM, this->LOC, this->kv, iter);
return this->exd->exx_after_converge(*this->p_hamilt, this->LM, *dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), this->kv, iter);
else
return this->exc->exx_after_converge(*this->p_hamilt, this->LM, this->LOC, this->kv, iter);
return this->exc->exx_after_converge(*this->p_hamilt, this->LM, *dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), this->kv, iter);
#endif // __EXX
return true;
}
Expand Down
26 changes: 13 additions & 13 deletions source/module_hamilt_lcao/hamilt_lcaodft/local_orbital_charge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ void Local_Orbital_Charge::set_dm_k(int ik, std::complex<double>* dm_k_in)
return;
}

void Local_Orbital_Charge::set_dm_gamma(int is, double* dm_gamma_in)
{
ModuleBase::TITLE("Local_Orbital_Charge", "set_dm_gamma");
dm_gamma[is].create(ParaV->ncol, ParaV->nrow);
for (int i = 0; i < ParaV->ncol; ++i)
{
for (int j = 0; j < ParaV->nrow; ++j)
{
dm_gamma[is](i, j) = dm_gamma_in[i * ParaV->nrow + j];
}
}
return;
}
// void Local_Orbital_Charge::set_dm_gamma(int is, double* dm_gamma_in)
// {
// ModuleBase::TITLE("Local_Orbital_Charge", "set_dm_gamma");
// dm_gamma[is].create(ParaV->ncol, ParaV->nrow);
// for (int i = 0; i < ParaV->ncol; ++i)
// {
// for (int j = 0; j < ParaV->nrow; ++j)
// {
// dm_gamma[is](i, j) = dm_gamma_in[i * ParaV->nrow + j];
// }
// }
// return;
// }
9 changes: 6 additions & 3 deletions source/module_ri/Exx_LRI_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ class LCAO_Matrix;
class Charge_Mixing;
namespace elecstate
{
class ElecState;
class ElecState;
template <typename TK, typename TR>
class DensityMatrix;
}

template<typename T, typename Tdata>
Expand All @@ -32,7 +34,8 @@ class Exx_LRI_Interface
void exx_beforescf(const K_Vectors& kv, const Charge_Mixing& chgmix);

/// @brief in eachiterinit: do DM mixing and calculate Hexx when entering 2nd SCF
void exx_eachiterinit(const Local_Orbital_Charge& loc, const Charge_Mixing& chgmix, const int& iter);
void exx_eachiterinit(const elecstate::DensityMatrix<T, double>& dm/**< double should be Tdata if complex-PBE-DM is supported*/,
const Charge_Mixing& chgmix, const int& iter);

/// @brief in hamilt2density: calculate Hexx and Eexx
void exx_hamilt2density(elecstate::ElecState& elec, const Parallel_Orbitals& pv);
Expand All @@ -41,7 +44,7 @@ class Exx_LRI_Interface
bool exx_after_converge(
hamilt::Hamilt<T>& hamilt,
LCAO_Matrix& lm,
const Local_Orbital_Charge& loc,
const elecstate::DensityMatrix<T, double>& dm/**< double should be Tdata if complex-PBE-DM is supported*/,
const K_Vectors& kv,
int& iter);

Expand Down
21 changes: 7 additions & 14 deletions source/module_ri/Exx_LRI_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,15 @@ void Exx_LRI_Interface<T, Tdata>::exx_beforescf(const K_Vectors& kv, const Charg
}

template<typename T, typename Tdata>
void Exx_LRI_Interface<T, Tdata>::exx_eachiterinit(const Local_Orbital_Charge& loc, const Charge_Mixing& chgmix, const int& iter)
void Exx_LRI_Interface<T, Tdata>::exx_eachiterinit(const elecstate::DensityMatrix<T, double>& dm, const Charge_Mixing& chgmix, const int& iter)
{
if (GlobalC::exx_info.info_global.cal_exx)
{
if (!GlobalC::exx_info.info_global.separate_loop && exx_lri->two_level_step)
{
const bool flag_restart = (iter==1) ? true : false;
if(GlobalV::GAMMA_ONLY_LOCAL)
exx_lri->mix_DMk_2D.mix(loc.dm_gamma, flag_restart);
else
exx_lri->mix_DMk_2D.mix(loc.dm_k, flag_restart);

exx_lri->cal_exx_elec(*loc.LOWF->ParaV);
exx_lri->mix_DMk_2D.mix(dm.get_DMK_vector(), flag_restart);
exx_lri->cal_exx_elec(*dm.get_paraV_pointer());
}
}
}
Expand Down Expand Up @@ -117,7 +113,7 @@ template<typename T, typename Tdata>
bool Exx_LRI_Interface<T, Tdata>::exx_after_converge(
hamilt::Hamilt<T>& hamilt,
LCAO_Matrix& lm,
const Local_Orbital_Charge& loc,
const elecstate::DensityMatrix<T, double>& dm,
const K_Vectors& kv,
int& iter)
{
Expand Down Expand Up @@ -203,14 +199,11 @@ bool Exx_LRI_Interface<T, Tdata>::exx_after_converge(
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].ncpp.xc_func);
}

const bool flag_restart = (exx_lri->two_level_step==0) ? true : false;
if (GlobalV::GAMMA_ONLY_LOCAL)
exx_lri->mix_DMk_2D.mix(loc.dm_gamma, flag_restart);
else
exx_lri->mix_DMk_2D.mix(loc.dm_k, flag_restart);
const bool flag_restart = (exx_lri->two_level_step == 0) ? true : false;
exx_lri->mix_DMk_2D.mix(dm.get_DMK_vector(), flag_restart);

// GlobalC::exx_lcao.cal_exx_elec(p_esolver->LOC, p_esolver->LOWF.wfc_k_grid);
exx_lri->cal_exx_elec(*loc.LOWF->ParaV);
exx_lri->cal_exx_elec(*dm.get_paraV_pointer());
iter = 0;
std::cout << " Updating EXX and rerun SCF" << std::endl;
exx_lri->two_level_step++;
Expand Down
20 changes: 10 additions & 10 deletions source/module_ri/Mix_DMk_2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ Mix_DMk_2D &Mix_DMk_2D::set_mixing(Base_Mixing::Mixing* mixing_in)
{
ModuleBase::TITLE("Mix_DMk_2D","set_mixing");
if(this->gamma_only)
for(Mix_Matrix<ModuleBase::matrix> &mix_one : this->mix_DMk_gamma)
for (Mix_Matrix<std::vector<double>>& mix_one : this->mix_DMk_gamma)
mix_one.init(mixing_in);
else
for(Mix_Matrix<ModuleBase::ComplexMatrix> &mix_one : this->mix_DMk_k)
for (Mix_Matrix<std::vector<std::complex<double>>>& mix_one : this->mix_DMk_k)
mix_one.init(mixing_in);
return *this;
}
Expand All @@ -33,39 +33,39 @@ Mix_DMk_2D &Mix_DMk_2D::set_mixing_beta(const double mixing_beta)
{
ModuleBase::TITLE("Mix_DMk_2D","set_mixing_beta");
if(this->gamma_only)
for(Mix_Matrix<ModuleBase::matrix> &mix_one : this->mix_DMk_gamma)
for (Mix_Matrix<std::vector<double>>& mix_one : this->mix_DMk_gamma)
mix_one.mixing_beta = mixing_beta;
else
for(Mix_Matrix<ModuleBase::ComplexMatrix> &mix_one : this->mix_DMk_k)
for (Mix_Matrix<std::vector<std::complex<double>>>& mix_one : this->mix_DMk_k)
mix_one.mixing_beta = mixing_beta;
return *this;
}

void Mix_DMk_2D::mix(const std::vector<ModuleBase::matrix> &dm, const bool flag_restart)
void Mix_DMk_2D::mix(const std::vector<std::vector<double>>& dm, const bool flag_restart)
{
ModuleBase::TITLE("Mix_DMk_2D","mix");
assert(this->mix_DMk_gamma.size() == dm.size());
for(int ik=0; ik<dm.size(); ++ik)
this->mix_DMk_gamma[ik].mix(dm[ik], flag_restart);
}
void Mix_DMk_2D::mix(const std::vector<ModuleBase::ComplexMatrix> &dm, const bool flag_restart)
void Mix_DMk_2D::mix(const std::vector<std::vector<std::complex<double>>>& dm, const bool flag_restart)
{
ModuleBase::TITLE("Mix_DMk_2D","mix");
assert(this->mix_DMk_k.size() == dm.size());
for(int ik=0; ik<dm.size(); ++ik)
this->mix_DMk_k[ik].mix(dm[ik], flag_restart);
}

std::vector<const ModuleBase::matrix*> Mix_DMk_2D::get_DMk_gamma_out() const
std::vector<const std::vector<double>*> Mix_DMk_2D::get_DMk_gamma_out() const
{
std::vector<const ModuleBase::matrix*> DMk_out(this->mix_DMk_gamma.size());
std::vector<const std::vector<double>*> DMk_out(this->mix_DMk_gamma.size());
for(int ik=0; ik<this->mix_DMk_gamma.size(); ++ik)
DMk_out[ik] = &this->mix_DMk_gamma[ik].get_data_out();
return DMk_out;
}
std::vector<const ModuleBase::ComplexMatrix*> Mix_DMk_2D::get_DMk_k_out() const
std::vector<const std::vector<std::complex<double>>*> Mix_DMk_2D::get_DMk_k_out() const
{
std::vector<const ModuleBase::ComplexMatrix*> DMk_out(this->mix_DMk_k.size());
std::vector<const std::vector<std::complex<double>>*> DMk_out(this->mix_DMk_k.size());
for(int ik=0; ik<this->mix_DMk_k.size(); ++ik)
DMk_out[ik] = &this->mix_DMk_k[ik].get_data_out();
return DMk_out;
Expand Down
12 changes: 6 additions & 6 deletions source/module_ri/Mix_DMk_2D.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,29 @@ class Mix_DMk_2D
* @param dm Double Density matrix.
* @param flag_restart Flag indicating whether restart mixing.
*/
void mix(const std::vector<ModuleBase::matrix> &dm, const bool flag_restart);
void mix(const std::vector<std::vector<double>>& dm, const bool flag_restart);

/**
* @brief Mixes the complex density matrix.
* @param dm Complex density matrix.
* @param flag_restart Flag indicating whether restart mixing.
*/
void mix(const std::vector<ModuleBase::ComplexMatrix> &dm, const bool flag_restart);
void mix(const std::vector<std::vector<std::complex<double>>>& dm, const bool flag_restart);

/**
* @brief Returns the double density matrix.
* @return Double density matrices for each k-points.
*/
std::vector<const ModuleBase::matrix*> get_DMk_gamma_out() const;
std::vector<const std::vector<double>*> get_DMk_gamma_out() const;
/**
* @brief Returns the complex density matrix.
* @return Complex density matrices for each k-points.
*/
std::vector<const ModuleBase::ComplexMatrix*> get_DMk_k_out() const;
std::vector<const std::vector<std::complex<double>>*> get_DMk_k_out() const;

private:
std::vector<Mix_Matrix<ModuleBase::matrix>> mix_DMk_gamma;
std::vector<Mix_Matrix<ModuleBase::ComplexMatrix>> mix_DMk_k;
std::vector<Mix_Matrix<std::vector<double>>> mix_DMk_gamma;
std::vector<Mix_Matrix<std::vector<std::complex<double>>>> mix_DMk_k;
bool gamma_only;
};

Expand Down
Loading

0 comments on commit 28c8648

Please sign in to comment.