From cda658ec8b748be47497bb22352e687a844b8d74 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Sun, 20 Oct 2024 16:40:49 +0800 Subject: [PATCH] remove band-traverse and recover RI-benchmark --- source/module_lr/hamilt_casida.cpp | 2 +- source/module_lr/hamilt_casida.h | 2 +- .../operator_casida/operator_lr_diag.h | 14 ++-- .../operator_casida/operator_lr_exx.cpp | 76 +++++++++---------- .../operator_casida/operator_lr_hxc.cpp | 47 +++++------- .../ri_benchmark/operator_ri_hartree.h | 24 +++--- .../module_lr/ri_benchmark/ri_benchmark.hpp | 2 +- 7 files changed, 76 insertions(+), 91 deletions(-) diff --git a/source/module_lr/hamilt_casida.cpp b/source/module_lr/hamilt_casida.cpp index b15d591d91..1b36b76c22 100644 --- a/source/module_lr/hamilt_casida.cpp +++ b/source/module_lr/hamilt_casida.cpp @@ -43,7 +43,7 @@ namespace LR #endif } // output Amat - std::cout << "Full A matrix:" << std::endl; + std::cout << "Full A matrix: (elements < 1e-10 is set to 0)" << std::endl; LR_Util::print_value(Amat_full.data(), nk * npairs, nk * npairs); return Amat_full; } diff --git a/source/module_lr/hamilt_casida.h b/source/module_lr/hamilt_casida.h index f7d16dd0c2..ab40c82611 100644 --- a/source/module_lr/hamilt_casida.h +++ b/source/module_lr/hamilt_casida.h @@ -46,7 +46,7 @@ namespace LR if (ri_hartree_benchmark != "aims") { assert(aims_nbasis.empty()); } // always use nspin=1 for transition density matrix this->DM_trans = LR_Util::make_unique>(&pmat_in, 1, kv_in.kvec_d, nk); - LR_Util::initialize_DMR(*this->DM_trans, pmat_in, ucell_in, gd_in, orb_cutoff); + if (ri_hartree_benchmark == "none") { LR_Util::initialize_DMR(*this->DM_trans, pmat_in, ucell_in, gd_in, orb_cutoff); } // this->DM_trans->init_DMR(&gd_in, &ucell_in); // too large due to not restricted by orb_cutoff // add the diag operator (the first one) diff --git a/source/module_lr/operator_casida/operator_lr_diag.h b/source/module_lr/operator_casida/operator_lr_diag.h index a924399a24..b18343e04a 100644 --- a/source/module_lr/operator_casida/operator_lr_diag.h +++ b/source/module_lr/operator_casida/operator_lr_diag.h @@ -46,15 +46,11 @@ namespace LR { ModuleBase::TITLE("OperatorLRDiag", "act"); const int nlocal_ph = nk * pX.get_local_size(); // local size of particle-hole basis - for (int ib = 0;ib < nbands;++ib) - { - const int ibstart = ib * nlocal_ph; - hsolver::vector_mul_vector_op()(this->ctx, - nk * pX.get_local_size(), - hpsi + ibstart, - psi_in + ibstart, - this->eig_ks_diff.c); - } + hsolver::vector_mul_vector_op()(this->ctx, + nk * pX.get_local_size(), + hpsi, + psi_in, + this->eig_ks_diff.c); } private: const Parallel_2D& pX; diff --git a/source/module_lr/operator_casida/operator_lr_exx.cpp b/source/module_lr/operator_casida/operator_lr_exx.cpp index b3c8362756..fba9b5f898 100644 --- a/source/module_lr/operator_casida/operator_lr_exx.cpp +++ b/source/module_lr/operator_casida/operator_lr_exx.cpp @@ -83,56 +83,54 @@ namespace LR // convert parallel info to LibRI interfaces std::vector, std::set>> judge = RI_2D_Comm::get_2D_judge(this->pmat); - for (int ib = 0;ib < nbands;++ib) - { - const int xstart_b = ib * nk * pX.get_local_size(); - // suppose Cs,Vs, have already been calculated in the ion-step of ground state - // and DM_trans has been calculated in hPsi() outside. - // 1. set_Ds (once) - // convert to vector for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor) - std::vector> DMk_trans_vector = this->DM_trans->get_DMK_vector(); - // assert(DMk_trans_vector.size() == nk); - std::vector*> DMk_trans_pointer(nk); - for (int ik = 0;ik < nk;++ik) {DMk_trans_pointer[ik] = &DMk_trans_vector[ik];} - // if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex) - std::vector>>> Ds_trans = - aims_nbasis.empty() ? - RI_2D_Comm::split_m2D_ktoR(this->kv, DMk_trans_pointer, this->pmat, 1) - : RI_Benchmark::split_Ds(DMk_trans_vector, aims_nbasis, ucell); //0.5 will be multiplied - // LR_Util::print_CV(Ds_trans[0], "Ds_trans in OperatorLREXX", 1e-10); - // 2. cal_Hs - auto lri = this->exx_lri.lock(); + // suppose Cs,Vs, have already been calculated in the ion-step of ground state + // and DM_trans has been calculated in hPsi() outside. + + // 1. set_Ds (once) + // convert to vector for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor) + std::vector> DMk_trans_vector = this->DM_trans->get_DMK_vector(); + // assert(DMk_trans_vector.size() == nk); + std::vector*> DMk_trans_pointer(nk); + for (int ik = 0;ik < nk;++ik) { DMk_trans_pointer[ik] = &DMk_trans_vector[ik]; } + // if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex) + std::vector>>> Ds_trans = + aims_nbasis.empty() ? + RI_2D_Comm::split_m2D_ktoR(this->kv, DMk_trans_pointer, this->pmat, 1) + : RI_Benchmark::split_Ds(DMk_trans_vector, aims_nbasis, ucell); //0.5 will be multiplied + // LR_Util::print_CV(Ds_trans[0], "Ds_trans in OperatorLREXX", 1e-10); + // 2. cal_Hs + auto lri = this->exx_lri.lock(); - // LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10); - lri->exx_lri.set_Ds(std::move(Ds_trans[0]), lri->info.dm_threshold); - lri->exx_lri.cal_Hs(); - lri->Hexxs[0] = RI::Communicate_Tensors_Map_Judge::comm_map2_first( - lri->mpi_comm, std::move(lri->exx_lri.Hs), std::get<0>(judge[0]), std::get<1>(judge[0])); - lri->post_process_Hexx(lri->Hexxs[0]); + // LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10); + lri->exx_lri.set_Ds(std::move(Ds_trans[0]), lri->info.dm_threshold); + lri->exx_lri.cal_Hs(); + lri->Hexxs[0] = RI::Communicate_Tensors_Map_Judge::comm_map2_first( + lri->mpi_comm, std::move(lri->exx_lri.Hs), std::get<0>(judge[0]), std::get<1>(judge[0])); + lri->post_process_Hexx(lri->Hexxs[0]); - // 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point - // caution: parrallel + // 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point + // caution: parrallel - for (int io = 0;io < this->nocc;++io) + for (int io = 0;io < this->nocc;++io) + { + for (int iv = 0;iv < this->nvirt;++iv) { - for (int iv = 0;iv < this->nvirt;++iv) + for (int ik = 0;ik < nk;++ik) { - for (int ik = 0;ik < nk;++ik) + const int xstart_bk = ik * pX.get_local_size(); + this->cal_DM_onebase(io, iv, ik); //set Ds_onebase for all e-h pairs (not only on this processor) + // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10); + const T& ene = 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry + lri->exx_lri.post_2D.cal_energy(this->Ds_onebase, lri->Hexxs[0]); + if (this->pX.in_this_processor(iv, io)) { - const int xstart_bk = xstart_b + ik * pX.get_local_size(); - this->cal_DM_onebase(io, iv, ik); //set Ds_onebase for all e-h pairs (not only on this processor) - // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10); - const T& ene = 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry - lri->exx_lri.post_2D.cal_energy(this->Ds_onebase, lri->Hexxs[0]); - if (this->pX.in_this_processor(iv, io)) - { - hpsi[xstart_bk + ik * pX.get_local_size() + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene; - } + hpsi[xstart_bk + ik * pX.get_local_size() + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene; } } } } + } template class OperatorLREXX; template class OperatorLREXX>; diff --git a/source/module_lr/operator_casida/operator_lr_hxc.cpp b/source/module_lr/operator_casida/operator_lr_hxc.cpp index 5d5c53496d..58f01932b4 100644 --- a/source/module_lr/operator_casida/operator_lr_hxc.cpp +++ b/source/module_lr/operator_casida/operator_lr_hxc.cpp @@ -22,36 +22,31 @@ namespace LR ModuleBase::TITLE("OperatorLRHxc", "act"); const int& sl = ispin_ks[0]; const auto psil_ks = LR_Util::get_psi_spin(psi_ks, sl, nk); - const int& lgd = gint->gridt->lgd; - for (int ib = 0;ib < nbands;++ib) - { - const int xstart_b = ib * nbasis; - - this->DM_trans->cal_DMR(); //DM_trans->get_DMR_vector() is 2d-block parallized - // LR_Util::print_DMR(*DM_trans, ucell.nat, "DMR"); - - // ========================= begin grid calculation========================= - this->grid_calculation(nbands); //DM(R) to H(R) - // ========================= end grid calculation ========================= - - // V(R)->V(k) - std::vector v_hxc_2d(nk, LR_Util::newTensor({ pmat.get_col_size(), pmat.get_row_size() })); - for (auto& v : v_hxc_2d) v.zero(); - int nrow = ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver) ? this->pmat.get_row_size() : this->pmat.get_col_size(); - for (int ik = 0;ik < nk;++ik) { folding_HR(*this->hR, v_hxc_2d[ik].data(), this->kv.kvec_d[ik], nrow, 1); } // V(R) -> V(k) - // LR_Util::print_HR(*this->hR, this->ucell.nat, "4.VR"); - // if (this->first_print) - // for (int ik = 0;ik < nk;++ik) - // LR_Util::print_tensor(v_hxc_2d[ik], "4.V(k)[ik=" + std::to_string(ik) + "]", &this->pmat); - - // 5. [AX]^{Hxc}_{ai}=\sum_{\mu,\nu}c^*_{a,\mu,}V^{Hxc}_{\mu,\nu}c_{\nu,i} + + this->DM_trans->cal_DMR(); //DM_trans->get_DMR_vector() is 2d-block parallized + // LR_Util::print_DMR(*DM_trans, ucell.nat, "DMR"); + + // ========================= begin grid calculation========================= + this->grid_calculation(nbands); //DM(R) to H(R) + // ========================= end grid calculation ========================= + + // V(R)->V(k) + std::vector v_hxc_2d(nk, LR_Util::newTensor({ pmat.get_col_size(), pmat.get_row_size() })); + for (auto& v : v_hxc_2d) v.zero(); + int nrow = ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver) ? this->pmat.get_row_size() : this->pmat.get_col_size(); + for (int ik = 0;ik < nk;++ik) { folding_HR(*this->hR, v_hxc_2d[ik].data(), this->kv.kvec_d[ik], nrow, 1); } // V(R) -> V(k) + // LR_Util::print_HR(*this->hR, this->ucell.nat, "4.VR"); + // if (this->first_print) + // for (int ik = 0;ik < nk;++ik) + // LR_Util::print_tensor(v_hxc_2d[ik], "4.V(k)[ik=" + std::to_string(ik) + "]", &this->pmat); + + // 5. [AX]^{Hxc}_{ai}=\sum_{\mu,\nu}c^*_{a,\mu,}V^{Hxc}_{\mu,\nu}c_{\nu,i} #ifdef __MPI - cal_AX_pblas(v_hxc_2d, this->pmat, psil_ks, this->pc, naos, nocc[sl], nvirt[sl], this->pX[sl], hpsi + xstart_b); + cal_AX_pblas(v_hxc_2d, this->pmat, psil_ks, this->pc, naos, nocc[sl], nvirt[sl], this->pX[sl], hpsi); #else - cal_AX_blas(v_hxc_2d, psil_ks, nocc[sl], nvirt[sl], hpsi + xstart_b); + cal_AX_blas(v_hxc_2d, psil_ks, nocc[sl], nvirt[sl], hpsi); #endif - } } diff --git a/source/module_lr/ri_benchmark/operator_ri_hartree.h b/source/module_lr/ri_benchmark/operator_ri_hartree.h index 46fc50ff7e..cc398be90c 100644 --- a/source/module_lr/ri_benchmark/operator_ri_hartree.h +++ b/source/module_lr/ri_benchmark/operator_ri_hartree.h @@ -50,22 +50,18 @@ namespace RI_Benchmark } }; ~OperatorRIHartree() {} - void act(const psi::Psi& X_in, psi::Psi& X_out, const int nbands) const override + void act(const int nbands, const int nbasis, const int npol, const T* psi_in, T* hpsi, const int ngk_ik = 0)const override { assert(GlobalV::MY_RANK == 0); // only serial now - const int nk = 1; - const psi::Psi& X = LR_Util::k1_to_bfirst_wrapper(X_in, nk, npairs); - psi::Psi AX = LR_Util::k1_to_bfirst_wrapper(X_out, nk, npairs); - for (int ib = 0;ib < nbands;++ib) - { - TLRIX CsX_vo = cal_CsX(Cs_vo_mo, &X(ib, 0, 0)); - TLRIX CsX_ov = cal_CsX(Cs_ov_mo, &X(ib, 0, 0)); - // LR_Util::print_CsX(Cs_bX, nvirt, "Cs_bX of state " + std::to_string(ib)); - cal_AX(CV_vo, CsX_vo, &AX(ib, 0, 0), 4.); - cal_AX(CV_vo, CsX_ov, &AX(ib, 0, 0), 4.); - cal_AX(CV_ov, CsX_vo, &AX(ib, 0, 0), 4.); - cal_AX(CV_ov, CsX_ov, &AX(ib, 0, 0), 4.); - } + assert(nbasis == npairs); + TLRIX CsX_vo = cal_CsX(Cs_vo_mo, psi_in); + TLRIX CsX_ov = cal_CsX(Cs_ov_mo, psi_in); + // LR_Util::print_CsX(Cs_bX, nvirt, "Cs_bX of state " + std::to_string(ib)); + // 4 for 4 terms in the expansion of local RI + cal_AX(CV_vo, CsX_vo, hpsi, 4.); + cal_AX(CV_vo, CsX_ov, hpsi, 4.); + cal_AX(CV_ov, CsX_vo, hpsi, 4.); + cal_AX(CV_ov, CsX_ov, hpsi, 4.); } protected: const int& naos; diff --git a/source/module_lr/ri_benchmark/ri_benchmark.hpp b/source/module_lr/ri_benchmark/ri_benchmark.hpp index 671ccd7774..82ab9249ac 100644 --- a/source/module_lr/ri_benchmark/ri_benchmark.hpp +++ b/source/module_lr/ri_benchmark/ri_benchmark.hpp @@ -153,7 +153,7 @@ namespace RI_Benchmark return Amat_full; } template - TLRIX cal_CsX(const TLRI& Cs_mo, TK* X) + TLRIX cal_CsX(const TLRI& Cs_mo, const TK* X) { TLRIX CsX; for (auto& it1 : Cs_mo)