diff --git a/source/module_cell/klist.h b/source/module_cell/klist.h index 999b40ad15..ec5573d0a9 100644 --- a/source/module_cell/klist.h +++ b/source/module_cell/klist.h @@ -127,6 +127,11 @@ class K_Vectors return this->nkstot_full; } + double get_koffset(const int i) const + { + return this->koffset[i]; + } + void set_nks(int value) { this->nks = value; diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.cpp index 1513a4f3a7..5bd7d9ba9a 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.cpp @@ -361,16 +361,17 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, #ifdef __EXX if (GlobalC::exx_info.info_global.cal_exx) { - Operator* exx = new OperatorEXX>(this->hsk, - LM_in, - this->hR, - *this->kv, - LM_in->Hexxd, - LM_in->Hexxc, - exx_two_level_step, - !GlobalC::restart.info_load.restart_exx - && GlobalC::restart.info_load.load_H); - this->getOperator()->add(exx); + Operator*exx = new OperatorEXX>(this->hsk, + LM_in, + this->hR, + *this->kv, + LM_in->Hexxd, + LM_in->Hexxc, + Add_Hexx_Type::R, + exx_two_level_step, + !GlobalC::restart.info_load.restart_exx + && GlobalC::restart.info_load.load_H); + this->getOperator()->add(exx); } #endif // if NSPIN==2, HR should be separated into two parts, save HR into this->hRS2 diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.h b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.h index 6cb1e8fde2..3d9b1f9c90 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.h +++ b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.h @@ -4,6 +4,7 @@ #ifdef __EXX #include +#include #include "operator_lcao.h" #include "module_cell/klist.h" #include "module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h" @@ -20,7 +21,7 @@ class OperatorEXX : public T }; #endif - +enum Add_Hexx_Type { R, k }; template class OperatorEXX> : public OperatorLCAO { @@ -32,13 +33,16 @@ class OperatorEXX> : public OperatorLCAO const K_Vectors& kv_in, std::vector>>>* Hexxd_in = nullptr, std::vector>>>>* Hexxc_in = nullptr, + Add_Hexx_Type add_hexx_type_in = Add_Hexx_Type::R, int* two_level_step_in = nullptr, const bool restart_in = false); virtual void contributeHk(int ik) override; + virtual void contributeHR() override; private: - + Add_Hexx_Type add_hexx_type = Add_Hexx_Type::R; + int current_spin = 0; bool HR_fixed_done = false; std::vector>>>* Hexxd = nullptr; @@ -53,9 +57,14 @@ class OperatorEXX> : public OperatorLCAO bool restart = false; void add_loaded_Hexx(const int ik); + void add_loaded_HexxR() {}; + void clear_loaded_HexxR() {}; const K_Vectors& kv; LCAO_Matrix* LM = nullptr; + // if k points has no shift, use cell_nearest to reduce the memory cost + RI::Cell_Nearest cell_nearest; + bool use_cell_nearest = true; }; } // namespace hamilt diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp index 4158abcaf5..200fb86f3b 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp @@ -1,127 +1,217 @@ -#ifndef OPEXXLCAO_HPP -#define OPEXXLCAO_HPP - -#ifdef __EXX - -#include "op_exx_lcao.h" -#include "module_ri/RI_2D_Comm.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" -#include "module_hamilt_general/module_xc/xc_functional.h" - -namespace hamilt -{ - -template -OperatorEXX>::OperatorEXX( - HS_Matrix_K* hsk_in, - LCAO_Matrix* LM_in, - hamilt::HContainer* hR_in, - const K_Vectors& kv_in, - std::vector>>>* Hexxd_in, - std::vector>>>>* Hexxc_in, - int* two_level_step_in, - const bool restart_in) - : OperatorLCAO(hsk_in, kv_in.kvec_d, hR_in), - kv(kv_in), - Hexxd(Hexxd_in), - Hexxc(Hexxc_in), - two_level_step(two_level_step_in), - restart(restart_in), - LM(LM_in) -{ - this->cal_type = calculation_type::lcao_exx; - if (this->restart) - { /// Now only Hexx depends on DM, so we can directly read Hexx to reduce the computational cost. - /// If other operators depends on DM, we can also read DM and then calculate the operators to save the memory to - /// store operator terms. - assert(this->two_level_step != nullptr); - /// read in Hexx - if (std::is_same::value) - { - this->LM->Hexxd_k_load.resize(this->kv.get_nks()); - for (int ik = 0; ik < this->kv.get_nks(); ik++) - { - this->LM->Hexxd_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0); - this->restart = GlobalC::restart.load_disk("Hexx", - ik, - this->LM->ParaV->get_local_size(), - this->LM->Hexxd_k_load[ik].data(), - false); - if (!this->restart) - break; - } - } - else - { - this->LM->Hexxc_k_load.resize(this->kv.get_nks()); - for (int ik = 0; ik < this->kv.get_nks(); ik++) - { - this->LM->Hexxc_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0); - this->restart = GlobalC::restart.load_disk("Hexx", - ik, - this->LM->ParaV->get_local_size(), - this->LM->Hexxc_k_load[ik].data(), - false); - if (!this->restart) - break; - } - } - if (!this->restart) - std::cout << "WARNING: Hexx not found, restart from the non-exx loop." << std::endl - << "If the loaded charge density is EXX-solved, this may lead to poor convergence." << std::endl; - GlobalC::restart.info_load.load_H_finish = this->restart; - } -} - -template -void OperatorEXX>::contributeHk(int ik) -{ - // Peize Lin add 2016-12-03 - if (GlobalV::CALCULATION != "nscf" && this->two_level_step != nullptr && *this->two_level_step == 0 && !this->restart) return; //in the non-exx loop, do nothing - if (XC_Functional::get_func_type() == 4 || XC_Functional::get_func_type() == 5) - { - if (this->restart && this->two_level_step != nullptr) - { - if (*this->two_level_step == 0) - { - this->add_loaded_Hexx(ik); - return; - } - else // clear loaded Hexx and release memory - { - if (this->LM->Hexxd_k_load.size() > 0) - { - this->LM->Hexxd_k_load.clear(); - this->LM->Hexxd_k_load.shrink_to_fit(); - } - else if (this->LM->Hexxc_k_load.size() > 0) - { - this->LM->Hexxc_k_load.clear(); - this->LM->Hexxc_k_load.shrink_to_fit(); - } - } - } - // cal H(k) from H(R) normally - - if (GlobalC::exx_info.info_ri.real_number) - RI_2D_Comm::add_Hexx( - this->kv, - ik, - GlobalC::exx_info.info_global.hybrid_alpha, - this->Hexxd == nullptr ? *this->LM->Hexxd : *this->Hexxd, - *this->LM->ParaV, - this->hsk->get_hk()); - else - RI_2D_Comm::add_Hexx( - this->kv, - ik, - GlobalC::exx_info.info_global.hybrid_alpha, - this->Hexxc == nullptr ? *this->LM->Hexxc : *this->Hexxc, - *this->LM->ParaV, - this->hsk->get_hk()); - } -} - -} // namespace hamilt -#endif // __EXX +#ifndef OPEXXLCAO_HPP +#define OPEXXLCAO_HPP + +#ifdef __EXX + +#include "op_exx_lcao.h" +#include "module_ri/RI_2D_Comm.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_hamilt_general/module_xc/xc_functional.h" + +namespace hamilt +{ + +template +OperatorEXX>::OperatorEXX(HS_Matrix_K* hsk_in, + LCAO_Matrix* LM_in, + hamilt::HContainer* hR_in, + const K_Vectors& kv_in, + std::vector>>>* Hexxd_in, + std::vector>>>>* Hexxc_in, + Add_Hexx_Type add_hexx_type_in, + int* two_level_step_in, + const bool restart_in) + : OperatorLCAO(hsk_in, kv_in.kvec_d, hR_in), + kv(kv_in), + Hexxd(Hexxd_in), + Hexxc(Hexxc_in), + add_hexx_type(add_hexx_type_in), + two_level_step(two_level_step_in), + restart(restart_in) +{ + this->cal_type = calculation_type::lcao_exx; + // if k points has no shift, use cell_nearest to reduce the memory cost + this->use_cell_nearest = (ModuleBase::Vector3(std::fmod(this->kv.get_koffset(0), 1.0), + std::fmod(this->kv.get_koffset(1), 1.0), std::fmod(this->kv.get_koffset(2), 1.0)).norm() < 1e-10); + + if (this->add_hexx_type == Add_Hexx_Type::R) + { + // set cell_nearest + std::map> atoms_pos; + for (int iat = 0; iat < GlobalC::ucell.nat; ++iat) { + atoms_pos[iat] = RI_Util::Vector3_to_array3( + GlobalC::ucell.atoms[GlobalC::ucell.iat2it[iat]] + .tau[GlobalC::ucell.iat2ia[iat]]); + } + const std::array, 3> latvec + = { RI_Util::Vector3_to_array3(GlobalC::ucell.a1), + RI_Util::Vector3_to_array3(GlobalC::ucell.a2), + RI_Util::Vector3_to_array3(GlobalC::ucell.a3) }; + const std::array Rs_period = { this->kv.nmp[0], this->kv.nmp[1], this->kv.nmp[2] }; + this->cell_nearest.init(atoms_pos, latvec, Rs_period); + + // reallocate hR if needed (Hexxd temp) + auto Rs = RI_Util::get_Born_von_Karmen_cells(Rs_period); + bool need_allocate = false; + for (int iat0 = 0;iat0 < GlobalC::ucell.nat;++iat0) + { + for (int iat1 = 0;iat1 < GlobalC::ucell.nat;++iat1) + { + for (auto& cell : Rs) + { + const Abfs::Vector3_Order& R = RI_Util::array3_to_Vector3( + (this->use_cell_nearest ? + cell_nearest.get_cell_nearest_discrete(iat0, iat1, cell) + : cell)); + hamilt::BaseMatrix* HlocR = this->hR->find_matrix(iat0, iat1, R.x, R.y, R.z); + if (HlocR == nullptr) + { // add R to HContainer + need_allocate = true; + hamilt::AtomPair tmp(iat0, iat1, R.x, R.y, R.z, this->LM->ParaV); + this->hR->insert_pair(tmp); + } + } + } + } + if (need_allocate) this->hR->allocate(nullptr, true); + } + + if (this->restart) + {/// Now only Hexx depends on DM, so we can directly read Hexx to reduce the computational cost. + /// If other operators depends on DM, we can also read DM and then calculate the operators to save the memory to store operator terms. + assert(this->two_level_step != nullptr); + + if (this->add_hexx_type == Add_Hexx_Type::k) + { + /// read in Hexx(k) + if (std::is_same::value) + { + this->LM->Hexxd_k_load.resize(this->kv.get_nks()); + for (int ik = 0; ik < this->kv.get_nks(); ik++) + { + this->LM->Hexxd_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0); + this->restart = GlobalC::restart.load_disk( + "Hexx", ik, + this->LM->ParaV->get_local_size(), this->LM->Hexxd_k_load[ik].data(), false); + if (!this->restart) break; + } + } + else + { + this->LM->Hexxc_k_load.resize(this->kv.get_nks()); + for (int ik = 0; ik < this->kv.get_nks(); ik++) + { + this->LM->Hexxc_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0); + this->restart = GlobalC::restart.load_disk( + "Hexx", ik, + this->LM->ParaV->get_local_size(), this->LM->Hexxc_k_load[ik].data(), false); + if (!this->restart) break; + } + } + } + else if (this->add_hexx_type == Add_Hexx_Type::R) + { + // refactor IO-csr functions + } + + if (!this->restart) + std::cout << "WARNING: Hexx not found, restart from the non-exx loop." << std::endl + << "If the loaded charge density is EXX-solved, this may lead to poor convergence." << std::endl; + GlobalC::restart.info_load.load_H_finish = this->restart; + } +} + +template +void OperatorEXX>::contributeHR() +{ + // Peize Lin add 2016-12-03 + if (GlobalV::CALCULATION != "nscf" && this->two_level_step != nullptr && *this->two_level_step == 0 && !this->restart) return; //in the non-exx loop, do nothing + if (XC_Functional::get_func_type() == 4 || XC_Functional::get_func_type() == 5) + { + if (this->restart && this->two_level_step != nullptr) + { + if (*this->two_level_step == 0) + { + this->add_loaded_HexxR(); + return; + } + else // clear loaded Hexx and release memory + { + this->clear_loaded_HexxR(); + } + } + // cal H(k) from H(R) normally + if (GlobalC::exx_info.info_ri.real_number) + RI_2D_Comm::add_HexxR( + this->current_spin, + GlobalC::exx_info.info_global.hybrid_alpha, + this->Hexxd == nullptr ? *this->LM->Hexxd : *this->Hexxd, + *this->LM->ParaV, + GlobalV::NPOL, + *this->hR, + this->use_cell_nearest ? &this->cell_nearest : nullptr); + else + RI_2D_Comm::add_HexxR( + this->current_spin, + GlobalC::exx_info.info_global.hybrid_alpha, + this->Hexxc == nullptr ? *this->LM->Hexxc : *this->Hexxc, + *this->LM->ParaV, + GlobalV::NPOL, + *this->hR, + this->use_cell_nearest ? &this->cell_nearest : nullptr); + } + if (GlobalV::NSPIN == 2) this->current_spin = 1 - this->current_spin; +} + +template +void OperatorEXX>::contributeHk(int ik) +{ + // Peize Lin add 2016-12-03 + if (GlobalV::CALCULATION != "nscf" && this->two_level_step != nullptr && *this->two_level_step == 0 && !this->restart) return; //in the non-exx loop, do nothing + if (XC_Functional::get_func_type() == 4 || XC_Functional::get_func_type() == 5) + { + if (this->restart && this->two_level_step != nullptr) + { + if (*this->two_level_step == 0) + { + this->add_loaded_Hexx(ik); + return; + } + else // clear loaded Hexx and release memory + { + if (this->LM->Hexxd_k_load.size() > 0) + { + this->LM->Hexxd_k_load.clear(); + this->LM->Hexxd_k_load.shrink_to_fit(); + } + else if (this->LM->Hexxc_k_load.size() > 0) + { + this->LM->Hexxc_k_load.clear(); + this->LM->Hexxc_k_load.shrink_to_fit(); + } + } + } + // cal H(k) from H(R) normally + + if (GlobalC::exx_info.info_ri.real_number) + RI_2D_Comm::add_Hexx( + this->kv, + ik, + GlobalC::exx_info.info_global.hybrid_alpha, + this->Hexxd == nullptr ? *this->LM->Hexxd : *this->Hexxd, + *this->LM->ParaV, + this->hsk->get_hk()); + else + RI_2D_Comm::add_Hexx( + this->kv, + ik, + GlobalC::exx_info.info_global.hybrid_alpha, + this->Hexxc == nullptr ? *this->LM->Hexxc : *this->Hexxc, + *this->LM->ParaV, + this->hsk->get_hk()); + } +} + +} // namespace hamilt +#endif // __EXX #endif // OPEXXLCAO_HPP \ No newline at end of file diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.cpp index 50937ff6af..f186685f1b 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.cpp @@ -185,7 +185,7 @@ void OperatorLCAO::init(const int ik_in) //update HK next //in cal_type=lcao_exx, HK only need to update from one node - this->contributeHk(ik_in); + // this->contributeHk(ik_in); break; } diff --git a/source/module_ri/RI_2D_Comm.h b/source/module_ri/RI_2D_Comm.h index d4be823ddf..c48ad2d056 100644 --- a/source/module_ri/RI_2D_Comm.h +++ b/source/module_ri/RI_2D_Comm.h @@ -8,6 +8,7 @@ #include "module_basis/module_ao/parallel_orbitals.h" #include "module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h" +#include "module_hamilt_lcao/module_hcontainer/hcontainer.h" #include "module_cell/klist.h" #include @@ -18,6 +19,7 @@ #include #include #include +#include namespace RI_2D_Comm { @@ -45,6 +47,16 @@ namespace RI_2D_Comm const Parallel_Orbitals& pv, TK* hk); + template + extern void add_HexxR( + const int current_spin, + const double alpha, + const std::vector>>>& Hs, + const Parallel_Orbitals& pv, + const int npol, + hamilt::HContainer& HlocR, + const RI::Cell_Nearest* const cell_nearest = nullptr); + template extern std::vector> Hexxs_to_Hk( const K_Vectors &kv, diff --git a/source/module_ri/RI_2D_Comm.hpp b/source/module_ri/RI_2D_Comm.hpp index af77cb9508..9d9f09ff09 100644 --- a/source/module_ri/RI_2D_Comm.hpp +++ b/source/module_ri/RI_2D_Comm.hpp @@ -215,4 +215,51 @@ int RI_2D_Comm::get_iwt(const int iat, const int iw_b, const int is_b) return iwt; } +template +void RI_2D_Comm::add_HexxR( + const int current_spin, + const double alpha, + const std::vector>>>& Hs, + const Parallel_Orbitals& pv, + const int npol, + hamilt::HContainer& hR, + const RI::Cell_Nearest* const cell_nearest) +{ + ModuleBase::TITLE("RI_2D_Comm", "add_HexxR"); + ModuleBase::timer::tick("RI_2D_Comm", "add_HexxR"); + + for (const auto& Hs_tmpA : Hs[current_spin]) + { + const TA& iat0 = Hs_tmpA.first; + for (const auto& Hs_tmpB : Hs_tmpA.second) + { + const TA& iat1 = Hs_tmpB.first.first; + const Abfs::Vector3_Order R = RI_Util::array3_to_Vector3( + (cell_nearest ? + cell_nearest->get_cell_nearest_discrete(iat0, iat1, Hs_tmpB.first.second) + : Hs_tmpB.first.second)); + hamilt::BaseMatrix* HlocR = hR.find_matrix(iat0, iat1, R.x, R.y, R.z); + if (HlocR == nullptr) + { // add R to HContainer + hamilt::AtomPair tmp(iat0, iat1, R.x, R.y, R.z, &pv); + hR.insert_pair(tmp); + HlocR = hR.find_matrix(iat0, iat1, R.x, R.y, R.z); + } + auto row_indexes = pv.get_indexes_row(iat0); + auto col_indexes = pv.get_indexes_col(iat1); + const RI::Tensor& HexxR = (Tdata)alpha * Hs_tmpB.second; + // std::cout << "iat0=" << iat0 << ", iat1=" << iat1 << std::endl; + for (int lw0 = 0;lw0 < row_indexes.size();lw0 += npol) + for (int lw1 = 0;lw1 < col_indexes.size();lw1 += npol) + { + const int& gw0 = row_indexes[lw0]; + const int& gw1 = col_indexes[lw1]; + // std::cout << "gw0=" << gw0 << ", gw1=" << gw1 << ", lw0=" << lw0 << ", lw1=" << lw1 << std::endl; + HlocR->add_element(lw0, lw1, RI::Global_Func::convert(HexxR(gw0, gw1))); + } + } + } + ModuleBase::timer::tick("RI_2D_Comm", "add_HexxR"); +} + #endif \ No newline at end of file