Skip to content

Commit

Permalink
Procedure: build gamma_only AX
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Sep 7, 2023
1 parent a593aea commit fd653db
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 61 deletions.
11 changes: 7 additions & 4 deletions source/module_beyonddft/esolver_lrtd_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
#include "module_hamilt_lcao/module_gint/gint_k.h"
#include "module_hamilt_lcao/module_gint/grid_technique.h"
#include "module_elecstate/potentials/potential_new.h"
#include "module_beyonddft/potentials/pot_hxc_lrtd.hpp"
#include "module_beyonddft/hamilt_casida.hpp"

namespace ModuleESolver
{
Expand Down Expand Up @@ -52,12 +53,12 @@ namespace ModuleESolver
const UnitCell* p_ucell = nullptr;
const Input* p_input = nullptr;

hamilt::Hamilt<FPTYPE, Device>* p_hamilt = nullptr;
hamilt::HamiltCasidaLR<FPTYPE, Device>* p_hamilt = nullptr;
hsolver::HSolver<FPTYPE, Device>* phsol = nullptr;
// not to use ElecState because 2-particle state is quite different from 1-particle state.
// implement a independent one (ExcitedState) to pack physical properties if needed.
// put the components of ElecState here:
elecstate::PotBase* pot = nullptr;
elecstate::PotHxcLR* pot = nullptr;

// ground state info
//pelec in ESolver_FP
Expand All @@ -67,7 +68,7 @@ namespace ModuleESolver

/// @brief Excited state info. size: nstates*nks*nocc*nvirt
std::vector<psi::Psi<FPTYPE, Device>> X;
//psi::Psi<FPTYPE, Device>* AX = nullptr;
std::vector<psi::Psi<FPTYPE, Device>> AX;

size_t nocc;
size_t nvirt;
Expand Down Expand Up @@ -96,6 +97,8 @@ namespace ModuleESolver
Parallel_2D paraX_;
/// @brief variables for parallel distribution of matrix in AO representation
Parallel_2D paraMat_;

/// move to hsolver::updatePsiK
void init_X();

};
Expand Down
10 changes: 10 additions & 0 deletions source/module_beyonddft/esolver_lrtd_lcao.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ ModuleESolver::ESolver_LRTD<FPTYPE, Device>::ESolver_LRTD(ModuleESolver::ESolver
this->p_hamilt = new hamilt::HamiltCasidaLR<FPTYPE, Device>(this->nks, this->nbasis, this->nocc, this->nvirt,
&this->gint_k, this->pot, &this->gridt, std::vector<Parallel_2D*>({ &this->paraX_, &this->paraC_, &this->paraMat_ }));

// init HSolver
// this->phsol = new hsolver::HSolver

// try hPsi
int istate = 0;
int iks = 0;
using hpsi_info = typename hamilt::Operator<FPTYPE, Device>::hpsi_info;
hpsi_info dav_info(&this->X[istate], psi::Range(1, iks, 0, this->nocc - 1), this->AX[istate].get_pointer());
// this->p_hamilt->ops->hPsi(dav_info);

}

template<typename FPTYPE, typename Device>
Expand Down
2 changes: 1 addition & 1 deletion source/module_beyonddft/hamilt_casida.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace hamilt
const int& nocc,
const int& nvirt,
TGint* gint_in,
elecstate::PotBase* pot_in,
elecstate::PotHxcLR* pot_in,
const Grid_Technique* gt,
const std::vector<Parallel_2D*> p2d_in)
{
Expand Down
45 changes: 25 additions & 20 deletions source/module_beyonddft/operator_casida/operatorA_hxc.h
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
#pragma once
#include "module_hamilt_general/operator.h"
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
#include "module_hamilt_lcao/module_gint/grid_technique.h"
#include "module_hamilt_lcao/module_gint/gint_k.h"

#include "module_beyonddft/potentials/pot_hxc_lrtd.hpp"

namespace hamilt
{

/// @brief kernel for Hxc operator
/// @tparam T
template<typename T, typename Device = psi::DEVICE_CPU>
class OperatorA_Hxc : public Operator<T, Device>
/// @brief Hxc part of A operator
template<typename FPTYPE = double, typename Device = psi::DEVICE_CPU>
class OperatorA_Hxc : public Operator<FPTYPE, Device>
{
public:
OperatorA_Hxc(const int& nsk,
const int& naos,
const int& nocc,
const int& nvirt,
Gint_Gamma* gg_in,
elecstate::PotBase* pot_in,
elecstate::PotHxcLR* pot_in,
const Grid_Technique* gt_in/* grid parallel info*/,
const std::vector<Parallel_2D*> p2d_in /*< 2d-block parallel info of {X, c, matrix}*/)
: nsk(nsk), naos(naos), nocc(nocc), nvirt(nvirt),
gg(gg_in), pot(pot_in), gridt(gt_in),
px(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2))
pX(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2))
{
this->cal_type = calculation_type::lcao_gint;
this->is_first_node = true;
Expand All @@ -33,43 +33,48 @@ namespace hamilt
const int& nocc,
const int& nvirt,
Gint_k* gk_in,
elecstate::PotBase* pot_in,
elecstate::PotHxcLR* pot_in,
const Grid_Technique* gt_in/* grid parallel info*/,
const std::vector<Parallel_2D*> p2d_in /*< 2d-block parallel info of {X, c, matrix}*/)
: nsk(nsk), naos(naos), nocc(nocc), nvirt(nvirt),
gk(gk_in), pot(pot_in), gridt(gt_in),
px(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2))
pX(p2d_in.at(0)), pc(p2d_in.at(1)), pmat(p2d_in.at(2))
{
this->cal_type = calculation_type::lcao_gint;
this->is_first_node = true;
};
void init(const int ik_in) override {};
//move hpsi and act to base class !
// void act() const override; // call gint
void act() const;

virtual void act(const int nbands,
const int nbasis,
const int npol,
const FPTYPE* tmpsi_in,
FPTYPE* tmhpsi,
const int ngk_ik = 0)const override {};
//tmp, for only one state
virtual psi::Psi<FPTYPE> act(const psi::Psi<FPTYPE>& psi_in) const override;
private:
//global sizes
int nsk; //nspin*nkpoints
int naos;
int nocc;
int nvirt;

/// ground state wavefunction
const psi::Psi<FPTYPE, Device>* psi_ks = nullptr;
//parallel info
const Parallel_2D* pc = nullptr;
const Parallel_2D* px = nullptr;
const Parallel_2D* pmat = nullptr;
Parallel_2D* pc = nullptr;
Parallel_2D* pX = nullptr;
Parallel_2D* pmat = nullptr;
const Grid_Technique* gridt = nullptr;

elecstate::PotBase* pot = nullptr;
elecstate::PotHxcLR* pot = nullptr;

Gint_Gamma* gg = nullptr;
Gint_k* gk = nullptr;


/// \f[ \tilde{\rho}(r)=\sum_{\mu_j, \mu_b}\tilde{\rho}_{\mu_j,\mu_b}\phi_{\mu_b}(r)\phi_{\mu_j}(r) \f]
void cal_rho_trans();
};


}
}
#include "module_beyonddft/operator_casida/operatorA_hxc.hpp"
85 changes: 52 additions & 33 deletions source/module_beyonddft/operator_casida/operatorA_hxc.hpp
Original file line number Diff line number Diff line change
@@ -1,56 +1,75 @@
#pragma once
#include "operatorA_hxc.h"
#include <vector>
#include "module_base/blas_connector.h"
#include "utils/lr_util.h"

#include "module_beyonddft/utils/lr_util.h"
#include "module_hamilt_lcao/hamilt_lcaodft/DM_gamma_2d_to_grid.h"
#include "module_beyonddft/density/dm_trans.h"
#include "module_beyonddft/AX/AX.h"
namespace hamilt
{
// for double
template<typename T, typename Device>
void OperatorA_Hxc<T, Device>::act() const
template<typename FPTYPE, typename Device>
psi::Psi<FPTYPE> OperatorA_Hxc<FPTYPE, Device>::act(const psi::Psi<FPTYPE>& psi_in) const
{
auto block2grid = [](std::vector<ModuleBase::matrix> block, double*** grid)->void {};
// gamma-only now
// 1. transition density matrix
std::vector<ModuleBase::matrix> dm_trans_2d = cal_dm_trans_blas(*px, *pc);
double*** dm_trans_grid = LR_Util::new_p3(nspin, naos_local_grid, naos_local_grid);
block2grid(dm_trans_2d, dm_trans_grid);
// multi-k needs k-to-R FT
#ifdef __MPI
std::vector<container::Tensor> dm_trans_2d = cal_dm_trans_pblas(psi_in, *pX, *psi_ks, *pc, naos, nocc, nvirt, *pmat);
#else
std::vector<container::Tensor> dm_trans_2d = cal_dm_trans_blas(*pX, *pc);
#endif
double*** dm_trans_grid;
LR_Util::new_p3(dm_trans_grid, nsk, gridt->lgd, gridt->lgd);
//2d block to grid
DMgamma_2dtoGrid dm2g;
#ifdef __MPI
dm2g.setAlltoallvParameter(pmat->comm_2D, naos, pmat->blacs_ctxt, pmat->nb, gridt->lgd, gridt->trace_lo);
#endif
dm2g.cal_dk_gamma_from_2D(LR_Util::ten2mat_double(dm_trans_2d), dm_trans_grid, nsk, naos, gridt->lgd, GlobalV::ofs_running);

// 2. transition electron density
double** rho_trans = LR_Util::new_p2(nspin, this->gint_g->nbxx); // is nbxx local grid num ?
double** rho_trans;
LR_Util::new_p2(rho_trans, nsk, this->pot->nrxx);
Gint_inout inout_rho(dm_trans_grid, rho_trans, Gint_Tools::job_type::rho);
this->gint_g->cal_gint(&inout_rho);
this->gg->cal_gint(&inout_rho);

// 3. v_hxc = f_hxc * rho_trans
this->pot->update_from_charge(rho_trans, GlobalC::ucell);
ModuleBase::matrix vr_hxc(nsk, this->pot->nrxx); //grid
this->pot->cal_v_eff(rho_trans, &GlobalC::ucell, vr_hxc);

// 4. V^{Hxc}_{\mu,\nu}=\int{dr} \phi_\mu(r) v_{Hxc}(r) \phi_\mu(r)
// loop for nspin, or use current spin (how?)
// results are stored in gint_g->pvpR_grid(gamma_only)
// loop for nsk, or use current spin (how?)
// results are stored in gg->pvpR_grid(gamma_only)
// or gint_k->pvpR_reduced(multi_k)

std::vector<ModuleBase::matrix> v_hxc_local(nspin); // 2D local matrix)
for (int is = 0;is < this->nspin;++is)
std::vector<ModuleBase::matrix> v_hxc_2d(nsk);
this->gg->init_pvpR_grid(gridt->lgd);
auto setter = [this](const int& iw1_all, const int& iw2_all, const double& v, double* out) {
const int ir = this->pmat->global2local_row(iw1_all);
const int ic = this->pmat->global2local_col(iw2_all);
out[ic * this->pmat->nrow + ir] += v;
};
for (int is = 0;is < this->nsk;++is)
{
const double* v1_hxc_grid = this->pot->get_effective_v(is);
Gint_intout inout_vlocal(v1_hxc_grid, is, Gint_Tools::job_type::vlocal);
// this->gint_g->cal_gint(&inout);
bool new_e_iteration = false; // what is this?
this->gint_g->cal_vlocal(&inout_vlocal, new_e_iteration);

// grid-to-2d needs refactor !
v_hxc_2d[is].create(naos_local_row, naos_local_col);
//LR_Util::grid2block(this->px, this->pc, this->gint_g->pvpR_grid, v_hxc_local.c);
ModuleBase::GlobalFunc::ZEROS(this->gg->get_pvpR_grid(), gridt->lgd * gridt->lgd);
double* vr_hxc_is = &vr_hxc.c[is * this->pot->nrxx]; //current spin
Gint_inout inout_vlocal(vr_hxc_is, Gint_Tools::job_type::vlocal);
this->gg->cal_gint(&inout_vlocal);
v_hxc_2d[is].create(pmat->get_col_size(), pmat->get_row_size());
this->gg->vl_grid_to_2D(this->gg->get_pvpR_grid(), *pmat, this->gridt->lgd, (is == 0), v_hxc_2d[is].c, setter);
}
this->gg->delete_pvpR_grid();

// clear useless matrices
LR_Util::delete_p3(dm_trans_grid, nsk, gridt->lgd);
LR_Util::delete_p2(rho_trans, nsk);

// 5. [AX]^{Hxc}_{ai}=\sum_{\mu,\nu}c^*_{a,\mu,}V^{Hxc}_{\mu,\nu}c_{\nu,i}
// use 2 pzgemms
this->cal_AX_cVc(v_hxc_2d, this->px);
// result is in which "psi" ? X?

// final clear
LR_Util::delete_p3(dm_trans_grid);
LR_Util::delete_p2(rho_trans);
return;
#ifdef __MPI
return cal_AX_pblas(LR_Util::mat2ten_double(v_hxc_2d), *this->pmat, *this->psi_ks, *this->pc, naos, nocc, nvirt, *this->pX);
#else
return cal_AX_blas(LR_Util::mat2ten_double(v_hxc_2d), *this->psi_ks, nocc);
#endif
}
}
6 changes: 3 additions & 3 deletions source/module_beyonddft/potentials/pot_hxc_lrtd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ namespace elecstate
{
for (auto k : this->kernel_hxc) delete k;
}

void cal_v_eff(const Charge* chg/*excited state*/, const UnitCell* ucell, ModuleBase::matrix& v_eff) override
void cal_v_eff(const Charge* chg/*excited state*/, const UnitCell* ucell, ModuleBase::matrix& v_eff) override {};
void cal_v_eff(double** rho, const UnitCell* ucell, ModuleBase::matrix& v_eff)
{
ModuleBase::TITLE("PotHxcLR", "cal_v_eff");
ModuleBase::timer::tick("PotHxcLR", "cal_v_eff");
Expand All @@ -38,7 +38,7 @@ namespace elecstate
if (XC_Functional::get_func_type() == 1)
if (1 == nspin)// for LDA-spin0, just f*rho
for (int ir = 0;ir < nrxx;++ir)
v_eff(0, ir) += (this->kernel_hxc[0]->get_kernel()["Hartree"](0, ir) + this->kernel_hxc[1]->get_kernel()["LDA-v2rho2"](0, ir)) * chg->rho[0][ir];
v_eff(0, ir) += (this->kernel_hxc[0]->get_kernel()["Hartree"](0, ir) + this->kernel_hxc[1]->get_kernel()["LDA-v2rho2"](0, ir)) * rho[0][ir];
else //remain for spin2, 4
throw std::domain_error("GlobalV::NSPIN =" + std::to_string(GlobalV::NSPIN)
+ " unfinished in " + std::string(__FILE__) + " line " + std::to_string(__LINE__));
Expand Down

0 comments on commit fd653db

Please sign in to comment.