Skip to content

Commit

Permalink
matrix-lapack, reuse DM, fix wf-sign-dependence
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Nov 18, 2023
1 parent e7eb223 commit 27af2b2
Show file tree
Hide file tree
Showing 15 changed files with 221 additions and 85 deletions.
2 changes: 1 addition & 1 deletion source/module_beyonddft/AX/test/AX_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include "mpi.h"
#include "../AX.h"

#include "module_beyonddft/utils/lr_util.h"
#include "module_beyonddft/utils/lr_util_algorithms.hpp"

struct matsize
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_beyonddft/dm_trans/test/dm_trans_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include "mpi.h"
#include "../dm_trans.h"
#ifdef __MPI
#include "module_beyonddft/utils/lr_util.h"
#include "module_beyonddft/utils/lr_util_algorithms.hpp"
#endif
struct matsize
{
Expand Down
6 changes: 1 addition & 5 deletions source/module_beyonddft/esolver_lrtd_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ namespace ModuleESolver
delete this->phsol;
delete this->pot;
delete this->psi_ks;
delete this->DM_trans;
delete this->X;
}

Expand Down Expand Up @@ -76,10 +75,7 @@ namespace ModuleESolver
//pelec in ESolver_FP
// const psi::Psi<T>* psi_ks = nullptr;
psi::Psi<T>* psi_ks = nullptr;
ModuleBase::matrix eig_ks;
/// transition density matrix in AO representation
elecstate::DensityMatrix<T, double>* DM_trans = nullptr;
// energy of ground state is in pelec->ekb
ModuleBase::matrix eig_ks;///< energy of ground state

/// @brief Excited state info. size: nstates * nks * (nocc(local) * nvirt (local))
psi::Psi<T>* X;
Expand Down
17 changes: 7 additions & 10 deletions source/module_beyonddft/esolver_lrtd_lcao.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,10 @@ ModuleESolver::ESolver_LRTD<T, TR>::ESolver_LRTD(Input& inp, UnitCell& ucell) :

// if EXX from scratch, init 2-center integral and calclate Cs, Vs
#ifdef __EXX
bool exx_from_scratch = false;
if (exx_from_scratch && xc_kernel == "hf")
if (xc_kernel == "hf")
{
int Lmax = GlobalC::exx_info.info_ri.abfs_Lmax;
#ifndef USE_NEW_TWO_CENTER
int Lmax = GlobalC::exx_info.info_ri.abfs_Lmax;
this->orb_con.set_orb_tables(GlobalV::ofs_running,
GlobalC::UOT,
GlobalC::ORB,
Expand All @@ -271,10 +270,11 @@ ModuleESolver::ESolver_LRTD<T, TR>::ESolver_LRTD(Input& inp, UnitCell& ucell) :
ucell.infoNL.Beta);
#else
two_center_bundle.reset(new TwoCenterBundle);
two_center_bundle->build(ucell.ntype, ucell.orbital_fn, ucell.infoNL.Beta,
two_center_bundle->build(ucell.ntype, ucell.orbital_fn, nullptr/*ucell.infoNL.Beta*/,
GlobalV::deepks_setorb, &ucell.descriptor_file);
GlobalC::UOT.two_center_bundle = std::move(two_center_bundle);
#endif
std::cout << GlobalC::exx_info.info_ri.dm_threshold << std::endl;
this->exx_lri = std::make_shared<Exx_LRI<T>>(GlobalC::exx_info.info_ri);
this->exx_lri->init(MPI_COMM_WORLD, this->kv); // using GlobalC::ORB
this->exx_lri->cal_exx_ions();
Expand Down Expand Up @@ -388,18 +388,15 @@ void ModuleESolver::ESolver_LRTD<T, TR>::init_A(hamilt::HContainer<double>* pHR_
pHR->allocate(true);
}
pHR->set_paraV(&this->paraMat_);
this->DM_trans = new elecstate::DensityMatrix<T, double>(&this->kv, &this->paraMat_, this->nspin);
this->DM_trans->init_DMR(*pHR);
this->p_hamilt = new hamilt::HamiltCasidaLR<T>(xc_kernel, this->nspin, this->nbasis, this->nocc, this->nvirt, this->ucell, this->psi_ks, this->eig_ks, this->DM_trans, pHR,
this->p_hamilt = new hamilt::HamiltCasidaLR<T>(xc_kernel, this->nspin, this->nbasis, this->nocc, this->nvirt, this->ucell, this->psi_ks, this->eig_ks, pHR,
#ifdef __EXX
this->exx_lri.get(),
#endif
this->gint, this->pot, this->kv, std::vector<Parallel_2D*>({ &this->paraX_, &this->paraC_, &this->paraMat_ }));
this->gint, this->pot, this->kv, & this->paraX_, & this->paraC_, & this->paraMat_);

// init HSolver
this->phsol = new hsolver::HSolverLR<T>();
this->phsol = new hsolver::HSolverLR<T>(this->npairs);
this->phsol->set_diagethr(0, 0, std::max(1e-13, lr_thr));
// if (!pHR_in) delete pHR;
}

template<typename T, typename TR>
Expand Down
73 changes: 66 additions & 7 deletions source/module_beyonddft/hamilt_casida.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,36 @@ namespace hamilt
const UnitCell& ucell_in,
const psi::Psi<T>* psi_ks_in,
const ModuleBase::matrix& eig_ks,
elecstate::DensityMatrix<T, double>* DM_trans_in,
// elecstate::DensityMatrix<T, double>* DM_trans_in,
HContainer<double>*& hR_in,
#ifdef __EXX
Exx_LRI<T>* exx_lri_in,
#endif
TGint* gint_in,
elecstate::PotHxcLR* pot_in,
const K_Vectors& kv_in,
const std::vector<Parallel_2D*> p2d_in)
Parallel_2D* pX_in,
Parallel_2D* pc_in,
Parallel_Orbitals* pmat_in) : nocc(nocc), nvirt(nvirt), pX(pX_in)
{
ModuleBase::TITLE("HamiltCasidaLR", "HamiltCasidaLR");
this->classname = "HamiltCasidaLR";
assert(hR_in != nullptr);
this->hR = new HContainer<double>(std::move(*hR_in));
this->DM_trans.resize(1);
this->DM_trans[0] = new elecstate::DensityMatrix<T, double>(&kv_in, pmat_in, nspin);
this->DM_trans[0]->init_DMR(*this->hR);
// add the diag operator (the first one)
this->ops = new OperatorLRDiag<T>(eig_ks, p2d_in.at(0), kv_in.nks, nspin, nocc, nvirt);
this->ops = new OperatorLRDiag<T>(eig_ks, pX_in, kv_in.nks, nspin, nocc, nvirt);
//add Hxc operator
OperatorLRHxc<T>* lr_hxc = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks_in, DM_trans_in, this->hR, gint_in, pot_in, kv_in.kvec_d, p2d_in);
OperatorLRHxc<T>* lr_hxc = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks_in,
this->DM_trans, this->hR, gint_in, pot_in, kv_in, pX_in, pc_in, pmat_in);
this->ops->add(lr_hxc);
#ifdef __EXX
if (xc_kernel == "hf")
{
//add Exx operator
Operator<T>* lr_exx = new OperatorLREXX<T>(nspin, naos, nocc, nvirt, ucell_in, psi_ks_in, DM_trans_in, exx_lri_in, kv_in, p2d_in);
{ //add Exx operator
Operator<T>* lr_exx = new OperatorLREXX<T>(nspin, naos, nocc, nvirt, ucell_in, psi_ks_in,
this->DM_trans, exx_lri_in, kv_in, pX_in, pc_in, pmat_in);
this->ops->add(lr_exx);
}
#endif
Expand All @@ -55,10 +61,63 @@ namespace hamilt
delete this->ops;
}
delete this->hR;
for (auto& d : this->DM_trans)delete d;
};

HContainer<double>* getHR() { return this->hR; }

virtual std::vector<T> matrix() override
{
ModuleBase::TITLE("HamiltCasidaLR", "matrix");
int npairs = this->nocc * this->nvirt;
std::vector<T> Amat_full(npairs * npairs, 0.0);
for (int lj = 0;lj < this->pX->get_col_size();++lj)
for (int lb = 0;lb < this->pX->get_row_size();++lb)
{//calculate A^{ai} for each bj
int b = pX->local2global_row(lb);
int j = pX->local2global_col(lj);
int bj = j * nvirt + b;
psi::Psi<T> X_bj(1, 1, this->pX->get_local_size());
X_bj.zero_out();
X_bj(0, 0, lj * this->pX->get_row_size() + lb) = this->one();
psi::Psi<T> A_aibj(1, 1, this->pX->get_local_size());
A_aibj.zero_out();
Operator<T>* node(this->ops);
while (node != nullptr)
{
node->act(X_bj, A_aibj, 1);
node = (Operator<T>*)(node->next_op);
}
// reduce ai for a fixed bj
LR_Util::gather_2d_to_full(*this->pX, A_aibj.get_pointer(), Amat_full.data() + bj * npairs, false, this->nvirt, this->nocc);
}
// reduce all bjs
MPI_Allreduce(MPI_IN_PLACE, Amat_full.data(), npairs * npairs, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
// output Amat
std::cout << "Amat_full:" << std::endl;
for (int i = 0;i < npairs;++i)
{
for (int j = 0;j < npairs;++j)
{
std::cout << Amat_full[i * npairs + j] << " ";
}
std::cout << std::endl;
}
return Amat_full;
}
private:
int nocc;
int nvirt;
Parallel_2D* pX = nullptr;
T one();
HContainer<double>* hR = nullptr;
/// transition density matrix in AO representation
/// Hxc only: size=1, calculate on the same address for each bands
/// Hxc+Exx: size=nbands, store the result of each bands for common use
std::vector<elecstate::DensityMatrix<T, double>*> DM_trans;
};

template<> double HamiltCasidaLR<double>::one() { return 1.0; }
template<> std::complex<double> HamiltCasidaLR<std::complex<double>>::one() { return std::complex<double>(1.0, 0.0); }

}
11 changes: 11 additions & 0 deletions source/module_beyonddft/hsolver_lrtd.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "hsolver_lrtd.h"
#include "module_hsolver/diago_david.h"
#include "module_hsolver/diago_cg.h"
#include "module_beyonddft/utils/lr_util.h"

namespace hsolver
{
Expand Down Expand Up @@ -30,6 +31,16 @@ namespace hsolver
this->pdiagh = new DiagoCG<T, Device>(precondition.data());
this->pdiagh->method = this->method;
}
else if (this->method == "lapack")
{
std::vector<T> Amat_full = pHamilt->matrix();
eigenvalue.resize(npairs);
LR_Util::diag_lapack(npairs, Amat_full.data(), eigenvalue.data());
std::cout << "eigenvalues:" << std::endl;
for (auto& e : eigenvalue)std::cout << e << " ";
std::cout << std::endl;
return;
}
else
throw std::runtime_error("HSolverLR::solve: method not implemented");

Expand Down
3 changes: 2 additions & 1 deletion source/module_beyonddft/hsolver_lrtd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace hsolver
{
private:
using Real = typename GetTypeReal<T>::type;
const int npairs = 0;
public:
HSolverLR() {};
HSolverLR(const int npairs_in) :npairs(npairs_in) {};
virtual Real set_diagethr(const int istep, const int iter, const Real ethr) override
{
this->diag_ethr = ethr;
Expand Down
18 changes: 10 additions & 8 deletions source/module_beyonddft/operator_casida/operator_lr_exx.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ namespace hamilt
const int& nvirt,
const UnitCell& ucell_in,
const psi::Psi<T>* psi_ks_in,
elecstate::DensityMatrix<T, double>* DM_trans_in,
std::vector<elecstate::DensityMatrix<T, double>*>& DM_trans_in,
// HContainer<double>* hR_in,
Exx_LRI<T>* exx_lri_in,
const K_Vectors& kv_in,
const std::vector<Parallel_2D*> p2d_in /*< 2d-block parallel info of {X, c, matrix}*/)
Parallel_2D* pX_in,
Parallel_2D* pc_in,
Parallel_Orbitals* pmat_in)
: nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt),
psi_ks(psi_ks_in), DM_trans(DM_trans_in), exx_lri(exx_lri_in), kv(kv_in),
pX(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2)), ucell(ucell_in)
pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in)
{
ModuleBase::TITLE("OperatorLREXX", "OperatorLREXX");
this->nks = std::is_same<T, double>::value ? 1 : this->kv.kvec_d.size();
Expand All @@ -38,14 +40,15 @@ namespace hamilt
this->is_first_node = false;

// reduce psi_ks for later use
this->psi_ks_full.resize(this->nsk, this->nbands_ks, this->naos);
LR_Util::gather_2d_to_full(*this->pc, this->psi_ks->get_pointer(), this->psi_ks_full.get_pointer(), false, this->naos, this->nbands_ks);
this->psi_ks_full.resize(this->nsk, this->psi_ks->get_nbands(), this->naos);
LR_Util::gather_2d_to_full(*this->pc, this->psi_ks->get_pointer(), this->psi_ks_full.get_pointer(), false, this->naos, this->psi_ks->get_nbands());

// get cells in BvK supercell
const TC period = RI_Util::get_Born_vonKarmen_period(kv_in);
this->BvK_cells = RI_Util::get_Born_von_Karmen_cells(period);

this->allocate_Ds_onebase();
this->exx_lri->Hexxs.resize(this->nspin);
};

void init(const int ik_in) override {};
Expand All @@ -60,14 +63,13 @@ namespace hamilt
int naos;
int nocc;
int nvirt;
int nbands_ks;
const K_Vectors& kv;
/// ground state wavefunction
const psi::Psi<T>* psi_ks = nullptr;
psi::Psi<T> psi_ks_full;

/// transition density matrix
elecstate::DensityMatrix<T, double>* DM_trans;
std::vector<elecstate::DensityMatrix<T, double>*>& DM_trans;

/// density matrix of a certain (i, a, k), with full naos*naos size for each key
/// D^{iak}_{\mu\nu}(k): 1/N_k * c^*_{ak,\mu} c_{ik,\nu}
Expand All @@ -92,7 +94,7 @@ namespace hamilt
///parallel info
Parallel_2D* pc = nullptr;
Parallel_2D* pX = nullptr;
Parallel_2D* pmat = nullptr;
Parallel_Orbitals* pmat = nullptr;


// allocate Ds_onebase
Expand Down
31 changes: 28 additions & 3 deletions source/module_beyonddft/operator_casida/operator_lr_exx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace hamilt
int iat2 = ucell.itia2iat(it2, ia2);
auto& D2d = this->Ds_onebase[is][iat1][std::make_pair(iat2, cell)];
for (int iw1 = 0;iw1 < ucell.atoms[it1].nw;++iw1)
for (int iw2 = 0;iw1 < ucell.atoms[it2].nw;++iw2)
for (int iw2 = 0;iw2 < ucell.atoms[it2].nw;++iw2)
D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, iv, ucell.itiaiw2iwt(it2, ia2, iw2));
}
}
Expand Down Expand Up @@ -80,14 +80,39 @@ namespace hamilt

// 1. set_Ds (once)
// convert to vector<T*> for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor)
std::vector<std::vector<T>> DMk_trans_vector = this->DM_trans->get_DMK_vector();
std::cout << "ib=" << ib << std::endl;
std::vector<std::vector<T>> DMk_trans_vector = this->DM_trans[ib]->get_DMK_vector();
assert(DMk_trans_vector.size() == this->nsk);
std::vector<const std::vector<T>*> DMk_trans_pointer(this->nsk);
for (int is = 0;is < this->nsk;++is) DMk_trans_pointer[is] = &DMk_trans_vector[is];
// if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex<double>)
std::vector<std::map<TA, std::map<TAC, RI::Tensor<T>>>> Ds_trans =
RI_2D_Comm::split_m2D_ktoR<T>(this->kv, DMk_trans_pointer, *this->pmat);

// output DM_trans
GlobalV::ofs_running << "DM_trans in OperatorLREXX::act" << std::endl;
for (int is = 0;is < this->nsk;++is)
{
GlobalV::ofs_running << "is = " << is << std::endl;
for (auto& mabR : Ds_trans[is
])
{
int ia = mabR.first;
GlobalV::ofs_running << "ia = " << ia << std::endl;
for (auto& mbR : mabR.second)
{
GlobalV::ofs_running << "ib = " << mbR.first.first << ", R=" << mbR.first.second[0] <<
" " << mbR.first.second[1] << " " << mbR.first.second[2] << ", size=" << mbR.second.shape[0] << ", " << mbR.second.shape[1] << std::endl;
auto& t = mbR.second;
for (int i = 0;i < t.shape[0];++i)
{
for (int j = 0;j < t.shape[1];++j)
GlobalV::ofs_running << t(i, j) << " ";
GlobalV::ofs_running << std::endl;
}
}
}
}
// 2. cal_Hs
for (int is = 0;is < this->nsk;++is)
{
Expand All @@ -110,7 +135,7 @@ namespace hamilt
for (int is = 0;is < this->nspin;++is)
{
this->cal_DM_onebase(this->pX->local2global_col(io), this->pX->local2global_row(iv), ik, is); //set Ds_onebase
psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) =
psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) +=
this->exx_lri->exx_lri.post_2D.cal_energy(this->Ds_onebase[is], this->exx_lri->Hexxs[is]);
}
}
Expand Down
Loading

0 comments on commit 27af2b2

Please sign in to comment.