diff --git a/source/module_basis/module_ao/parallel_2d.cpp b/source/module_basis/module_ao/parallel_2d.cpp index 7c2f9c372b..765a1b8bf0 100644 --- a/source/module_basis/module_ao/parallel_2d.cpp +++ b/source/module_basis/module_ao/parallel_2d.cpp @@ -121,7 +121,7 @@ void Parallel_2D::mpi_create_cart(const MPI_Comm& diag_world) return; } -void Parallel_2D::set_desc(const int& gr, const int& gc, const int& lld) +void Parallel_2D::set_desc(const int& gr, const int& gc, const int& lld, bool first_time) { ModuleBase::TITLE("Parallel_2D", "set_desc"); #ifdef __DEBUG @@ -129,23 +129,26 @@ void Parallel_2D::set_desc(const int& gr, const int& gc, const int& lld) assert(gr > 0 && gc > 0 && lld > 0); assert(this->nb > 0 && this->dim0 > 0 && this->dim1 > 0); #endif - int myprow, mypcol; - int* usermap = new int[this->dim0 * this->dim1]; - int info = 0; - for (int i = 0; i < this->dim0; ++i) + if (first_time) { - for (int j = 0; j < this->dim1; ++j) + int myprow, mypcol; + int* usermap = new int[this->dim0 * this->dim1]; + for (int i = 0; i < this->dim0; ++i) { - int pcoord[2] = { i, j }; - MPI_Cart_rank(comm_2D, pcoord, &usermap[i + j * this->dim0]); + for (int j = 0; j < this->dim1; ++j) + { + int pcoord[2] = { i, j }; + MPI_Cart_rank(comm_2D, pcoord, &usermap[i + j * this->dim0]); + } } + MPI_Fint comm_2D_f = MPI_Comm_c2f(comm_2D); + Cblacs_get(comm_2D_f, 0, &this->blacs_ctxt); + Cblacs_gridmap(&this->blacs_ctxt, usermap, this->dim0, this->dim0, this->dim1); + Cblacs_gridinfo(this->blacs_ctxt, &this->dim0, &this->dim1, &myprow, &mypcol); + delete[] usermap; } - MPI_Fint comm_2D_f = MPI_Comm_c2f(comm_2D); - Cblacs_get(comm_2D_f, 0, &this->blacs_ctxt); - Cblacs_gridmap(&this->blacs_ctxt, usermap, this->dim0, this->dim0, this->dim1); - Cblacs_gridinfo(this->blacs_ctxt, &this->dim0, &this->dim1, &myprow, &mypcol); - delete[] usermap; int ISRC = 0; + int info = 0; descinit_(desc, &gr, &gc, &this->nb, &this->nb, &ISRC, &ISRC, &this->blacs_ctxt, &lld, &info); } diff --git a/source/module_basis/module_ao/parallel_2d.h b/source/module_basis/module_ao/parallel_2d.h index a9a4567218..53060dd9a3 100644 --- a/source/module_basis/module_ao/parallel_2d.h +++ b/source/module_basis/module_ao/parallel_2d.h @@ -98,7 +98,8 @@ class Parallel_2D ///@brief set the desc[9] of the 2D-block-cyclic distribution void set_desc(const int& gr/**< global row size*/, const int& gc/**< global col size*/, - const int& lld/**< leading local dimension*/); + const int& lld/**< leading local dimension*/, + bool first_time = true/**< true: call `Cblacs_get`; false: use `this->blacs_ctxt``*/); #else void set_serial(const int& M_A/**< global row size*/, const int& N_A/**< global col size*/); diff --git a/source/module_basis/module_ao/test/parallel_2d_test.cpp b/source/module_basis/module_ao/test/parallel_2d_test.cpp index 534af3304a..eb8a836cb7 100644 --- a/source/module_basis/module_ao/test/parallel_2d_test.cpp +++ b/source/module_basis/module_ao/test/parallel_2d_test.cpp @@ -19,7 +19,7 @@ * set the desc[9] of the 2D-block-cyclic distribution. * * - set_global2local - * set the map from global index to local index. + * set the map from global index to local index (init, reuse). * * - set_serial (serial) * set the local(=global) sizes. @@ -129,6 +129,39 @@ TEST_F(test_para2d, Divide2D) } } } + +TEST_F(test_para2d, DescReuseCtxt) +{ + for (auto nb : nbs) + { + Parallel_2D p1; + p1.set_block_size(nb); + p1.set_proc_dim(dsize); + p1.mpi_create_cart(MPI_COMM_WORLD); + p1.set_local2global(sizes[0].first, sizes[0].second, ofs_running, ofs_running); + p1.set_desc(sizes[0].first, sizes[0].second, p1.get_row_size()); + + Parallel_2D p2; // use 2 different sizes, but they can share the same ctxt + p2.set_block_size(nb); + p2.set_proc_dim(dsize); + p2.comm_2D = p1.comm_2D; + p2.blacs_ctxt = p1.blacs_ctxt; + p2.set_local2global(sizes[1].first, sizes[1].second, ofs_running, ofs_running); + p2.set_desc(sizes[1].first, sizes[1].second, p2.get_row_size(), false); + + EXPECT_EQ(p1.desc[1], p2.desc[1]); + + Parallel_2D p3; // using default `set_desc`, p3 can't share the same ctxt with p1 + p3.set_block_size(nb); + p3.set_proc_dim(dsize); + p3.comm_2D = p1.comm_2D; + p3.blacs_ctxt = p1.blacs_ctxt; + p3.set_local2global(sizes[2].first, sizes[2].second, ofs_running, ofs_running); + p3.set_desc(sizes[2].first, sizes[2].second, p3.get_row_size()); + + EXPECT_NE(p1.desc[1], p3.desc[1]); + } +} #else TEST_F(test_para2d, Serial) {