From bbf9baddd711e0ee39c2643947c96e069d99c923 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Fri, 27 Sep 2024 21:10:52 +0800 Subject: [PATCH] fix LR-exx parallel --- .../operator_casida/operator_lr_exx.cpp | 31 +++++++++---------- .../operator_casida/operator_lr_exx.h | 4 +-- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/source/module_lr/operator_casida/operator_lr_exx.cpp b/source/module_lr/operator_casida/operator_lr_exx.cpp index 311a274c57..4e06d63fa2 100644 --- a/source/module_lr/operator_casida/operator_lr_exx.cpp +++ b/source/module_lr/operator_casida/operator_lr_exx.cpp @@ -45,7 +45,8 @@ namespace LR const int nw2 = aims_nbasis.empty() ? ucell.atoms[it2].nw : aims_nbasis[it2]; for (int iw1 = 0;iw1 < nw1;++iw1) for (int iw2 = 0;iw2 < nw2;++iw2) - D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2)); + if(this->pmat->in_this_processor(ucell.itiaiw2iwt(it1, ia1, iw1), ucell.itiaiw2iwt(it2, ia2, iw2))) + D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2)); } } } @@ -70,7 +71,8 @@ namespace LR const int nw2 = aims_nbasis.empty() ? ucell.atoms[it2].nw : aims_nbasis[it2]; for (int iw1 = 0;iw1 < nw1;++iw1) for (int iw2 = 0;iw2 < nw2;++iw2) - D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1))) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2)); + if(this->pmat->in_this_processor(ucell.itiaiw2iwt(it1, ia1, iw1), ucell.itiaiw2iwt(it2, ia2, iw2))) + D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1))) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2)); } } } @@ -133,23 +135,18 @@ namespace LR // 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point // caution: parrallel - for (int io = 0;io < this->pX->get_col_size();++io) // nocc for serial - { - for (int iv = 0;iv < this->pX->get_row_size();++iv) // nvirt for serial - { + for (int io = 0;io < this->nocc;++io) + for (int iv = 0;iv < this->nvirt;++iv) for (int ik = 0;ik < nk;++ik) - { for (int is = 0;is < this->nspin_solve;++is) - { - this->cal_DM_onebase(this->pX->local2global_col(io), this->pX->local2global_row(iv), ik, is); //set Ds_onebase - // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10); - psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) += 2 * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry - alpha * lri->exx_lri.post_2D.cal_energy(this->Ds_onebase[is], lri->Hexxs[is]); - } - } - } - } - + { + this->cal_DM_onebase(io, iv, ik, is); //set Ds_onebase for all e-h pairs (not only on this processor) + // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10); + const T& ene= 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry + lri->exx_lri.post_2D.cal_energy(this->Ds_onebase[is], lri->Hexxs[is]); + if(this->pX->in_this_processor(iv, io)) + psi_out_bfirst(ik, this->pX->global2local_col(io) * this->pX->get_row_size() + this->pX->global2local_row(iv)) += ene; + } } } template class OperatorLREXX; diff --git a/source/module_lr/operator_casida/operator_lr_exx.h b/source/module_lr/operator_casida/operator_lr_exx.h index f36964d74b..1ebea0c9f2 100644 --- a/source/module_lr/operator_casida/operator_lr_exx.h +++ b/source/module_lr/operator_casida/operator_lr_exx.h @@ -44,8 +44,8 @@ namespace LR this->is_first_node = false; // reduce psi_ks for later use - this->psi_ks_full.resize(this->kv.get_nks(), this->psi_ks->get_nbands(), this->naos); - LR_Util::gather_2d_to_full(*this->pc, this->psi_ks->get_pointer(), this->psi_ks_full.get_pointer(), false, this->naos, this->psi_ks->get_nbands()); + this->psi_ks_full.resize(this->kv.get_nks(), nocc + nvirt, this->naos); + LR_Util::gather_2d_to_full(*this->pc, this->psi_ks->get_pointer(), this->psi_ks_full.get_pointer(), false, this->naos, nocc + nvirt); // get cells in BvK supercell const TC period = RI_Util::get_Born_vonKarmen_period(kv_in);