Skip to content

Commit

Permalink
Refactor csr-IO and apply to nscf and restart
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Jul 6, 2024
1 parent 3bd6367 commit d38bd99
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 227 deletions.
4 changes: 2 additions & 2 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ The band (KS orbital) energy for each (k-point, spin, band) will be printed in t
- auto: These files are saved in folder `OUT.${suffix}/restart/`;
- other: These files are saved in folder `${read_file_dir}/restart/`.

If EXX(exact exchange) is calculated (i.e. *[dft_fuctional](#dft_functional)==hse/hf/pbe0/scan0/opt_orb* or *[rpa](#rpa)==True*), the Hexx(k) files for each k-point will also be saved in the above folder, which can be read in EXX calculation with *[restart_load](#restart_load)==True*.
If EXX(exact exchange) is calculated (i.e. *[dft_fuctional](#dft_functional)==hse/hf/pbe0/scan0/opt_orb* or *[rpa](#rpa)==True*), the Hexx(R) files for each processor will also be saved in the above folder, which can be read in EXX calculation with *[restart_load](#restart_load)==True*.
- **Default**: False

### restart_load
Expand All @@ -1717,7 +1717,7 @@ The band (KS orbital) energy for each (k-point, spin, band) will be printed in t
- **Availability**: Numerical atomic orbital basis
- **Description**: If [restart_save](#restart_save) is set to true and an electronic iteration is finished, calculations can be restarted from the charge density file, which are saved in the former calculation. Please ensure [read_file_dir](#read_file_dir) is correct, and the charge density file exist.

If EXX(exact exchange) is calculated (i.e. *[dft_fuctional](#dft_functional)==hse/hf/pbe0/scan0/opt_orb* or *[rpa](#rpa)==True*), the Hexx(k) files in the same folder for each k-point will also be read.
If EXX(exact exchange) is calculated (i.e. *[dft_fuctional](#dft_functional)==hse/hf/pbe0/scan0/opt_orb* or *[rpa](#rpa)==True*), the Hexx(R) files in the same folder for each processor will also be read.
- **Default**: False

### rpa
Expand Down
17 changes: 15 additions & 2 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <memory>
#ifdef __EXX
#include "module_ri/RPA_LRI.h"
#include "module_io/restart_exx_csr.h"
#endif

#ifdef __DEEPKS
Expand Down Expand Up @@ -1010,6 +1011,8 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter) {
&& (!GlobalC::exx_info.info_global.separate_loop
|| iter == 1)) // to avoid saving the same value repeatedly
{
////////// for Add_Hexx_Type::k
/*
hamilt::HS_Matrix_K<TK> Hexxk_save(&this->orb_con.ParaV, 1);
for (int ik = 0; ik < this->kv.get_nks(); ++ik) {
Hexxk_save.set_zero_hk();
Expand All @@ -1025,6 +1028,16 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter) {
ik,
this->orb_con.ParaV.get_local_size(),
Hexxk_save.get_hk());
}*/
////////// for Add_Hexx_Type:R
const std::string& restart_HR_path = GlobalC::restart.folder + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
{
ModuleIO::write_Hexxs_csr(restart_HR_path, GlobalC::ucell, this->exd->get_Hexxs());
}
else
{
ModuleIO::write_Hexxs_csr(restart_HR_path, GlobalC::ucell, this->exc->get_Hexxs());
}
if (GlobalV::MY_RANK == 0) {
GlobalC::restart.save_disk("Eexx", 0, 1, &this->pelec->f_en.exx);
Expand Down Expand Up @@ -1144,9 +1157,9 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep) {
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR"
+ std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number) {
this->exd->write_Hexxs_csr(file_name_exx, GlobalC::ucell);
ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exd->get_Hexxs());
} else {
this->exc->write_Hexxs_csr(file_name_exx, GlobalC::ucell);
ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exc->get_Hexxs());
}
}
#endif
Expand Down
18 changes: 3 additions & 15 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h"
#include "module_io/rho_io.h"
#include "module_io/write_pot.h"
#ifdef __EXX
#include "module_io/restart_exx_csr.h"
#endif

namespace ModuleESolver {

Expand Down Expand Up @@ -599,21 +602,6 @@ void ESolver_KS_LCAO<TK, TR>::nscf() {

time_t time_start = std::time(nullptr);

#ifdef __EXX
#ifdef __MPI
// Peize Lin add 2018-08-14
if (GlobalC::exx_info.info_global.cal_exx) {
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR"
+ std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number) {
this->exd->read_Hexxs_csr(file_name_exx, GlobalC::ucell);
} else {
this->exc->read_Hexxs_csr(file_name_exx, GlobalC::ucell);
}
}
#endif // __MPI
#endif // __EXX

// mohan add 2021-02-09
// in ions, istep starts from 1,
// then when the istep is a variable of scf or nscf,
Expand Down
226 changes: 140 additions & 86 deletions source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,140 +7,193 @@
#include "module_ri/RI_2D_Comm.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_hamilt_general/module_xc/xc_functional.h"
#include "module_io/restart_exx_csr.h"

namespace hamilt
{
using TAC = std::pair<int, std::array<int, 3>>;
// allocate according to the read-in HexxR, used in nscf
template <typename Tdata, typename TR>
inline void reallocate_hcontainer(const std::vector<std::map<int, std::map<TAC, RI::Tensor<Tdata>>>>& Hexxs,
HContainer<TR>* hR)
{
bool need_allocate = false;
for (auto& Htmp1 : Hexxs[0])
{
const int& iat0 = Htmp1.first;
for (auto& Htmp2 : Htmp1.second)
{
const int& iat1 = Htmp2.first.first;
const Abfs::Vector3_Order<int>& R = RI_Util::array3_to_Vector3(Htmp2.first.second);
BaseMatrix<TR>* HlocR = hR->find_matrix(iat0, iat1, R.x, R.y, R.z);
if (HlocR == nullptr)
{ // add R to HContainer
need_allocate = true;
AtomPair<TR> tmp(iat0, iat1, R.x, R.y, R.z, hR->find_pair(0, 0)->get_paraV());
hR->insert_pair(tmp);
}
}
}
if (need_allocate) hR->allocate(nullptr, true);
}
/// allocate according to BvK cells, used in scf
template <typename TR>
inline void reallocate_hcontainer(const int nat, HContainer<TR>* hR,
const std::array<int, 3>& Rs_period,
const RI::Cell_Nearest<int, int, 3, double, 3>* const cell_nearest = nullptr)
{
auto Rs = RI_Util::get_Born_von_Karmen_cells(Rs_period);
bool need_allocate = false;
for (int iat0 = 0;iat0 < GlobalC::ucell.nat;++iat0)
{
for (int iat1 = 0;iat1 < GlobalC::ucell.nat;++iat1)
{
for (auto& cell : Rs)
{
const Abfs::Vector3_Order<int>& R = RI_Util::array3_to_Vector3(
(cell_nearest ?
cell_nearest->get_cell_nearest_discrete(iat0, iat1, cell)
: cell));
BaseMatrix<TR>* HlocR = hR->find_matrix(iat0, iat1, R.x, R.y, R.z);
if (HlocR == nullptr)
{ // add R to HContainer
need_allocate = true;
AtomPair<TR> tmp(iat0, iat1, R.x, R.y, R.z, hR->find_pair(0, 0)->get_paraV());
hR->insert_pair(tmp);
}
}
}
}
if (need_allocate) hR->allocate(nullptr, true);
}

template <typename TK, typename TR>
OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
LCAO_Matrix* LM_in,
hamilt::HContainer<TR>* hR_in,
HContainer<TR>*hR_in,
const K_Vectors& kv_in,
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
Add_Hexx_Type add_hexx_type_in,
int* two_level_step_in,
const bool restart_in)
: OperatorLCAO<TK, TR>(hsk_in, kv_in.kvec_d, hR_in),
LM(LM_in),
kv(kv_in),
Hexxd(Hexxd_in),
Hexxc(Hexxc_in),
add_hexx_type(add_hexx_type_in),
two_level_step(two_level_step_in),
restart(restart_in)
{
ModuleBase::TITLE("OperatorEXX", "OperatorEXX");
this->cal_type = calculation_type::lcao_exx;
// if k points has no shift, use cell_nearest to reduce the memory cost
this->use_cell_nearest = (ModuleBase::Vector3<double>(std::fmod(this->kv.get_koffset(0), 1.0),
std::fmod(this->kv.get_koffset(1), 1.0), std::fmod(this->kv.get_koffset(2), 1.0)).norm() < 1e-10);

if (this->add_hexx_type == Add_Hexx_Type::R)
{
// set cell_nearest
std::map<int, std::array<double, 3>> atoms_pos;
for (int iat = 0; iat < GlobalC::ucell.nat; ++iat) {
atoms_pos[iat] = RI_Util::Vector3_to_array3(
GlobalC::ucell.atoms[GlobalC::ucell.iat2it[iat]]
.tau[GlobalC::ucell.iat2ia[iat]]);
if (GlobalV::CALCULATION == "nscf")
{ // if nscf, read HexxR first and reallocate hR according to the read-in HexxR
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
{
ModuleIO::read_Hexxs_csr(file_name_exx, GlobalC::ucell, GlobalV::NSPIN, GlobalV::NLOCAL, *Hexxd);
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
}
const std::array<std::array<double, 3>, 3> latvec
= { RI_Util::Vector3_to_array3(GlobalC::ucell.a1),
RI_Util::Vector3_to_array3(GlobalC::ucell.a2),
RI_Util::Vector3_to_array3(GlobalC::ucell.a3) };
const std::array<int, 3> Rs_period = { this->kv.nmp[0], this->kv.nmp[1], this->kv.nmp[2] };
this->cell_nearest.init(atoms_pos, latvec, Rs_period);

// reallocate hR if needed (Hexxd temp)
auto Rs = RI_Util::get_Born_von_Karmen_cells(Rs_period);
bool need_allocate = false;
for (int iat0 = 0;iat0 < GlobalC::ucell.nat;++iat0)
else
{
for (int iat1 = 0;iat1 < GlobalC::ucell.nat;++iat1)
ModuleIO::read_Hexxs_csr(file_name_exx, GlobalC::ucell, GlobalV::NSPIN, GlobalV::NLOCAL, *Hexxc);
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
}
this->use_cell_nearest = false;
}
else
{ // if scf and Add_Hexx_Type::R, init cell_nearest and reallocate hR according to BvK cells
if (this->add_hexx_type == Add_Hexx_Type::R)
{
// if k points has no shift, use cell_nearest to reduce the memory cost
this->use_cell_nearest = (ModuleBase::Vector3<double>(std::fmod(this->kv.get_koffset(0), 1.0),
std::fmod(this->kv.get_koffset(1), 1.0), std::fmod(this->kv.get_koffset(2), 1.0)).norm() < 1e-10);

const std::array<int, 3> Rs_period = { this->kv.nmp[0], this->kv.nmp[1], this->kv.nmp[2] };
if (this->use_cell_nearest)
{
for (auto& cell : Rs)
{
const Abfs::Vector3_Order<int>& R = RI_Util::array3_to_Vector3(
(this->use_cell_nearest ?
cell_nearest.get_cell_nearest_discrete(iat0, iat1, cell)
: cell));
hamilt::BaseMatrix<TR>* HlocR = this->hR->find_matrix(iat0, iat1, R.x, R.y, R.z);
if (HlocR == nullptr)
{ // add R to HContainer
need_allocate = true;
hamilt::AtomPair<TR> tmp(iat0, iat1, R.x, R.y, R.z, this->LM->ParaV);
this->hR->insert_pair(tmp);
}
// set cell_nearest
std::map<int, std::array<double, 3>> atoms_pos;
for (int iat = 0; iat < GlobalC::ucell.nat; ++iat) {
atoms_pos[iat] = RI_Util::Vector3_to_array3(
GlobalC::ucell.atoms[GlobalC::ucell.iat2it[iat]]
.tau[GlobalC::ucell.iat2ia[iat]]);
}
const std::array<std::array<double, 3>, 3> latvec
= { RI_Util::Vector3_to_array3(GlobalC::ucell.a1),
RI_Util::Vector3_to_array3(GlobalC::ucell.a2),
RI_Util::Vector3_to_array3(GlobalC::ucell.a3) };
this->cell_nearest.init(atoms_pos, latvec, Rs_period);
reallocate_hcontainer(GlobalC::ucell.nat, this->hR, Rs_period, &this->cell_nearest);
}
else { reallocate_hcontainer(GlobalC::ucell.nat, this->hR, Rs_period); }
}
if (need_allocate) this->hR->allocate(nullptr, true);
}

if (this->restart)
{/// Now only Hexx depends on DM, so we can directly read Hexx to reduce the computational cost.
/// If other operators depends on DM, we can also read DM and then calculate the operators to save the memory to store operator terms.
assert(this->two_level_step != nullptr);
if (this->restart)
{/// Now only Hexx depends on DM, so we can directly read Hexx to reduce the computational cost.
/// If other operators depends on DM, we can also read DM and then calculate the operators to save the memory to store operator terms.
assert(this->two_level_step != nullptr);

if (this->add_hexx_type == Add_Hexx_Type::k)
{
/// read in Hexx(k)
if (std::is_same<TK, double>::value)
if (this->add_hexx_type == Add_Hexx_Type::k)
{
this->LM->Hexxd_k_load.resize(this->kv.get_nks());
for (int ik = 0; ik < this->kv.get_nks(); ik++)
/// read in Hexx(k)
if (std::is_same<TK, double>::value)
{
this->LM->Hexxd_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0);
this->restart = GlobalC::restart.load_disk(
"Hexx", ik,
this->LM->ParaV->get_local_size(), this->LM->Hexxd_k_load[ik].data(), false);
if (!this->restart) break;
this->LM->Hexxd_k_load.resize(this->kv.get_nks());
for (int ik = 0; ik < this->kv.get_nks(); ik++)
{
this->LM->Hexxd_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0);
this->restart = GlobalC::restart.load_disk(
"Hexx", ik,
this->LM->ParaV->get_local_size(), this->LM->Hexxd_k_load[ik].data(), false);
if (!this->restart) break;
}
}
else
{
this->LM->Hexxc_k_load.resize(this->kv.get_nks());
for (int ik = 0; ik < this->kv.get_nks(); ik++)
{
this->LM->Hexxc_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0);
this->restart = GlobalC::restart.load_disk(
"Hexx", ik,
this->LM->ParaV->get_local_size(), this->LM->Hexxc_k_load[ik].data(), false);
if (!this->restart) break;
}
}
}
else
else if (this->add_hexx_type == Add_Hexx_Type::R)
{
this->LM->Hexxc_k_load.resize(this->kv.get_nks());
for (int ik = 0; ik < this->kv.get_nks(); ik++)
{
this->LM->Hexxc_k_load[ik].resize(this->LM->ParaV->get_local_size(), 0.0);
this->restart = GlobalC::restart.load_disk(
"Hexx", ik,
this->LM->ParaV->get_local_size(), this->LM->Hexxc_k_load[ik].data(), false);
if (!this->restart) break;
// read in Hexx(R)
const std::string restart_HR_path = GlobalC::restart.folder + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number) {
ModuleIO::read_Hexxs_csr(restart_HR_path, GlobalC::ucell, GlobalV::NSPIN, GlobalV::NLOCAL, *Hexxd);
}
else {
ModuleIO::read_Hexxs_csr(restart_HR_path, GlobalC::ucell, GlobalV::NSPIN, GlobalV::NLOCAL, *Hexxc);
}
}
}
else if (this->add_hexx_type == Add_Hexx_Type::R)
{
// refactor IO-csr functions
}

if (!this->restart)
std::cout << "WARNING: Hexx not found, restart from the non-exx loop." << std::endl
<< "If the loaded charge density is EXX-solved, this may lead to poor convergence." << std::endl;
GlobalC::restart.info_load.load_H_finish = this->restart;
}
if (!this->restart)
std::cout << "WARNING: Hexx not found, restart from the non-exx loop." << std::endl
<< "If the loaded charge density is EXX-solved, this may lead to poor convergence." << std::endl;
GlobalC::restart.info_load.load_H_finish = this->restart;
}
}
}

template<typename TK, typename TR>
void OperatorEXX<OperatorLCAO<TK, TR>>::contributeHR()
{
ModuleBase::TITLE("OperatorEXX", "contributeHR");
// Peize Lin add 2016-12-03
if (GlobalV::CALCULATION != "nscf" && this->two_level_step != nullptr && *this->two_level_step == 0 && !this->restart) return; //in the non-exx loop, do nothing
if (XC_Functional::get_func_type() == 4 || XC_Functional::get_func_type() == 5)
{
if (this->restart && this->two_level_step != nullptr)
{
if (*this->two_level_step == 0)
{
this->add_loaded_HexxR();
return;
}
else // clear loaded Hexx and release memory
{
this->clear_loaded_HexxR();
}
}
// cal H(k) from H(R) normally
// add H(R) normally
if (GlobalC::exx_info.info_ri.real_number)
RI_2D_Comm::add_HexxR(
this->current_spin,
Expand All @@ -166,6 +219,7 @@ void OperatorEXX<OperatorLCAO<TK, TR>>::contributeHR()
template<typename TK, typename TR>
void OperatorEXX<OperatorLCAO<TK, TR>>::contributeHk(int ik)
{
ModuleBase::TITLE("OperatorEXX", "constributeHR");
// Peize Lin add 2016-12-03
if (GlobalV::CALCULATION != "nscf" && this->two_level_step != nullptr && *this->two_level_step == 0 && !this->restart) return; //in the non-exx loop, do nothing
if (XC_Functional::get_func_type() == 4 || XC_Functional::get_func_type() == 5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ void OperatorLCAO<TK, TR>::init(const int ik_in)
case calculation_type::lcao_exx:
{
//update HR first
//in cal_type=lcao_exx, HR should be updated by most priority sub-chain nodes
this->contributeHR();
if (!this->hr_done)
{
this->contributeHR();
}

//update HK next
//in cal_type=lcao_exx, HK only need to update from one node
Expand Down
Loading

0 comments on commit d38bd99

Please sign in to comment.