From c4eebcb8b2870d9d92a14ada81be2c1dcbb57b2a Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Sun, 10 Dec 2023 01:31:44 +0800 Subject: [PATCH] add to hK passed in rather than LM.Hloc --- .../hamilt_lcaodft/LCAO_matrix.cpp | 70 ++- .../hamilt_lcaodft/LCAO_matrix.h | 3 +- .../operator_lcao/op_exx_lcao.cpp | 22 +- source/module_ri/RI_2D_Comm.h | 141 +++--- source/module_ri/RI_2D_Comm.hpp | 411 +++++++++--------- 5 files changed, 314 insertions(+), 333 deletions(-) diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.cpp index d65aaf5e69..5d96c55a4e 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.cpp @@ -139,16 +139,18 @@ void LCAO_Matrix::allocate_HS_R(const int &nnR) // set 'dtype' matrix element (iw1_all, iw2_all) with // an input value 'v' //------------------------------------------------------ -void LCAO_Matrix::set_HSgamma( - const int &iw1_all, // index i for atomic orbital (row) - const int &iw2_all, // index j for atomic orbital (column) - const double& v, // value for matrix element (i,j) - double* HSloc) //input pointer for store the matrix +template +void LCAO_Matrix::set_mat2d( + const int& global_ir, // index i for atomic orbital (row) + const int& global_ic, // index j for atomic orbital (column) + const T& v, // value for matrix element (i,j) + const Parallel_Orbitals& pv, + T* HSloc) //input pointer for store the matrix { // use iw1_all and iw2_all to set Hloc // becareful! The ir and ic may be < 0 !!! - const int ir = this->ParaV->global2local_row(iw1_all); - const int ic = this->ParaV->global2local_col(iw2_all); + const int ir = pv.global2local_row(global_ir); + const int ic = pv.global2local_col(global_ic); //const int index = ir * ParaO.ncol + ic; long index=0; @@ -156,22 +158,22 @@ void LCAO_Matrix::set_HSgamma( // save the matrix as column major format if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER()) { - index=ic*this->ParaV->nrow+ir; + index = ic * pv.nrow + ir; } else { - index=ir*this->ParaV->ncol+ic; + index = ir * pv.ncol + ic; } - if( index >= this->ParaV->nloc) + if (index >= pv.nloc) { - std::cout << " iw1_all = " << iw1_all << std::endl; - std::cout << " iw2_all = " << iw2_all << std::endl; + std::cout << " iw1_all = " << global_ir << std::endl; + std::cout << " iw2_all = " << global_ic << std::endl; std::cout << " ir = " << ir << std::endl; std::cout << " ic = " << ic << std::endl; std::cout << " index = " << index << std::endl; - std::cout << " this->ParaV->nloc = " << this->ParaV->nloc << std::endl; - ModuleBase::WARNING_QUIT("LCAO_Matrix","set_HSgamma"); + std::cout << " ParaV->nloc = " << pv.nloc << std::endl; + ModuleBase::WARNING_QUIT("LCAO_Matrix", "set_mat2d"); } //using input pointer HSloc @@ -180,41 +182,21 @@ void LCAO_Matrix::set_HSgamma( return; } -void LCAO_Matrix::set_HSk(const int &iw1_all, const int &iw2_all, const std::complex &v, const char &dtype, const int spin) +void LCAO_Matrix::set_HSgamma(const int& iw1_all, const int& iw2_all, const double& v, double* HSloc) +{ + LCAO_Matrix::set_mat2d(iw1_all, iw2_all, v, *this->ParaV, HSloc); + return; +} +void LCAO_Matrix::set_HSk(const int& iw1_all, const int& iw2_all, const std::complex& v, const char& dtype, const int spin) { - // use iw1_all and iw2_all to set Hloc - // becareful! The ir and ic may < 0!!!!!!!!!!!!!!!! - const int ir = this->ParaV->global2local_row(iw1_all); - const int ic = this->ParaV->global2local_col(iw2_all); - //const int index = ir * this->ParaV->ncol + ic; - long index; - if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER()) - { - index=ic*this->ParaV->nrow+ir; - } - else - { - index=ir*this->ParaV->ncol+ic; - } - assert(index < this->ParaV->nloc); if (dtype=='S')//overlap Hamiltonian. - { - this->Sloc2[index] += v; - } + LCAO_Matrix::set_mat2d>(iw1_all, iw2_all, v, *this->ParaV, this->Sloc2.data()); else if (dtype=='T' || dtype=='N')// kinetic and nonlocal Hamiltonian. - { - this->Hloc_fixed2[index] += v; // because kinetic and nonlocal Hamiltonian matrices are already block-cycle staraged after caculated in lcao_nnr.cpp - // this statement will not be used. - } + LCAO_Matrix::set_mat2d>(iw1_all, iw2_all, v, *this->ParaV, this->Hloc_fixed2.data()); else if (dtype=='L') // Local potential Hamiltonian. - { - this->Hloc2[index] += v; - } + LCAO_Matrix::set_mat2d>(iw1_all, iw2_all, v, *this->ParaV, this->Hloc2.data()); else - { - ModuleBase::WARNING_QUIT("LCAO_Matrix","set_HSk"); - } - + ModuleBase::WARNING_QUIT("LCAO_Matrix", "set_HSk"); return; } diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h b/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h index ebe09667d4..4f6f7a5c64 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h +++ b/source/module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h @@ -189,7 +189,8 @@ class LCAO_Matrix double* DHloc_fixed_23; double* DHloc_fixed_33; - + template + static void set_mat2d(const int& global_ir, const int& global_ic, const T& v, const Parallel_Orbitals& pv, T* mat); void set_HSgamma(const int& iw1_all, const int& iw2_all, const double& v, double* HSloc); void set_HSk(const int &iw1_all, const int &iw2_all, const std::complex &v, const char &dtype, const int spin = 0); diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.cpp index 67f0be54e1..d7af2691d5 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.cpp @@ -33,15 +33,17 @@ void OperatorEXX>::contributeHk(int ik) kv, ik, GlobalC::exx_info.info_global.hybrid_alpha, - *this->LM->Hexxd, - *this->LM); + *this->LM->Hexxd, + *this->LM->ParaV, + *this->hK); else RI_2D_Comm::add_Hexx( kv, ik, GlobalC::exx_info.info_global.hybrid_alpha, - *this->LM->Hexxc, - *this->LM); + *this->LM->Hexxc, + *this->LM->ParaV, + *this->hK); } } @@ -57,14 +59,16 @@ void OperatorEXX, double>>::contributeHk(int i ik, GlobalC::exx_info.info_global.hybrid_alpha, *this->LM->Hexxd, - *this->LM); + *this->LM->ParaV, + *this->hK); else RI_2D_Comm::add_Hexx( kv, ik, GlobalC::exx_info.info_global.hybrid_alpha, *this->LM->Hexxc, - *this->LM); + *this->LM->ParaV, + *this->hK); } } @@ -80,14 +84,16 @@ void OperatorEXX, std::complex>>::cont ik, GlobalC::exx_info.info_global.hybrid_alpha, *this->LM->Hexxd, - *this->LM); + *this->LM->ParaV, + *this->hK); else RI_2D_Comm::add_Hexx( kv, ik, GlobalC::exx_info.info_global.hybrid_alpha, *this->LM->Hexxc, - *this->LM); + *this->LM->ParaV, + *this->hK); } } diff --git a/source/module_ri/RI_2D_Comm.h b/source/module_ri/RI_2D_Comm.h index 7194035271..51b5c6a996 100644 --- a/source/module_ri/RI_2D_Comm.h +++ b/source/module_ri/RI_2D_Comm.h @@ -1,71 +1,72 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-08-17 -//======================= - -#ifndef RI_2D_COMM_H -#define RI_2D_COMM_H - -#include "module_basis/module_ao/parallel_orbitals.h" -#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h" -#include "module_cell/klist.h" - -#include - -#include -#include -#include -#include -#include -#include - -namespace RI_2D_Comm -{ - using TA = int; - using Tcell = int; - static const size_t Ndim = 3; - using TC = std::array; - using TAC = std::pair; - -//public: - template - extern std::vector>>> - split_m2D_ktoR(const K_Vectors &kv, const std::vector &mks_2D, const Parallel_Orbitals &pv); - - // judge[is] = {s0, s1} - extern std::vector, std::set>> - get_2D_judge(const Parallel_Orbitals &pv); - - template - extern void add_Hexx( - const K_Vectors &kv, - const int ik, - const double alpha, - const std::vector>>> &Hs, - LCAO_Matrix &lm); - - template - extern std::vector> Hexxs_to_Hk( - const K_Vectors &kv, - const Parallel_Orbitals &pv, - const std::vector< std::map>>> &Hexxs, - const int ik); - template - std::vector> pulay_mixing( - const Parallel_Orbitals &pv, - std::deque>> &Hk_seq, - const std::vector> &Hk_new, - const double mixing_beta, - const std::string mixing_mode); - -//private: - extern std::vector get_ik_list(const K_Vectors &kv, const int is_k); - extern inline std::tuple get_iat_iw_is_block(const int iwt); - extern inline int get_is_block(const int is_k, const int is_row_b, const int is_col_b); - extern inline std::tuple split_is_block(const int is_b); - extern inline int get_iwt(const int iat, const int iw_b, const int is_b); -} - -#include "RI_2D_Comm.hpp" - +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-08-17 +//======================= + +#ifndef RI_2D_COMM_H +#define RI_2D_COMM_H + +#include "module_basis/module_ao/parallel_orbitals.h" +#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_matrix.h" +#include "module_cell/klist.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace RI_2D_Comm +{ + using TA = int; + using Tcell = int; + static const size_t Ndim = 3; + using TC = std::array; + using TAC = std::pair; + +//public: + template + extern std::vector>>> + split_m2D_ktoR(const K_Vectors &kv, const std::vector &mks_2D, const Parallel_Orbitals &pv); + + // judge[is] = {s0, s1} + extern std::vector, std::set>> + get_2D_judge(const Parallel_Orbitals &pv); + + template + extern void add_Hexx( + const K_Vectors& kv, + const int ik, + const double alpha, + const std::vector>>>& Hs, + const Parallel_Orbitals& pv, + std::vector& Hloc); + + template + extern std::vector> Hexxs_to_Hk( + const K_Vectors &kv, + const Parallel_Orbitals &pv, + const std::vector< std::map>>> &Hexxs, + const int ik); + template + std::vector> pulay_mixing( + const Parallel_Orbitals &pv, + std::deque>> &Hk_seq, + const std::vector> &Hk_new, + const double mixing_beta, + const std::string mixing_mode); + +//private: + extern std::vector get_ik_list(const K_Vectors &kv, const int is_k); + extern inline std::tuple get_iat_iw_is_block(const int iwt); + extern inline int get_is_block(const int is_k, const int is_row_b, const int is_col_b); + extern inline std::tuple split_is_block(const int is_b); + extern inline int get_iwt(const int iat, const int iw_b, const int is_b); +} + +#include "RI_2D_Comm.hpp" + #endif \ No newline at end of file diff --git a/source/module_ri/RI_2D_Comm.hpp b/source/module_ri/RI_2D_Comm.hpp index 332183dbbc..2575de63fd 100644 --- a/source/module_ri/RI_2D_Comm.hpp +++ b/source/module_ri/RI_2D_Comm.hpp @@ -1,211 +1,202 @@ -//======================= -// AUTHOR : Peize Lin -// DATE : 2022-08-17 -//======================= - -#ifndef RI_2D_COMM_HPP -#define RI_2D_COMM_HPP - -#include "RI_2D_Comm.h" -#include "RI_Util.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" -#include "module_base/tool_title.h" -#include "module_base/timer.h" - -#include - -#include -#include -#include - -inline RI::Tensor tensor_conj(const RI::Tensor& t) { return t; } -inline RI::Tensor> tensor_conj(const RI::Tensor>& t) -{ - RI::Tensor> r(t.shape); - for (int i = 0;i < t.data->size();++i)(*r.data)[i] = std::conj((*t.data)[i]); - return r; -} -template -auto RI_2D_Comm::split_m2D_ktoR(const K_Vectors &kv, const std::vector &mks_2D, const Parallel_Orbitals &pv) --> std::vector>>> -{ - ModuleBase::TITLE("RI_2D_Comm","split_m2D_ktoR"); - ModuleBase::timer::tick("RI_2D_Comm", "split_m2D_ktoR"); - - const TC period = RI_Util::get_Born_vonKarmen_period(kv); - const std::map nspin_k = {{1,1}, {2,2}, {4,1}}; - const double SPIN_multiple = std::map{{1,0.5}, {2,1}, {4,1}}.at(GlobalV::NSPIN); // why? - - std::vector>>> mRs_a2D(GlobalV::NSPIN); - for(int is_k=0; is_k ik_list = RI_2D_Comm::get_ik_list(kv, is_k); - for(const TC &cell : RI_Util::get_Born_von_Karmen_cells(period)) - { - RI::Tensor mR_2D; - for(const int ik : ik_list) - { - using Tdata_m = typename Tmatrix::value_type; - RI::Tensor mk_2D = RI_Util::Vector_to_Tensor(*mks_2D[ik], pv.get_col_size(), pv.get_row_size()); - const Tdata_m frac = SPIN_multiple - * RI::Global_Func::convert( std::exp( - -ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (kv.kvec_c[ik] * (RI_Util::array3_to_Vector3(cell) * GlobalC::ucell.latvec)))); - auto set_mR_2D = [&mR_2D](auto&& mk_frac) { - if (mR_2D.empty()) - mR_2D = RI::Global_Func::convert(mk_frac); - else - mR_2D = mR_2D + RI::Global_Func::convert(mk_frac); - }; - if (static_cast(std::round(SPIN_multiple * kv.wk[ik] * kv.nkstot_full)) == 2) - set_mR_2D(mk_2D * (frac * 0.5) + tensor_conj(mk_2D * (frac * 0.5))); - else set_mR_2D(mk_2D * frac); - } - - for(int iwt0_2D=0; iwt0_2D!=mR_2D.shape[0]; ++iwt0_2D) - { - const int iwt0 = - ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER() - ? pv.local2global_col(iwt0_2D) - : pv.local2global_row(iwt0_2D); - int iat0, iw0_b, is0_b; - std::tie(iat0,iw0_b,is0_b) = RI_2D_Comm::get_iat_iw_is_block(iwt0); - const int it0 = GlobalC::ucell.iat2it[iat0]; - for(int iwt1_2D=0; iwt1_2D!=mR_2D.shape[1]; ++iwt1_2D) - { - const int iwt1 = - ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER() - ? pv.local2global_row(iwt1_2D) - : pv.local2global_col(iwt1_2D); - int iat1, iw1_b, is1_b; - std::tie(iat1,iw1_b,is1_b) = RI_2D_Comm::get_iat_iw_is_block(iwt1); - const int it1 = GlobalC::ucell.iat2it[iat1]; - - const int is_b = RI_2D_Comm::get_is_block(is_k, is0_b, is1_b); - RI::Tensor &mR_a2D = mRs_a2D[is_b][iat0][{iat1,cell}]; - if(mR_a2D.empty()) - mR_a2D = RI::Tensor({static_cast(GlobalC::ucell.atoms[it0].nw), static_cast(GlobalC::ucell.atoms[it1].nw)}); - mR_a2D(iw0_b,iw1_b) = mR_2D(iwt0_2D, iwt1_2D); - } - } - } - } - ModuleBase::timer::tick("RI_2D_Comm", "split_m2D_ktoR"); - return mRs_a2D; -} - - -template -void RI_2D_Comm::add_Hexx( - const K_Vectors &kv, - const int ik, - const double alpha, - const std::vector>>> &Hs, - LCAO_Matrix &lm) -{ - ModuleBase::TITLE("RI_2D_Comm","add_Hexx"); - ModuleBase::timer::tick("RI_2D_Comm", "add_Hexx"); - - const Parallel_Orbitals& pv = *lm.ParaV; - - const std::map> is_list = {{1,{0}}, {2,{kv.isk[ik]}}, {4,{0,1,2,3}}}; - for(const int is_b : is_list.at(GlobalV::NSPIN)) - { - int is0_b, is1_b; - std::tie(is0_b,is1_b) = RI_2D_Comm::split_is_block(is_b); - for(const auto &Hs_tmpA : Hs[is_b]) - { - const TA &iat0 = Hs_tmpA.first; - for(const auto &Hs_tmpB : Hs_tmpA.second) - { - const TA &iat1 = Hs_tmpB.first.first; - const TC &cell1 = Hs_tmpB.first.second; - const std::complex frac = alpha - * std::exp( ModuleBase::TWO_PI*ModuleBase::IMAG_UNIT * (kv.kvec_c[ik] * (RI_Util::array3_to_Vector3(cell1)*GlobalC::ucell.latvec)) ); - const RI::Tensor &H = Hs_tmpB.second; - for(size_t iw0_b=0; iw0_b(H(iw0_b, iw1_b)) * RI::Global_Func::convert(frac), - lm.Hloc.data()); - else - lm.set_HSk(iwt0, iwt1, - RI::Global_Func::convert>(H(iw0_b, iw1_b)) * frac, - 'L', -1); - } - } - } - } - } - ModuleBase::timer::tick("RI_2D_Comm", "add_Hexx"); -} - -std::tuple -RI_2D_Comm::get_iat_iw_is_block(const int iwt) -{ - const int iat = GlobalC::ucell.iwt2iat[iwt]; - const int iw = GlobalC::ucell.iwt2iw[iwt]; - switch(GlobalV::NSPIN) - { - case 1: case 2: - return std::make_tuple(iat, iw, 0); - case 4: - return std::make_tuple(iat, iw/2, iw%2); - default: - throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - } -} - -int RI_2D_Comm::get_is_block(const int is_k, const int is_row_b, const int is_col_b) -{ - switch(GlobalV::NSPIN) - { - case 1: return 0; - case 2: return is_k; - case 4: return is_row_b*2+is_col_b; - default: throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - } -} - -std::tuple -RI_2D_Comm::split_is_block(const int is_b) -{ - switch(GlobalV::NSPIN) - { - case 1: case 2: - return std::make_tuple(0, 0); - case 4: - return std::make_tuple(is_b/2, is_b%2); - default: - throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - } -} - - - -int RI_2D_Comm::get_iwt(const int iat, const int iw_b, const int is_b) -{ - const int it = GlobalC::ucell.iat2it[iat]; - const int ia = GlobalC::ucell.iat2ia[iat]; - int iw=-1; - switch(GlobalV::NSPIN) - { - case 1: case 2: - iw = iw_b; break; - case 4: - iw = iw_b*2+is_b; break; - default: - throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); - } - const int iwt = GlobalC::ucell.itiaiw2iwt(it,ia,iw); - return iwt; -} - +//======================= +// AUTHOR : Peize Lin +// DATE : 2022-08-17 +//======================= + +#ifndef RI_2D_COMM_HPP +#define RI_2D_COMM_HPP + +#include "RI_2D_Comm.h" +#include "RI_Util.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_base/tool_title.h" +#include "module_base/timer.h" + +#include + +#include +#include +#include + +inline RI::Tensor tensor_conj(const RI::Tensor& t) { return t; } +inline RI::Tensor> tensor_conj(const RI::Tensor>& t) +{ + RI::Tensor> r(t.shape); + for (int i = 0;i < t.data->size();++i)(*r.data)[i] = std::conj((*t.data)[i]); + return r; +} +template +auto RI_2D_Comm::split_m2D_ktoR(const K_Vectors &kv, const std::vector &mks_2D, const Parallel_Orbitals &pv) +-> std::vector>>> +{ + ModuleBase::TITLE("RI_2D_Comm","split_m2D_ktoR"); + ModuleBase::timer::tick("RI_2D_Comm", "split_m2D_ktoR"); + + const TC period = RI_Util::get_Born_vonKarmen_period(kv); + const std::map nspin_k = {{1,1}, {2,2}, {4,1}}; + const double SPIN_multiple = std::map{{1,0.5}, {2,1}, {4,1}}.at(GlobalV::NSPIN); // why? + + std::vector>>> mRs_a2D(GlobalV::NSPIN); + for(int is_k=0; is_k ik_list = RI_2D_Comm::get_ik_list(kv, is_k); + for(const TC &cell : RI_Util::get_Born_von_Karmen_cells(period)) + { + RI::Tensor mR_2D; + for(const int ik : ik_list) + { + using Tdata_m = typename Tmatrix::value_type; + RI::Tensor mk_2D = RI_Util::Vector_to_Tensor(*mks_2D[ik], pv.get_col_size(), pv.get_row_size()); + const Tdata_m frac = SPIN_multiple + * RI::Global_Func::convert( std::exp( + -ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (kv.kvec_c[ik] * (RI_Util::array3_to_Vector3(cell) * GlobalC::ucell.latvec)))); + auto set_mR_2D = [&mR_2D](auto&& mk_frac) { + if (mR_2D.empty()) + mR_2D = RI::Global_Func::convert(mk_frac); + else + mR_2D = mR_2D + RI::Global_Func::convert(mk_frac); + }; + if (static_cast(std::round(SPIN_multiple * kv.wk[ik] * kv.nkstot_full)) == 2) + set_mR_2D(mk_2D * (frac * 0.5) + tensor_conj(mk_2D * (frac * 0.5))); + else set_mR_2D(mk_2D * frac); + } + + for(int iwt0_2D=0; iwt0_2D!=mR_2D.shape[0]; ++iwt0_2D) + { + const int iwt0 = + ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER() + ? pv.local2global_col(iwt0_2D) + : pv.local2global_row(iwt0_2D); + int iat0, iw0_b, is0_b; + std::tie(iat0,iw0_b,is0_b) = RI_2D_Comm::get_iat_iw_is_block(iwt0); + const int it0 = GlobalC::ucell.iat2it[iat0]; + for(int iwt1_2D=0; iwt1_2D!=mR_2D.shape[1]; ++iwt1_2D) + { + const int iwt1 = + ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER() + ? pv.local2global_row(iwt1_2D) + : pv.local2global_col(iwt1_2D); + int iat1, iw1_b, is1_b; + std::tie(iat1,iw1_b,is1_b) = RI_2D_Comm::get_iat_iw_is_block(iwt1); + const int it1 = GlobalC::ucell.iat2it[iat1]; + + const int is_b = RI_2D_Comm::get_is_block(is_k, is0_b, is1_b); + RI::Tensor &mR_a2D = mRs_a2D[is_b][iat0][{iat1,cell}]; + if(mR_a2D.empty()) + mR_a2D = RI::Tensor({static_cast(GlobalC::ucell.atoms[it0].nw), static_cast(GlobalC::ucell.atoms[it1].nw)}); + mR_a2D(iw0_b,iw1_b) = mR_2D(iwt0_2D, iwt1_2D); + } + } + } + } + ModuleBase::timer::tick("RI_2D_Comm", "split_m2D_ktoR"); + return mRs_a2D; +} + + +template +void RI_2D_Comm::add_Hexx( + const K_Vectors &kv, + const int ik, + const double alpha, + const std::vector>>> &Hs, + const Parallel_Orbitals& pv, + std::vector& Hloc) +{ + ModuleBase::TITLE("RI_2D_Comm","add_Hexx"); + ModuleBase::timer::tick("RI_2D_Comm", "add_Hexx"); + + const std::map> is_list = {{1,{0}}, {2,{kv.isk[ik]}}, {4,{0,1,2,3}}}; + for(const int is_b : is_list.at(GlobalV::NSPIN)) + { + int is0_b, is1_b; + std::tie(is0_b,is1_b) = RI_2D_Comm::split_is_block(is_b); + for(const auto &Hs_tmpA : Hs[is_b]) + { + const TA &iat0 = Hs_tmpA.first; + for(const auto &Hs_tmpB : Hs_tmpA.second) + { + const TA &iat1 = Hs_tmpB.first.first; + const TC &cell1 = Hs_tmpB.first.second; + const std::complex frac = alpha + * std::exp( ModuleBase::TWO_PI*ModuleBase::IMAG_UNIT * (kv.kvec_c[ik] * (RI_Util::array3_to_Vector3(cell1)*GlobalC::ucell.latvec)) ); + const RI::Tensor &H = Hs_tmpB.second; + for(size_t iw0_b=0; iw0_b(H(iw0_b, iw1_b)) * RI::Global_Func::convert(frac), pv, Hloc.data()); + } + } + } + } + } + ModuleBase::timer::tick("RI_2D_Comm", "add_Hexx"); +} + +std::tuple +RI_2D_Comm::get_iat_iw_is_block(const int iwt) +{ + const int iat = GlobalC::ucell.iwt2iat[iwt]; + const int iw = GlobalC::ucell.iwt2iw[iwt]; + switch(GlobalV::NSPIN) + { + case 1: case 2: + return std::make_tuple(iat, iw, 0); + case 4: + return std::make_tuple(iat, iw/2, iw%2); + default: + throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + } +} + +int RI_2D_Comm::get_is_block(const int is_k, const int is_row_b, const int is_col_b) +{ + switch(GlobalV::NSPIN) + { + case 1: return 0; + case 2: return is_k; + case 4: return is_row_b*2+is_col_b; + default: throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + } +} + +std::tuple +RI_2D_Comm::split_is_block(const int is_b) +{ + switch(GlobalV::NSPIN) + { + case 1: case 2: + return std::make_tuple(0, 0); + case 4: + return std::make_tuple(is_b/2, is_b%2); + default: + throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + } +} + + + +int RI_2D_Comm::get_iwt(const int iat, const int iw_b, const int is_b) +{ + const int it = GlobalC::ucell.iat2it[iat]; + const int ia = GlobalC::ucell.iat2ia[iat]; + int iw=-1; + switch(GlobalV::NSPIN) + { + case 1: case 2: + iw = iw_b; break; + case 4: + iw = iw_b*2+is_b; break; + default: + throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__)); + } + const int iwt = GlobalC::ucell.itiaiw2iwt(it,ia,iw); + return iwt; +} + #endif \ No newline at end of file