Skip to content

Commit

Permalink
param xc_kernel & fix gint-call and Kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Sep 14, 2023
1 parent 379534f commit f1bfb8b
Show file tree
Hide file tree
Showing 17 changed files with 100 additions and 76 deletions.
6 changes: 6 additions & 0 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -3146,4 +3146,10 @@ These parameters are used to solve the excited states using. e.g. lr-tddft
- **Description**: The number of 2-particle states to be solved
- **Default**: 0

### xc_kernel

- **Type**: String
- **Description**: The exchange-correlation kernel used in the calculation. Currently, only `LDA` is supported.
- **Default**: LDA

[back to top](#full-list-of-input-keywords)
10 changes: 6 additions & 4 deletions source/module_beyonddft/esolver_lrtd_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace ModuleESolver
delete this->phsol;
delete this->pot;
delete this->psi_ks;
delete this->X;
delete this->AX;
}

///input: input, call, basis(LCAO), psi(ground state), elecstate
Expand Down Expand Up @@ -66,9 +68,9 @@ namespace ModuleESolver
ModuleBase::matrix eig_ks;
// energy of ground state is in pelec->ekb

/// @brief Excited state info. size: nstates*nks*nocc*nvirt
std::vector<psi::Psi<FPTYPE, Device>> X;
std::vector<psi::Psi<FPTYPE, Device>> AX;
/// @brief Excited state info. size: nstates * nks * (nocc(local) * nvirt (local))
psi::Psi<FPTYPE, Device>* X;
psi::Psi<FPTYPE, Device>* AX;

size_t nocc;
size_t nvirt;
Expand All @@ -85,7 +87,7 @@ namespace ModuleESolver
// adj info
Record_adj ra;
// grid parallel info (no need for 2d-block distribution)?
Grid_Technique gridt;
// Grid_Technique gridt;
// grid integration method(will be moved to OperatorKernelHxc)
ModulePW::PW_Basis_Big bigpw;
Gint_Gamma gint_g;
Expand Down
39 changes: 18 additions & 21 deletions source/module_beyonddft/esolver_lrtd_lcao.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@ ModuleESolver::ESolver_LRTD<FPTYPE, Device>::ESolver_LRTD(ModuleESolver::ESolver
this->gint_g = std::move(ks_sol.UHM.GG);
else
this->gint_k = std::move(ks_sol.UHM.GK);
// can GlobalC be moved? try

// grid parallel info
this->gridt = std::move(ks_sol.GridT);

// xc kernel
XC_Functional::set_xc_type(inp.xc_kernel);
//init potential and calculate kernels using ground state charge
if (this->pot == nullptr)
{
Expand Down Expand Up @@ -75,22 +73,22 @@ ModuleESolver::ESolver_LRTD<FPTYPE, Device>::ESolver_LRTD(ModuleESolver::ESolver
//init Hamiltonian
if (typeid(FPTYPE) == typeid(double))
this->p_hamilt = new hamilt::HamiltCasidaLR<FPTYPE, Device>(this->nks, this->nbasis, this->nocc, this->nvirt, this->psi_ks,
&this->gint_g, this->pot, &this->gridt, std::vector<Parallel_2D*>({ &this->paraX_, &this->paraC_, &this->paraMat_ }));
&this->gint_g, this->pot, std::vector<Parallel_2D*>({ &this->paraX_, &this->paraC_, &this->paraMat_ }));
else if (typeid(FPTYPE) == typeid(std::complex<double>))
this->p_hamilt = new hamilt::HamiltCasidaLR<FPTYPE, Device>(this->nks, this->nbasis, this->nocc, this->nvirt, this->psi_ks,
&this->gint_k, this->pot, &this->gridt, std::vector<Parallel_2D*>({ &this->paraX_, &this->paraC_, &this->paraMat_ }));
&this->gint_k, this->pot, std::vector<Parallel_2D*>({ &this->paraX_, &this->paraC_, &this->paraMat_ }));

// init HSolver
// this->phsol = new hsolver::HSolver

/// =========just for test==============
// try hPsi
int istate = 0;
int iks = 0;
// using hpsi_info = typename hamilt::Operator<FPTYPE, Device>::hpsi_info;
// hpsi_info dav_info(&this->X[istate], psi::Range(1, iks, 0, this->paraX_.get_row_size() - 1), this->AX[istate].get_pointer());
this->AX[0] = this->p_hamilt->ops->act(this->X[0]);

// try act
for (int istate = 0;istate < nstates;++istate)
{
this->X->fix_b(istate);
this->AX->fix_b(istate);
this->p_hamilt->ops->act(*this->X, *this->AX);
}
}

template<typename FPTYPE, typename Device>
Expand All @@ -112,15 +110,12 @@ void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::init_X()

int nsk = (nspin == 4) ? this->nks : this->nks * this->nspin;
LR_Util::setup_2d_division(this->paraX_, 1, this->nvirt, this->nocc);//nvirt - row, nocc - col
for (int i = 0; i < this->nstates; i++)
{
this->X.emplace_back(nsk, this->paraX_.get_col_size(), this->paraX_.get_row_size());
X[i].zero_out();
}
this->X = new psi::Psi<FPTYPE, Device>(nsk, this->nstates, this->paraX_.get_local_size(), nullptr, false); // band(state)-first
this->X->zero_out();
LR_Util::setup_2d_division(this->paraMat_, 1, this->nbasis, this->nbasis, this->paraX_.comm_2D, this->paraX_.blacs_ctxt);
LR_Util::setup_2d_division(this->paraC_, 1, this->nbasis, this->nocc + this->nvirt, this->paraX_.comm_2D, this->paraX_.blacs_ctxt);
this->AX = this->X;

// this->AX = new psi::Psi<FPTYPE, Device>(*this->X);
this->AX = new psi::Psi<FPTYPE, Device>(nsk, this->nstates, this->paraX_.get_local_size(), nullptr, false);
// set the initial guess of X
// if (E_{lumo}-E_{homo-1} < E_{lumo+1}-E{homo}), mode = 0, else 1(smaller first)
bool ix_mode = 0; //default
Expand All @@ -143,9 +138,11 @@ void ModuleESolver::ESolver_LRTD<FPTYPE, Device>::init_X()
// use unit vectors as the initial guess
for (int i = 0; i < this->nstates; i++)
{
this->X->fix_b(i);
int row_global = std::get<0>(ix2iciv[i]);
int col_global = std::get<1>(ix2iciv[i]);
if (this->paraX_.in_this_processor(row_global, col_global))
X[i](this->paraX_.global2local_row(row_global), this->paraX_.global2local_col(col_global)) = static_cast<FPTYPE>(1.0);
for (int isk = 0;isk < nsk;++isk)
(*X)(isk, this->paraX_.global2local_row(row_global) * this->paraX_.get_col_size() + this->paraX_.global2local_col(col_global)) = static_cast<FPTYPE>(1.0);
}
}
3 changes: 1 addition & 2 deletions source/module_beyonddft/hamilt_casida.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ namespace hamilt
const psi::Psi<FPTYPE, Device>* psi_ks_in,
TGint* gint_in,
elecstate::PotHxcLR* pot_in,
const Grid_Technique* gt,
const std::vector<Parallel_2D*> p2d_in)
{
ModuleBase::TITLE("HamiltCasidaLR", "HamiltCasidaLR");
this->classname = "HamiltCasidaLR";

// ops and opsd in base class may be unified in the future?
//add Hxc operator (the first one)
this->ops = new OperatorA_Hxc<FPTYPE, Device>(nsk, naos, nocc, nvirt, psi_ks_in, gint_in, pot_in, gt, p2d_in);
this->ops = new OperatorA_Hxc<FPTYPE, Device>(nsk, naos, nocc, nvirt, psi_ks_in, gint_in, pot_in, p2d_in);
//add Exx operator (remaining)
// Operator<double>* A_Exx = new OperatorA_Exx<FPTYPE, TGint>;
// this->opsd->add(A_Exx);
Expand Down
10 changes: 4 additions & 6 deletions source/module_beyonddft/operator_casida/operatorA_hxc.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ namespace hamilt
const psi::Psi<FPTYPE, Device>* psi_ks_in,
Gint_Gamma* gg_in,
elecstate::PotHxcLR* pot_in,
const Grid_Technique* gt_in/* grid parallel info*/,
const std::vector<Parallel_2D*> p2d_in /*< 2d-block parallel info of {X, c, matrix}*/)
: nsk(nsk), naos(naos), nocc(nocc), nvirt(nvirt),
psi_ks(psi_ks_in), gg(gg_in), pot(pot_in), gridt(gt_in),
psi_ks(psi_ks_in), gg(gg_in), pot(pot_in),
pX(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2))
{
ModuleBase::TITLE("OperatorA_Hxc", "OperatorA_Hxc(gamma)");
Expand All @@ -37,10 +36,9 @@ namespace hamilt
const psi::Psi<FPTYPE, Device>* psi_ks_in,
Gint_k* gk_in,
elecstate::PotHxcLR* pot_in,
const Grid_Technique* gt_in/* grid parallel info*/,
const std::vector<Parallel_2D*> p2d_in /*< 2d-block parallel info of {X, c, matrix}*/)
: nsk(nsk), naos(naos), nocc(nocc), nvirt(nvirt),
psi_ks(psi_ks_in), gk(gk_in), pot(pot_in), gridt(gt_in),
psi_ks(psi_ks_in), gk(gk_in), pot(pot_in),
pX(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2))
{
ModuleBase::TITLE("OperatorA_Hxc", "OperatorA_Hxc(k)");
Expand All @@ -56,7 +54,8 @@ namespace hamilt
FPTYPE* tmhpsi,
const int ngk_ik = 0)const override {};
//tmp, for only one state
virtual psi::Psi<FPTYPE> act(const psi::Psi<FPTYPE>& psi_in) const override;
// virtual psi::Psi<FPTYPE> act(const psi::Psi<FPTYPE>& psi_in) const override;
virtual void act(const psi::Psi<FPTYPE>& psi_in, psi::Psi<FPTYPE>& psi_out) const override;
private:
//global sizes
int nsk; //nspin*nkpoints
Expand All @@ -70,7 +69,6 @@ namespace hamilt
Parallel_2D* pc = nullptr;
Parallel_2D* pX = nullptr;
Parallel_2D* pmat = nullptr;
const Grid_Technique* gridt = nullptr;

elecstate::PotHxcLR* pot = nullptr;

Expand Down
27 changes: 13 additions & 14 deletions source/module_beyonddft/operator_casida/operatorA_hxc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ namespace hamilt
{
// for double
template<typename FPTYPE, typename Device>
psi::Psi<FPTYPE> OperatorA_Hxc<FPTYPE, Device>::act(const psi::Psi<FPTYPE>& psi_in) const
// psi::Psi<FPTYPE> OperatorA_Hxc<FPTYPE, Device>::act(const psi::Psi<FPTYPE>& psi_in) const
void OperatorA_Hxc<FPTYPE, Device>::act(const psi::Psi<FPTYPE>& psi_in, psi::Psi<FPTYPE>& psi_out) const
{
ModuleBase::TITLE("OperatorA_Hxc", "act");
const int& lgd = gg->gridt->lgd;
// gamma-only now
// 1. transition density matrix
// multi-k needs k-to-R FT
Expand All @@ -21,13 +24,13 @@ namespace hamilt
std::vector<container::Tensor> dm_trans_2d = cal_dm_trans_blas(*pX, *pc, nocc, nvirt);
#endif
double*** dm_trans_grid;
LR_Util::new_p3(dm_trans_grid, nsk, gridt->lgd, gridt->lgd);
LR_Util::new_p3(dm_trans_grid, nsk, lgd, lgd);
//2d block to grid
DMgamma_2dtoGrid dm2g;
#ifdef __MPI
dm2g.setAlltoallvParameter(pmat->comm_2D, naos, pmat->blacs_ctxt, pmat->nb, gridt->lgd, gridt->trace_lo);
dm2g.setAlltoallvParameter(pmat->comm_2D, naos, pmat->blacs_ctxt, pmat->nb, lgd, gg->gridt->trace_lo);
#endif
dm2g.cal_dk_gamma_from_2D(LR_Util::ten2mat_double(dm_trans_2d), dm_trans_grid, nsk, naos, gridt->lgd, GlobalV::ofs_running);
dm2g.cal_dk_gamma_from_2D(LR_Util::ten2mat_double(dm_trans_2d), dm_trans_grid, nsk, naos, lgd, GlobalV::ofs_running);

// 2. transition electron density
double** rho_trans;
Expand All @@ -44,36 +47,32 @@ namespace hamilt
// results are stored in gg->pvpR_grid(gamma_only)
// or gint_k->pvpR_reduced(multi_k)
std::vector<ModuleBase::matrix> v_hxc_2d(nsk);
this->gg->init_pvpR_grid(gridt->lgd);
this->gg->init_pvpR_grid(lgd);
auto setter = [this](const int& iw1_all, const int& iw2_all, const double& v, double* out) {
const int ir = this->pmat->global2local_row(iw1_all);
const int ic = this->pmat->global2local_col(iw2_all);
out[ic * this->pmat->nrow + ir] += v;
};
for (int is = 0;is < this->nsk;++is)
{
ModuleBase::GlobalFunc::ZEROS(this->gg->get_pvpR_grid(), gridt->lgd * gridt->lgd);
ModuleBase::GlobalFunc::ZEROS(this->gg->get_pvpR_grid(), lgd * lgd);
double* vr_hxc_is = &vr_hxc.c[is * this->pot->nrxx]; //current spin
Gint_inout inout_vlocal(vr_hxc_is, Gint_Tools::job_type::vlocal);
this->gg->cal_gint(&inout_vlocal);
v_hxc_2d[is].create(pmat->get_col_size(), pmat->get_row_size());
this->gg->vl_grid_to_2D(this->gg->get_pvpR_grid(), *pmat, this->gridt->lgd, (is == 0), v_hxc_2d[is].c, setter);
this->gg->vl_grid_to_2D(this->gg->get_pvpR_grid(), *pmat, lgd, (is == 0), v_hxc_2d[is].c, setter);
}
this->gg->delete_pvpR_grid();

// clear useless matrices
LR_Util::delete_p3(dm_trans_grid, nsk, gridt->lgd);
LR_Util::delete_p3(dm_trans_grid, nsk, lgd);
LR_Util::delete_p2(rho_trans, nsk);

// 5. [AX]^{Hxc}_{ai}=\sum_{\mu,\nu}c^*_{a,\mu,}V^{Hxc}_{\mu,\nu}c_{\nu,i}
#ifdef __MPI
psi::Psi<FPTYPE> AX(1, nsk, this->pX->get_local_size(), nullptr, false);
cal_AX_pblas(LR_Util::mat2ten_double(v_hxc_2d), *this->pmat, *this->psi_ks, *this->pc, naos, nocc, nvirt, *this->pX, AX);
return AX;
cal_AX_pblas(LR_Util::mat2ten_double(v_hxc_2d), *this->pmat, *this->psi_ks, *this->pc, naos, nocc, nvirt, *this->pX, psi_out);
#else
psi::Psi<FPTYPE> AX(1, nsk, nocc * nvirt, nullptr, false);
cal_AX_blas(LR_Util::mat2ten_double(v_hxc_2d), *this->psi_ks, nocc, nvirt, AX);
return AX;
cal_AX_blas(LR_Util::mat2ten_double(v_hxc_2d), *this->psi_ks, nocc, nvirt, psi_out);
#endif
}
}
3 changes: 2 additions & 1 deletion source/module_beyonddft/potentials/kernel_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#include "module_basis/module_pw/pw_basis.h"
#include "module_elecstate/module_charge/charge.h"
#include "module_cell/unitcell.h"
// #include <ATen/tensor.h>
class KernelBase
{
public:
virtual ~KernelBase() = default;
virtual void cal_kernel(const Charge* chg_gs, const UnitCell* ucell, int& nspin) = 0;
virtual std::map<std::string, ModuleBase::matrix>& get_kernel() = 0;
virtual ModuleBase::matrix& get_kernel(const std::string& name) { return kernel_set_[name]; }
protected:
const ModulePW::PW_Basis* rho_basis_ = nullptr;
std::map<std::string, ModuleBase::matrix> kernel_set_; // [kernel_type][nspin][nrxx]
Expand Down
14 changes: 6 additions & 8 deletions source/module_beyonddft/potentials/kernel_hartree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ namespace elecstate
{
this->rho_basis_ = rho_basis_in;
}
std::map<std::string, ModuleBase::matrix>& get_kernel() override { return this->kernel_set_; }

void cal_kernel(const Charge* chg_gs, const UnitCell* ucell, int& nspin) override
{
ModuleBase::TITLE("KernelHartree", "cal_v_eff");
ModuleBase::timer::tick("KernelHartree", "cal_v_eff");
ModuleBase::matrix f_hartree(chg_gs->nrxx, nspin);

//1. Coulomb kernel in reciprocal space
this->kernel_set_.emplace("Hartree", ModuleBase::matrix(nspin, chg_gs->nrxx));
//1. Hartree kernel in reciprocal space
std::vector<std::complex<double>> Porter(this->rho_basis_->nmaxgr, std::complex<double>(0.0, 0.0));
#ifdef _OPENMP
#pragma omp parallel for
Expand All @@ -32,14 +30,14 @@ namespace elecstate
//2. FFT to real space
rho_basis_->recip2real(Porter.data(), Porter.data());

//3. Add to f_hartree
//3. Add to kernel_set_
if (nspin == 4)
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static, 512)
#endif
for (int ir = 0;ir < this->rho_basis_->nrxx;ir++)
f_hartree(ir, 0) += Porter[ir].real();
kernel_set_["Hartree"](0, ir) += Porter[ir].real();
}
else
{
Expand All @@ -48,9 +46,9 @@ namespace elecstate
#endif
for (int is = 0;is < nspin;is++)
for (int ir = 0;ir < this->rho_basis_->nrxx;ir++)
f_hartree(ir, is) += Porter[ir].real();
kernel_set_["Hartree"](is, ir) += Porter[ir].real();
}
this->kernel_set_.insert({ "Hartree", f_hartree });

ModuleBase::timer::tick("KernelHartree", "cal_v_eff");
return;
}
Expand Down
Loading

0 comments on commit f1bfb8b

Please sign in to comment.