Skip to content

Commit

Permalink
Para2D distribute X
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Jul 8, 2023
1 parent 3bbaeeb commit 4818252
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 14 deletions.
7 changes: 6 additions & 1 deletion source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
#include "module_io/print_info.h"
#include "module_md/run_md.h"
#include "module_beyonddft/esolver_lrtd_lcao.hpp"

extern "C"
{
#include "module_base/blacs_connector.h"
}
// This is the driver function which defines the workflow of ABACUS calculations
// It relies on the class Esolver, which is a class that organizes workflows of single point calculations.
// For calculations involving change of configuration (lattice parameter & ionic motion),
Expand Down Expand Up @@ -90,6 +93,8 @@ void Driver::driver_run()
else
ModuleESolver::clean_esolver(p_esolver);

if (INPUT.basis_type == "lcao")
Cblacs_exit(1); // clean up blacs after all the esolvers are cleaned up without closing MPI
std::cout << "befor end" << std::endl;
ModuleBase::timer::tick("Driver", "driver_line");
return;
Expand Down
4 changes: 1 addition & 3 deletions source/module_basis/module_ao/ORB_control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ ORB_control::ORB_control() :
setup_2d(false)
{}
ORB_control::~ORB_control()
{
Cblacs_exit(1); //delete global variables in cblacs but do not close MPI
}
{}

void ORB_control::read_orb_first(
std::ofstream& ofs_in,
Expand Down
32 changes: 31 additions & 1 deletion source/module_basis/module_ao/parallel_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,4 +335,34 @@ void Parallel_2D::set_serial(const int& M_A, const int& N_A)
for (int i = 0; i < this->nrow; i++) this->row_set[i] = i;
for (int i = 0; i < this->ncol; i++) this->col_set[i] = i;
}
#endif
#endif

Parallel_2D& Parallel_2D::operator=(Parallel_2D&& rhs)
{
ModuleBase::TITLE("Parallel_2D", "operator=");
this->nrow = rhs.nrow;
this->ncol = rhs.ncol;
this->nloc = rhs.nloc;
this->nb = rhs.nb;
this->dim0 = rhs.dim0;
this->dim1 = rhs.dim1;
this->coord[0] = rhs.coord[0];
this->coord[1] = rhs.coord[1];
this->testpb = rhs.testpb;
this->row_set = std::move(rhs.row_set);
this->col_set = std::move(rhs.col_set);

if (this->trace_loc_row) delete[] this->trace_loc_row;
this->trace_loc_row = rhs.trace_loc_row;
rhs.trace_loc_row = nullptr;
if (this->trace_loc_col) delete[] this->trace_loc_col;
this->trace_loc_col = rhs.trace_loc_col;
rhs.trace_loc_col = nullptr;
#ifdef __MPI
this->blacs_ctxt = rhs.blacs_ctxt;
this->comm_2D = rhs.comm_2D;
for (int i = 0; i < 9; ++i)
this->desc[i] = rhs.desc[i];
#endif
return *this;
}
2 changes: 2 additions & 0 deletions source/module_basis/module_ao/parallel_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class Parallel_2D
Parallel_2D();
~Parallel_2D();

Parallel_2D& operator=(Parallel_2D&& rhs);

/// map from global index to local index
int* trace_loc_row = nullptr;
int* trace_loc_col = nullptr;
Expand Down
5 changes: 3 additions & 2 deletions source/module_beyonddft/esolver_lrtd_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ namespace ModuleESolver
Gint_k gint_k;

/// @brief variables for parallel distribution of KS orbitals
Parallel_2D ParaC;
Parallel_2D paraC_;
/// @brief variables for parallel distribution of excited states
Parallel_2D ParaX;
Parallel_2D paraX_;

void init_X();
void setup_2d_division(int nb, int gr, int gc);

};
}
44 changes: 38 additions & 6 deletions source/module_beyonddft/esolver_lrtd_lcao.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ ModuleESolver::ESolver_LRTD<FPTYPE, Device>::ESolver_LRTD(ModuleESolver::ESolver
//only need the eigenvalues. the 'elecstates' of excited states is different from ground state.
this->eig_ks = std::move(ks_sol.pelec->ekb);

// 2d-distribution
Parallel_2D* tmp_paraC = static_cast<Parallel_2D*>(&ks_sol.orb_con.ParaV);
this->paraC_ = std::move(*tmp_paraC);
tmp_paraC = nullptr;

// move the basis info (2-center integrals currently not needed)
// std::cout<<"before move orb"<<std::endl;
// this->orb = std::forward<LCAO_Orbitals>(GlobalC::ORB);
Expand All @@ -45,7 +50,7 @@ ModuleESolver::ESolver_LRTD<FPTYPE, Device>::ESolver_LRTD(ModuleESolver::ESolver
template<typename FPTYPE, typename Device>
void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::Init(Input& inp, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_KS_LCAO", "Init");
ModuleBase::TITLE("ESolver_LRTD", "Init");

this->p_input = &inp;
this->p_ucell = &ucell;
Expand All @@ -54,7 +59,7 @@ void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::Init(Input& inp, UnitCell& uce
// which determines the basis size of the excited states
this->nocc = LR_Util::cal_nocc(LR_Util::cal_nelec(ucell));
this->nbasis = GlobalV::NLOCAL; //use GlobalV temporarily about basis
this->nvirt = nbasis - nocc;
this->nvirt = this->eig_ks.nc - nocc; //nbands-nocc
this->npairs = this->nocc * this->nvirt;
this->nstates = inp.nstates;
GlobalV::ofs_running << "Setting LR-TDDFT parameters: " << std::endl;
Expand All @@ -75,20 +80,23 @@ void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::Init(Input& inp, UnitCell& uce
template<typename FPTYPE, typename Device>
void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::Run(int istep, UnitCell& cell)
{
ModuleBase::TITLE("ESolver_LRTD", "Run");
std::cout << "running ESolver_LRTD" << std::endl;
return;
}

template<typename FPTYPE, typename Device>
void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::init_X()
{
ModuleBase::TITLE("ESolver_LRTD", "Init");
//the eigenstate in the electron-hole pair representation
//Psi.nbasis = npairs, Psi.nbands = nstates.
//need a parallel distribution in the future
this->X.resize(this->nstates);

// setup ParaX
this->setup_2d_division(1, this->nocc, this->nvirt);
for (int i = 0; i < this->nstates; i++)
{
this->X[i] = std::move(psi::Psi<FPTYPE, Device>(this->nks, this->nocc, this->nvirt)); // to be changed into local size
this->X.emplace_back(this->nks, this->paraX_.get_row_size(), this->paraX_.get_col_size());
X[i].zero_out();
}

Expand All @@ -110,6 +118,30 @@ void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::init_X()
// use unit vectors as the initial guess
for (int i = 0; i < this->nstates; i++)
{
X[i](std::get<0>(ix2iciv[i]) , std::get<1>(ix2iciv[i])) = static_cast<FPTYPE>(1.0); // to be changed into local index
int row_global = std::get<0>(ix2iciv[i]);
int col_global = std::get<1>(ix2iciv[i]);
if (this->paraX_.in_this_processor(row_global, col_global))
X[i](this->paraX_.trace_loc_row[row_global], this->paraX_.trace_loc_col[col_global]) = static_cast<FPTYPE>(1.0);
}
}

template<typename FPTYPE, typename Device>
void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::setup_2d_division(int nb, int gr, int gc)
{
ModuleBase::TITLE("ESolver_LRTD", "setup_2d_division");
this->paraX_.set_block_size(nb);
#ifdef __MPI
int nprocs, myrank;
MPI_Comm_size(MPI_COMM_WORLD, &nprocs);
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
this->paraX_.set_proc_dim(nprocs);
this->paraX_.mpi_create_cart(MPI_COMM_WORLD);
this->paraX_.set_local2global(gr, gc, GlobalV::ofs_running, GlobalV::ofs_warning);
this->paraX_.set_desc(gr, gc, this->paraX_.get_row_size());
this->paraX_.set_global2local(gr, gc, true, GlobalV::ofs_running);
#else
this->paraX_.set_proc_dim(1);
this->paraX_.set_serial(gr, gc);
this->paraX_.set_global2local(gr, gc, false, GlobalV::ofs_running);
#endif
}
10 changes: 9 additions & 1 deletion source/module_io/input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,10 @@ bool Input::Read(const std::string &fn)
else if (strcmp("beyonddft_method", word) == 0)
{
read_value(ifs, beyonddft_method);
}
else if (strcmp("nstates", word) == 0)
{
read_value(ifs, nstates);
}
//----------------------------------------------------------------------------------
else
Expand Down Expand Up @@ -3219,7 +3223,11 @@ void Input::Bcast()
// device control denghui added on 2022-11-05
//----------------------------------------------------------------------------------
Parallel_Common::bcast_string(device);

//----------------------------------------------------------------------------------
// beyond dft
//----------------------------------------------------------------------------------
Parallel_Common::bcast_string(beyonddft_method);
Parallel_Common::bcast_int(nstates);
return;
}
#endif
Expand Down

0 comments on commit 4818252

Please sign in to comment.