Skip to content

Commit

Permalink
support reusing blacs_ctxt in Parallel_2D::set_desc
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Jul 22, 2023
1 parent f10f842 commit 573e6c4
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
29 changes: 16 additions & 13 deletions source/module_basis/module_ao/parallel_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,34 @@ void Parallel_2D::mpi_create_cart(const MPI_Comm& diag_world)
return;
}

void Parallel_2D::set_desc(const int& gr, const int& gc, const int& lld)
void Parallel_2D::set_desc(const int& gr, const int& gc, const int& lld, bool first_time)
{
ModuleBase::TITLE("Parallel_2D", "set_desc");
#ifdef __DEBUG
assert(this->comm_2D != MPI_COMM_NULL);
assert(gr > 0 && gc > 0 && lld > 0);
assert(this->nb > 0 && this->dim0 > 0 && this->dim1 > 0);
#endif
int myprow, mypcol;
int* usermap = new int[this->dim0 * this->dim1];
int info = 0;
for (int i = 0; i < this->dim0; ++i)
if (first_time)
{
for (int j = 0; j < this->dim1; ++j)
int myprow, mypcol;
int* usermap = new int[this->dim0 * this->dim1];
for (int i = 0; i < this->dim0; ++i)
{
int pcoord[2] = { i, j };
MPI_Cart_rank(comm_2D, pcoord, &usermap[i + j * this->dim0]);
for (int j = 0; j < this->dim1; ++j)
{
int pcoord[2] = { i, j };
MPI_Cart_rank(comm_2D, pcoord, &usermap[i + j * this->dim0]);
}
}
MPI_Fint comm_2D_f = MPI_Comm_c2f(comm_2D);
Cblacs_get(comm_2D_f, 0, &this->blacs_ctxt);
Cblacs_gridmap(&this->blacs_ctxt, usermap, this->dim0, this->dim0, this->dim1);
Cblacs_gridinfo(this->blacs_ctxt, &this->dim0, &this->dim1, &myprow, &mypcol);
delete[] usermap;
}
MPI_Fint comm_2D_f = MPI_Comm_c2f(comm_2D);
Cblacs_get(comm_2D_f, 0, &this->blacs_ctxt);
Cblacs_gridmap(&this->blacs_ctxt, usermap, this->dim0, this->dim0, this->dim1);
Cblacs_gridinfo(this->blacs_ctxt, &this->dim0, &this->dim1, &myprow, &mypcol);
delete[] usermap;
int ISRC = 0;
int info = 0;
descinit_(desc, &gr, &gc, &this->nb, &this->nb, &ISRC, &ISRC, &this->blacs_ctxt, &lld, &info);
}

Expand Down
3 changes: 2 additions & 1 deletion source/module_basis/module_ao/parallel_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class Parallel_2D
///@brief set the desc[9] of the 2D-block-cyclic distribution
void set_desc(const int& gr/**< global row size*/,
const int& gc/**< global col size*/,
const int& lld/**< leading local dimension*/);
const int& lld/**< leading local dimension*/,
bool first_time = true/**< true: call `Cblacs_get`; false: use `this->blacs_ctxt``*/);
#else
void set_serial(const int& M_A/**< global row size*/,
const int& N_A/**< global col size*/);
Expand Down
35 changes: 34 additions & 1 deletion source/module_basis/module_ao/test/parallel_2d_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* set the desc[9] of the 2D-block-cyclic distribution.
*
* - set_global2local
* set the map from global index to local index.
* set the map from global index to local index (init, reuse).
*
* - set_serial (serial)
* set the local(=global) sizes.
Expand Down Expand Up @@ -129,6 +129,39 @@ TEST_F(test_para2d, Divide2D)
}
}
}

TEST_F(test_para2d, DescReuseCtxt)
{
for (auto nb : nbs)
{
Parallel_2D p1;
p1.set_block_size(nb);
p1.set_proc_dim(dsize);
p1.mpi_create_cart(MPI_COMM_WORLD);
p1.set_local2global(sizes[0].first, sizes[0].second, ofs_running, ofs_running);
p1.set_desc(sizes[0].first, sizes[0].second, p1.get_row_size());

Parallel_2D p2; // use 2 different sizes, but they can share the same ctxt
p2.set_block_size(nb);
p2.set_proc_dim(dsize);
p2.comm_2D = p1.comm_2D;
p2.blacs_ctxt = p1.blacs_ctxt;
p2.set_local2global(sizes[1].first, sizes[1].second, ofs_running, ofs_running);
p2.set_desc(sizes[1].first, sizes[1].second, p2.get_row_size(), false);

EXPECT_EQ(p1.desc[1], p2.desc[1]);

Parallel_2D p3; // using default `set_desc`, p3 can't share the same ctxt with p1
p3.set_block_size(nb);
p3.set_proc_dim(dsize);
p3.comm_2D = p1.comm_2D;
p3.blacs_ctxt = p1.blacs_ctxt;
p3.set_local2global(sizes[2].first, sizes[2].second, ofs_running, ofs_running);
p3.set_desc(sizes[2].first, sizes[2].second, p3.get_row_size());

EXPECT_NE(p1.desc[1], p3.desc[1]);
}
}
#else
TEST_F(test_para2d, Serial)
{
Expand Down

0 comments on commit 573e6c4

Please sign in to comment.