diff --git a/source/module_lr/operator_casida/operator_lr_exx.cpp b/source/module_lr/operator_casida/operator_lr_exx.cpp index 9c010286c6..3d1c4e39b6 100644 --- a/source/module_lr/operator_casida/operator_lr_exx.cpp +++ b/source/module_lr/operator_casida/operator_lr_exx.cpp @@ -2,6 +2,7 @@ #include "operator_lr_exx.h" #include "module_lr/dm_trans/dm_trans.h" #include "module_lr/utils/lr_util.h" +#include "module_lr/utils/lr_util_print.h" namespace LR { template @@ -38,7 +39,7 @@ namespace LR auto& D2d = this->Ds_onebase[is][iat1][std::make_pair(iat2, cell)]; for (int iw1 = 0;iw1 < ucell.atoms[it1].nw;++iw1) for (int iw2 = 0;iw2 < ucell.atoms[it2].nw;++iw2) - D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, iv, ucell.itiaiw2iwt(it2, ia2, iw2)); + D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2)); } } } @@ -61,7 +62,7 @@ namespace LR auto& D2d = this->Ds_onebase[is][iat1][std::make_pair(iat2, cell)]; for (int iw1 = 0;iw1 < ucell.atoms[it1].nw;++iw1) for (int iw2 = 0;iw2 < ucell.atoms[it2].nw;++iw2) - D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1))) * this->psi_ks_full(ik, iv, ucell.itiaiw2iwt(it2, ia2, iw2)); + D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1))) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2)); } } } @@ -83,6 +84,19 @@ namespace LR psi_out_bfirst.fix_b(ib); // suppose Cs,Vs, have already been calculated in the ion-step of ground state, // DM_trans(k) and DM_trans(R) has already been calculated from psi_in in OperatorLRHxc::act + // but int RI_benchmark, DM_trans(k) should be first calculated here + if (cal_dm_trans) + { +#ifdef __MPI + std::vector dm_trans_2d = cal_dm_trans_pblas(psi_in_bfirst, *pX, *psi_ks, *pc, naos, nocc, nvirt, *pmat); + if (this->tdm_sym) for (auto& t : dm_trans_2d) LR_Util::matsym(t.data(), naos, *pmat); +#else + std::vector dm_trans_2d = cal_dm_trans_blas(psi_in_bfirst, *psi_ks, nocc, nvirt); + if (this->tdm_sym) for (auto& t : dm_trans_2d) LR_Util::matsym(t.data(), naos); +#endif + // tensor to vector, then set DMK + for (int ik = 0;ik < nk;++ik) { this->DM_trans[ib]->set_DMK_pointer(ik, dm_trans_2d[ik].data()); } + } // 1. set_Ds (once) // convert to vector for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor) @@ -92,17 +106,18 @@ namespace LR for (int ik = 0;ik < nk;++ik) {DMk_trans_pointer[ik] = &DMk_trans_vector[ik];} // if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex) std::vector>>> Ds_trans = - RI_2D_Comm::split_m2D_ktoR(this->kv, DMk_trans_pointer, *this->pmat, this->nspin_solve); + RI_2D_Comm::split_m2D_ktoR(this->kv, DMk_trans_pointer, *this->pmat, this->nspin_solve); //0.5 will be multiplied // 2. cal_Hs auto lri = this->exx_lri.lock(); - for (int ik = 0;ik < nk;++ik) + for (int is = 0;is < nspin_solve;++is) { - lri->exx_lri.set_Ds(std::move(Ds_trans[ik]), lri->info.dm_threshold); + // LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10); + lri->exx_lri.set_Ds(std::move(Ds_trans[is]), lri->info.dm_threshold); lri->exx_lri.cal_Hs(); - lri->Hexxs[ik] = RI::Communicate_Tensors_Map_Judge::comm_map2_first( - lri->mpi_comm, std::move(lri->exx_lri.Hs), std::get<0>(judge[ik]), std::get<1>(judge[ik])); - lri->post_process_Hexx(lri->Hexxs[ik]); + lri->Hexxs[is] = RI::Communicate_Tensors_Map_Judge::comm_map2_first( + lri->mpi_comm, std::move(lri->exx_lri.Hs), std::get<0>(judge[is]), std::get<1>(judge[is])); + lri->post_process_Hexx(lri->Hexxs[is]); } // 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point @@ -117,12 +132,14 @@ namespace LR for (int is = 0;is < this->nspin_solve;++is) { this->cal_DM_onebase(this->pX->local2global_col(io), this->pX->local2global_row(iv), ik, is); //set Ds_onebase - psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) -= 0.5 * //minus for exchange, 0.5 for spin + // LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10); + psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) += 2 * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry alpha * lri->exx_lri.post_2D.cal_energy(this->Ds_onebase[is], lri->Hexxs[is]); } } } } + } } template class OperatorLREXX; diff --git a/source/module_lr/operator_casida/operator_lr_exx.h b/source/module_lr/operator_casida/operator_lr_exx.h index d852bade0c..d9f9d3e52a 100644 --- a/source/module_lr/operator_casida/operator_lr_exx.h +++ b/source/module_lr/operator_casida/operator_lr_exx.h @@ -30,10 +30,11 @@ namespace LR Parallel_2D* pX_in, Parallel_2D* pc_in, Parallel_Orbitals* pmat_in, - const double& alpha = 1.0) + const double& alpha = 1.0, + const bool& cal_dm_trans = false) : nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt), psi_ks(psi_ks_in), DM_trans(DM_trans_in), exx_lri(exx_lri_in), kv(kv_in), - pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in), alpha(alpha) + pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in), alpha(alpha), cal_dm_trans(cal_dm_trans) { ModuleBase::TITLE("OperatorLREXX", "OperatorLREXX"); this->cal_type = hamilt::calculation_type::lcao_exx; @@ -64,6 +65,8 @@ namespace LR const int& nocc; const int& nvirt; const double& alpha; + const bool cal_dm_trans = false; + const bool tdm_sym = false; ///< whether transition density matrix is symmetric const K_Vectors& kv; /// ground state wavefunction const psi::Psi* psi_ks = nullptr;