From 707d7bc40824592694ff109d1934e9838c719ec0 Mon Sep 17 00:00:00 2001 From: maki49 <1579492865@qq.com> Date: Tue, 1 Aug 2023 18:33:20 +0800 Subject: [PATCH] update and UT for restore_psik; add lacpy interface --- source/module_base/lapack_connector.h | 4 +- source/module_base/scalapack_connector.h | 6 +- source/module_ri/exx_symmetry.cpp | 117 +++++++++++------ source/module_ri/exx_symmetry.h | 22 ++-- source/module_ri/test/exx_symmetry_test.cpp | 138 +++++++++++++++++++- 5 files changed, 235 insertions(+), 52 deletions(-) diff --git a/source/module_base/lapack_connector.h b/source/module_base/lapack_connector.h index 29d70e8e34..7b19edbdf4 100644 --- a/source/module_base/lapack_connector.h +++ b/source/module_base/lapack_connector.h @@ -145,8 +145,8 @@ extern "C" void zhegs2_(int *itype, char *uplo, int *n, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); // copies a into b - void dlacpy_(char *uplo, int *m, int *n, double* a, int *lda, double *b, int *ldb); - void zlacpy_(char *uplo, int *m, int *n, std::complex* a, int *lda, std::complex *b, int *ldb); + void dlacpy_(const char* uplo, const int* m, const int* n, double* a, const int* lda, double* b, const int* ldb); + void zlacpy_(const char* uplo, const int* m, const int* n, std::complex* a, const int* lda, std::complex* b, const int* ldb); // generates a real elementary reflector H of order n, such that // H * ( alpha ) = ( beta ), H is unitary. diff --git a/source/module_base/scalapack_connector.h b/source/module_base/scalapack_connector.h index a373639e5f..e6e0f057ca 100644 --- a/source/module_base/scalapack_connector.h +++ b/source/module_base/scalapack_connector.h @@ -14,14 +14,16 @@ extern "C" const int *ictxt, const int *lld, int *info); void pdpotrf_(const char* uplo, const int* n, double* a, const int* ia, const int* ja, const int* desca, int* info); -// void pzpotrf_(char *uplo, int *n, double _Complex *a, int *ia, int *ja, int *desca, int *info); + void pzpotrf_(const char* uplo, const int* n, std::complex* a, const int* ia, const int* ja, const int* desca, int* info); void pdpotri_(const char* uplo, const int* n, double* a, const int* ia, const int* ja, const int* desca, int* info); void pzpotri_(const char* uplo, const int* n, std::complex* a, const int* ia, const int* ja, const int* desca, int* info); - void pdtran_(int *m , int *n , + void pzlacpy_(const char* uplo, const int* m, const int* n, const std::complex* a, const int* ia, const int* ja, const int* desca, std::complex* b, const int* ib, const int* jb, const int* descb); + + void pdtran_(int* m, int* n, double *alpha , double *a , int *ia , int *ja , int *desca , double *beta , double *c , int *ic , int *jc , int *descc ); diff --git a/source/module_ri/exx_symmetry.cpp b/source/module_ri/exx_symmetry.cpp index ce06a6ea79..903fe70f88 100644 --- a/source/module_ri/exx_symmetry.cpp +++ b/source/module_ri/exx_symmetry.cpp @@ -1,4 +1,3 @@ -#ifdef __EXX #include "exx_symmetry.h" #include #include "module_psi/psi.h" @@ -6,13 +5,11 @@ #ifdef __MPI #include #include "module_base/scalapack_connector.h" -#else -#include "module_base/lapack_connector.h" #endif +#include "module_base/lapack_connector.h" namespace ExxSym { - std::vector>> rearange_smat( const int ikibz, const std::vector> sloc_ikibz, @@ -87,36 +84,82 @@ namespace ExxSym return fullmat; } - psi::Psi, psi::DEVICE_CPU> restore_psik( + psi::Psi, psi::DEVICE_CPU> restore_psik_lapack( const int& ikibz, - const std::vector>& psi_ikibz, + const psi::Psi, psi::DEVICE_CPU>& psi_ikibz, + const std::vector>& sloc_ikibz, + const std::vector>>& sloc_ik, + const int& nbasis, + const int& nbands) + { + int nkstar = sloc_ik.size(); + psi::Psi, psi::DEVICE_CPU> c_k(nkstar, nbands, nbasis); + psi::Psi, psi::DEVICE_CPU> tmpSc(1, nbands, nbasis); + // only col_maj is considered now + // 1. S(gk)c_gk + char transa = 'N'; + char transb = 'N'; + std::complex alpha = 1.0; + std::complex beta = 0.0; + + zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis, + &alpha, sloc_ikibz.data(), &nbasis, psi_ikibz.get_pointer(), &nbasis, + &beta, tmpSc.get_pointer(), &nbasis); + + //2. c_k = S^{-1}(k)S(gk)c_{gk} for each k + int ik = 0; + for (auto sk : sloc_ik)// copy, not reference: sk will be replaced by S^{-1}(k) after 2.1 (sk is const) + { + c_k.fix_k(ik); + + // 2.1 S^{-1}(k) + char uplo = 'U'; + int info = -1; + zpotrf_(&uplo, &nbasis, sk.data(), &nbasis, &info); + if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k).(info=" + std::to_string(info) + ")."); + zpotri_(&uplo, &nbasis, sk.data(), &nbasis, &info); + if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when calculating inv(S(k)).(info=" + std::to_string(info) + ")."); + //transpose and copy the upper triangle + std::vector> invsk(sk.size()); + std::vector> ones(sk.size(), 0); + for (int i = 0;i < nbasis;i++) ones[i * nbasis + i] = std::complex(1, 0); + char t = 'T'; + zgemm_(&t, &t, &nbasis, &nbasis, &nbasis, &alpha, sk.data(), &nbasis, ones.data(), &nbasis, &beta, invsk.data(), &nbasis); + zlacpy_(&uplo, &nbasis, &nbasis, sk.data(), &nbasis, invsk.data(), &nbasis); + + //2.2 S^{-1}(k) * S(gk)c_{gk} + zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis, + &alpha, invsk.data(), &nbasis, tmpSc.get_pointer(), &nbasis, + &beta, c_k.get_pointer(), &nbasis); + ++ik; + } + return c_k; + } + +#ifdef __MPI + psi::Psi, psi::DEVICE_CPU> restore_psik_scalapack( + const int& ikibz, + const psi::Psi, psi::DEVICE_CPU>& psi_ikibz, const std::vector>& sloc_ikibz, const std::vector>>& sloc_ik, const int& nbasis, const int& nbands, - const Parallel_Orbitals& pv, - const bool col_inside) + const Parallel_Orbitals& pv) { - int kstar_size = sloc_ik.size(); - psi::Psi, psi::DEVICE_CPU> c_k(kstar_size, pv.ncol_bands, pv.get_row_size()); + int nkstar = sloc_ik.size(); + psi::Psi, psi::DEVICE_CPU> c_k(nkstar, pv.ncol_bands, pv.get_row_size()); psi::Psi, psi::DEVICE_CPU> tmpSc(1, pv.ncol_bands, pv.get_row_size()); // only col_maj is considered now - // 1. S(gk)c_gk + // 1. S(gk)c_gk char transa = 'N'; char transb = 'N'; std::complex alpha = 1.0; std::complex beta = 0.0; -#ifdef __MPI int i1 = 1; pzgemm_(&transa, &transb, &nbasis, &nbands, &nbasis, &alpha, sloc_ikibz.data(), &i1, &i1, pv.desc, - psi_ikibz.data(), &i1, &i1, pv.desc_wfc, &beta, + psi_ikibz.get_pointer(), &i1, &i1, pv.desc_wfc, &beta, tmpSc.get_pointer(), &i1, &i1, pv.desc_wfc); -#else - zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis, - &alpha, sloc_ikibz.data(), &nbasis, psi_ikibz.data(), &nbasis, - &beta, tmpSc.get_pointer(), &nbasis); -#endif //2. c_k = S^{-1}(k)S(gk)c_{gk} for each k int ik = 0; @@ -124,35 +167,35 @@ namespace ExxSym { c_k.fix_k(ik); - // - // std::vector> invsk = sk; - // 2.1 S^{-1}(k) char uplo = 'U'; int info = -1; -#ifdef __MPI pzpotrf_(&uplo, &nbasis, sk.data(), &i1, &i1, pv.desc, &info); - if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k)."); + if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k).(info=" + std::to_string(info) + ")."); pzpotri_(&uplo, &nbasis, sk.data(), &i1, &i1, pv.desc, &info); -#else - zpotrf_(&uplo, &nbasis, sk.data(), &nbasis, &info); - if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k)."); - zpotri_(&uplo, &nbasis, sk.data(), &nbasis, &info); -#endif + if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when calculating inv(S(k)).(info=" + std::to_string(info) + ")."); + //transpose and copy the upper triangle + std::vector> invsk(sk.size()); + std::vector> ones(sk.size(), 0); //row-major + for (int i = 0;i < nbasis;++i) + if (pv.in_this_processor(i, i)) + ones[pv.global2local_col(i) * pv.get_row_size() + pv.global2local_row(i)] = std::complex(1, 0); + char t = 'T'; + pzgemm_(&t, &t, &nbasis, &nbasis, &nbasis, + &alpha, sk.data(), &i1, &i1, pv.desc, + ones.data(), &i1, &i1, pv.desc, &beta, + invsk.data(), &i1, &i1, pv.desc); + pzlacpy_(&uplo, &nbasis, &nbasis, sk.data(), &i1, &i1, pv.desc, invsk.data(), &i1, &i1, pv.desc); //2.2 S^{-1}(k) * S(gk)c_{gk} -#ifdef __MPI pzgemm_(&transa, &transb, &nbasis, &nbands, &nbasis, - &alpha, sk.data(), &i1, &i1, pv.desc, + &alpha, invsk.data(), &i1, &i1, pv.desc, tmpSc.get_pointer(), &i1, &i1, pv.desc_wfc, &beta, c_k.get_pointer(), &i1, &i1, pv.desc_wfc); -#else - zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis, - &alpha, sk.data(), &nbasis, tmpSc.get_pointer(), &nbasis, - &beta, c_k.get_pointer(), &nbasis); -#endif + ++ik; } return c_k; } -} -#endif \ No newline at end of file +#endif + +} \ No newline at end of file diff --git a/source/module_ri/exx_symmetry.h b/source/module_ri/exx_symmetry.h index 3d65d09de8..2647803176 100644 --- a/source/module_ri/exx_symmetry.h +++ b/source/module_ri/exx_symmetry.h @@ -1,4 +1,3 @@ -#ifdef __EXX #pragma once #include #include "module_basis/module_ao/parallel_orbitals.h" @@ -34,21 +33,28 @@ namespace ExxSym /// @param nbands [in] global number of bands /// @param pv [in] parallel orbitals (for both matrix and wavefunction) /// @param col_inside [in] whether the matrix is column-major (major means memory continuity) - /// @return - psi::Psi, psi::DEVICE_CPU> restore_psik( + /// @return c_k: wavefunction of each k in kstars[ikibz] +#ifdef __MPI + psi::Psi, psi::DEVICE_CPU> restore_psik_scalapack( const int& ikibz, - const std::vector>& psi_ikibz, + const psi::Psi, psi::DEVICE_CPU>& psi_ikibz, const std::vector>& sloc_ikibz, const std::vector>>& sloc_ik, const int& nbasis, const int& nbands, - const Parallel_Orbitals& pv, - const bool col_inside); + const Parallel_Orbitals& pv); +#endif + psi::Psi, psi::DEVICE_CPU> restore_psik_lapack( + const int& ikibz, + const psi::Psi, psi::DEVICE_CPU>& psi_ikibz, + const std::vector>& sloc_ikibz, + const std::vector>>& sloc_ik, + const int& nbasis, + const int& nbands); std::vector> get_full_smat( const std::vector>& locmat, const int& nbasis, const Parallel_2D& p2d, const bool col_inside); -} -#endif \ No newline at end of file +} \ No newline at end of file diff --git a/source/module_ri/test/exx_symmetry_test.cpp b/source/module_ri/test/exx_symmetry_test.cpp index 50d4ce519f..5557edf4ec 100644 --- a/source/module_ri/test/exx_symmetry_test.cpp +++ b/source/module_ri/test/exx_symmetry_test.cpp @@ -8,13 +8,88 @@ class SymExxTest : public testing::Test { protected: + int dsize = 1; + int my_rank = 0; + std::ofstream ofs_running; + Parallel_Orbitals pv; + + //cases std::map, std::vector> invmap_cases = { {{3, 2, 1, 0}, {3, 2, 1, 0}}, { {4, 1, 3, 0, 2}, {3, 1, 4, 2, 0} } }; std::vector, std::vector, std::vector>> mapmul_cases = { {{4, 1, 3, 0, 2}, {3, 1, 4, 2, 0}, {0, 1, 2, 3, 4}}, - {{3, 1, 4, 2, 0}, {2, 3, 0, 1, 4}, {1, 3, 4, 0, 2}} - }; + {{3, 1, 4, 2, 0}, {2, 3, 0, 1, 4}, {1, 3, 4, 0, 2}} }; + std::vector> nkstar_nbands_nbasis = { + {1, 3, 4}, {2, 8, 11} }; + + void set2d(int nbasis, int nbands) + { + pv.set_block_size(1); + pv.set_proc_dim(dsize); +#ifdef __MPI + pv.mpi_create_cart(MPI_COMM_WORLD); + pv.set_local2global(nbasis, nbasis, ofs_running, ofs_running); + pv.set_desc(nbasis, nbasis, pv.get_row_size()); + pv.set_global2local(nbasis, nbasis, true, ofs_running); + pv.set_nloc_wfc_Eij(nbands, ofs_running, ofs_running); + pv.set_desc_wfc_Eij(nbasis, nbands, pv.get_row_size()); +#else + pv.set_serial(nbasis, nbands); + pv.set_global2local(nbasis, nbasis, false, ofs_running); + pv.ncol_bands = nbands; + pv.nloc_wfc = nbasis * nbands; +#endif + } + void set_int(std::complex* s, int size) + { + for (int i = 0; i < size; i++) s[i] = std::complex(i, -i); + } + void set_rand(std::complex* s, int size) + { + for (int i = 0; i < size; i++) + s[i] = std::complex(rand(), rand()) / double(RAND_MAX) * 10.0 - 5.0; + } + void set_int_posisym(std::complex* s, int nbasis, int seed) + { + for (int i = 0; i < nbasis; i++) + for (int j = i; j < nbasis; j++) + { + int diff = j - i; + int sgn = diff % 2 ? -1 : 1; + s[j * nbasis + i] = s[i * nbasis + j] = std::complex(static_cast((nbasis - diff) * sgn + seed), 0); + } + } + void copy_from_global(const std::complex* sg, std::complex* sl, const int gr, const int gc, const int lr, const int lc, bool gcol_inside, bool lcol_inside) + { + // global is column major + for (int i = 0;i < lr;++i) + for (int j = 0;j < lc;++j) + if (gcol_inside) + if (lcol_inside) + sl[i * lc + j] = sg[pv.local2global_row(i) * gc + pv.local2global_col(j)]; + else + sl[j * lr + i] = sg[pv.local2global_row(i) * gc + pv.local2global_col(j)]; + else + if (lcol_inside) + sl[i * lc + j] = sg[pv.local2global_col(j) * gr + pv.local2global_row(i)]; + else + sl[j * lr + i] = sg[pv.local2global_col(j) * gr + pv.local2global_row(i)]; + } +#ifdef __MPI + void SetUp() override + { + MPI_Comm_size(MPI_COMM_WORLD, &dsize); + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + this->ofs_running.open("log" + std::to_string(my_rank) + ".txt"); + ofs_running << "dsize(nproc) = " << dsize << std::endl; + ofs_running << "my_rank = " << my_rank << std::endl; + } + void TearDown() override + { + ofs_running.close(); + } +#endif }; TEST_F(SymExxTest, invmap) @@ -35,8 +110,65 @@ TEST_F(SymExxTest, mapmul) } } +#ifdef __MPI +TEST_F(SymExxTest, restore_psik) +{ + //get from case + for (auto& sizes : nkstar_nbands_nbasis) + { + int nbasis = sizes[2]; + int nbands = sizes[1]; + int nkstar = sizes[0]; + + // setup 2d division for this size + this->set2d(nbasis, nbands); + + // generate global symmitric matrix and copy to local + std::vector> sfull_gk(nbasis * nbasis); + this->set_int_posisym(sfull_gk.data(), nbasis, nkstar); + + std::vector> sloc_gk(pv.get_local_size()); + this->copy_from_global(sfull_gk.data(), sloc_gk.data(), nbasis, nbasis, pv.get_row_size(), pv.get_col_size(), false, false); + + std::vector>> sfull_ks(nkstar, std::vector>(nbasis * nbasis)); + for (int ik = 0;ik < nkstar;++ik)this->set_int_posisym(sfull_ks[ik].data(), nbasis, ik); + std::vector>> sloc_ks(nkstar, std::vector>(pv.get_local_size())); + + for (int ik = 0;ik < nkstar;++ik) + this->copy_from_global(sfull_ks[ik].data(), sloc_ks[ik].data(), nbasis, nbasis, pv.get_row_size(), pv.get_col_size(), false, false); + + // generate global psi and copy to local (both are row-major) + int ikibz = 0; + psi::Psi, psi::DEVICE_CPU> psi_full_gk(1, nbands, nbasis); + this->set_int(psi_full_gk.get_pointer(), nbasis * nbands); + psi::Psi, psi::DEVICE_CPU> psi_loc_gk(1, pv.ncol_bands, pv.get_row_size()); + this->copy_from_global(psi_full_gk.get_pointer(), psi_loc_gk.get_pointer(), nbasis, nbands, pv.get_row_size(), pv.ncol_bands, false, false); + + // run + psi::Psi, psi::DEVICE_CPU> psi_loc_ks = ExxSym::restore_psik_scalapack(ikibz, psi_loc_gk, sloc_gk, sloc_ks, nbasis, nbands, pv); + psi::Psi, psi::DEVICE_CPU> psi_full_ks = ExxSym::restore_psik_lapack(ikibz, psi_full_gk, sfull_gk, sfull_ks, nbasis, nbands); + + // check + for (int ik = 0;ik < nkstar;ik++) + { + psi_loc_ks.fix_k(ik); + psi_full_ks.fix_k(ik); + for (int i = 0;i < pv.ncol_bands;i++) + for (int j = 0;j < pv.get_row_size();j++) + EXPECT_NEAR(psi_loc_ks(i, j).real(), psi_full_ks(pv.local2global_col(i), pv.local2global_row(j)).real(), 1e-10); + } + } +} +#endif int main(int argc, char** argv) { + srand(time(NULL)); // for random number generator +#ifdef __MPI + MPI_Init(&argc, &argv); +#endif testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + int result = RUN_ALL_TESTS(); +#ifdef __MPI + MPI_Finalize(); +#endif } \ No newline at end of file