Skip to content

Commit

Permalink
move act() interface to basic operator
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Sep 1, 2023
1 parent d37e0f6 commit 1779e38
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 125 deletions.
10 changes: 10 additions & 0 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ FPTYPE* Operator<FPTYPE, Device>::get_hpsi(const hpsi_info& info) const
return hpsi_pointer;
}

template<typename FPTYPE, typename Device>
void Operator<FPTYPE, Device>::act(
const int nbands,
const int nbasis,
const int npol,
const FPTYPE* tmpsi_in,
FPTYPE* tmhpsi,
const int ngk_ik)const
{
}

namespace hamilt {
template class Operator<float, psi::DEVICE_CPU>;
Expand Down
11 changes: 10 additions & 1 deletion source/module_hamilt_general/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,16 @@ class Operator

virtual void add(Operator* next);

virtual int get_ik() const {return this->ik;}
virtual int get_ik() const { return this->ik; }

//do operation : |hpsi_choosed> = V|psi_choosed>
//V is the target operator act on choosed psi, the consequence should be added to choosed hpsi
virtual void act(const int nbands,
const int nbasis,
const int npol,
const FPTYPE* tmpsi_in,
FPTYPE* tmhpsi,
const int ngk_ik = 0)const;

Operator* next_op = nullptr;

Expand Down
25 changes: 13 additions & 12 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,27 @@ Ekinetic<OperatorPW<FPTYPE, Device>>::~Ekinetic() {}

template<typename FPTYPE, typename Device>
void Ekinetic<OperatorPW<FPTYPE, Device>>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
ModuleBase::timer::tick("Operator", "EkineticPW");
const int npw = psi_in->get_ngk(this->ik);
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
ModuleBase::timer::tick("Operator", "EkineticPW");
int max_npw = nbasis / npol;

const FPTYPE *gk2_ik = &(this->gk2[this->ik * this->gk2_col]);
// denghui added 20221019
ekinetic_op()(this->ctx, n_npwx, npw, this->max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in);
// for (int ib = 0; ib < n_npwx; ++ib)
ekinetic_op()(this->ctx, nbands, ngk_ik, max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in);
// for (int ib = 0; ib < nbands; ++ib)
// {
// for (int ig = 0; ig < npw; ++ig)
// for (int ig = 0; ig < ngk_ik; ++ig)
// {
// tmhpsi[ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
// }
// tmhpsi += this->max_npw;
// tmpsi_in += this->max_npw;
// tmhpsi += max_npw;
// tmpsi_in += max_npw;
// }
ModuleBase::timer::tick("Operator", "EkineticPW");
}
Expand Down
15 changes: 6 additions & 9 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ class Ekinetic<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

virtual ~Ekinetic();

virtual void act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const override;
virtual void act(const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik = 0)const override;

// denghuilu added for copy construct at 20221105
int get_gk2_row() const {return this->gk2_row;}
Expand All @@ -49,10 +50,6 @@ class Ekinetic<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

private:

mutable int max_npw = 0;

mutable int npol = 0;

FPTYPE tpiba2 = 0.0;
const FPTYPE* gk2 = nullptr;
int gk2_row = 0;
Expand Down
25 changes: 12 additions & 13 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ Meta<OperatorPW<FPTYPE, Device>>::~Meta()

template<typename FPTYPE, typename Device>
void Meta<OperatorPW<FPTYPE, Device>>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi
)const
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
if (XC_Functional::get_func_type() != 3)
{
Expand All @@ -52,29 +53,27 @@ void Meta<OperatorPW<FPTYPE, Device>>::act(

ModuleBase::timer::tick("Operator", "MetaPW");

const int npw = psi_in->get_ngk(this->ik);
const int current_spin = this->isk[this->ik];
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
int max_npw = nbasis / npol;
//npol == 2 case has not been considered
this->npol = psi_in->npol;

for (int ib = 0; ib < n_npwx; ++ib)
for (int ib = 0; ib < nbands; ++ib)
{
for (int j = 0; j < 3; j++)
{
meta_op()(this->ctx, this->ik, j, npw, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), tmpsi_in, this->porter);
meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), tmpsi_in, this->porter);
wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik);

if(this->vk_col != 0) {
vector_mul_vector_op()(this->ctx, this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col);
}

wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik);
meta_op()(this->ctx, this->ik, j, npw, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), this->porter, tmhpsi, true);
meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<FPTYPE>(), wfcpw->get_kvec_c_data<FPTYPE>(), this->porter, tmhpsi, true);

} // x,y,z directions
tmhpsi += this->max_npw;
tmpsi_in += this->max_npw;
tmhpsi += max_npw;
tmpsi_in += max_npw;
}
ModuleBase::timer::tick("Operator", "MetaPW");
}
Expand Down
10 changes: 6 additions & 4 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class Meta<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

virtual ~Meta();

virtual void act(const psi::Psi<std::complex<FPTYPE>, Device>* psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi) const override;
virtual void act(const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk = 0)const override;

// denghui added for copy constructor at 20221105
FPTYPE get_tpiba() const
Expand Down
35 changes: 18 additions & 17 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,30 +206,31 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::add_nonlocal_pp(std::complex<FPTYPE>
}

template<typename FPTYPE, typename Device>
void Nonlocal<OperatorPW<FPTYPE, Device>>::act
(
const psi::Psi<std::complex<FPTYPE>, Device>* psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const
void Nonlocal<OperatorPW<FPTYPE, Device>>::act(
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
ModuleBase::timer::tick("Operator", "NonlocalPW");
this->npw = psi_in->get_ngk(this->ik);
this->max_npw = psi_in->get_nbasis() / psi_in->npol;
this->npol = psi_in->npol;
this->npw = ngk_ik;
this->max_npw = nbasis / npol;
this->npol = npol;

if (this->ppcell->nkb > 0)
{
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// qianrui optimize 2021-3-31
int nkb = this->ppcell->nkb;
if (this->nkb_m < n_npwx * nkb) {
resmem_complex_op()(this->ctx, this->becp, n_npwx * nkb, "Nonlocal<PW>::becp");
if (this->nkb_m < nbands * nkb) {
resmem_complex_op()(this->ctx, this->becp, nbands * nkb, "Nonlocal<PW>::becp");
}
// ModuleBase::ComplexMatrix becp(n_npwx, nkb, false);
// ModuleBase::ComplexMatrix becp(nbands, nkb, false);
char transa = 'C';
char transb = 'N';
if (n_npwx == 1)
if (nbands == 1)
{
int inc = 1;
// denghui replace 2022-10-20
Expand All @@ -250,7 +251,7 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
}
else
{
int npm = n_npwx;
int npm = nbands;
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// denghui replace 2022-10-20
gemm_op()(
Expand All @@ -264,16 +265,16 @@ void Nonlocal<OperatorPW<FPTYPE, Device>>::act
this->vkb,
this->ppcell->vkb.nc,
tmpsi_in,
this->max_npw,
max_npw,
&this->zero,
this->becp,
nkb
);
}

Parallel_Reduce::reduce_complex_double_pool(becp, nkb * n_npwx);
Parallel_Reduce::reduce_complex_double_pool(becp, nkb * nbands);

this->add_nonlocal_pp(tmhpsi, becp, n_npwx);
this->add_nonlocal_pp(tmhpsi, becp, nbands);
}
ModuleBase::timer::tick("Operator", "NonlocalPW");
}
Expand Down
12 changes: 6 additions & 6 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class Nonlocal<OperatorPW<FPTYPE, Device>> : public OperatorPW<FPTYPE, Device>

virtual void init(const int ik_in)override;

virtual void act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi
)const override;
virtual void act(const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk = 0)const override;

const int *get_isk() const {return this->isk;}
const pseudopot_cell_vnl *get_ppcell() const {return this->ppcell;}
Expand Down
19 changes: 5 additions & 14 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ typename OperatorPW<FPTYPE, Device>::hpsi_info OperatorPW<FPTYPE, Device>::hPsi(
ModuleBase::timer::tick("OperatorPW", "hPsi");
auto psi_input = std::get<0>(input);
std::tuple<const std::complex<FPTYPE>*, int> psi_info = psi_input->to_range(std::get<1>(input));
int n_npwx = std::get<1>(psi_info);
int nbands = std::get<1>(psi_info);

std::complex<FPTYPE> *tmhpsi = this->get_hpsi(input);
const std::complex<FPTYPE> *tmpsi_in = std::get<0>(psi_info);
Expand All @@ -24,11 +24,11 @@ typename OperatorPW<FPTYPE, Device>::hpsi_info OperatorPW<FPTYPE, Device>::hPsi(
ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!");
}

this->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik));
OperatorPW* node((OperatorPW*)this->next_op);
while(node != nullptr)
{
node->act(psi_input, n_npwx, tmpsi_in, tmhpsi);
node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik));
node = (OperatorPW*)(node->next_op);
}

Expand All @@ -41,18 +41,9 @@ typename OperatorPW<FPTYPE, Device>::hpsi_info OperatorPW<FPTYPE, Device>::hPsi(
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
syncmem_complex_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
delete this->hpsi;
this->hpsi = new psi::Psi<std::complex<FPTYPE>, Device>(hpsi_pointer, *psi_input, 1, n_npwx/psi_input->npol);
this->hpsi = new psi::Psi<std::complex<FPTYPE>, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
}
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/psi_input->npol), hpsi_pointer);
}

template<typename FPTYPE, typename Device>
void OperatorPW<FPTYPE, Device>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi) const
{
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}

namespace hamilt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ class OperatorPW : public Operator<std::complex<FPTYPE>, Device>
using hpsi_info = typename hamilt::Operator<std::complex<FPTYPE>, Device>::hpsi_info;
virtual hpsi_info hPsi(hpsi_info& input)const;
//main function which should be modified in Operator for PW base
//do operation : |hpsi_choosed> = V|psi_choosed>
//V is the target operator act on choosed psi, the consequence should be added to choosed hpsi
virtual void act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi)const;

std::string classname = "";
using syncmem_complex_op = psi::memory::synchronize_memory_op<std::complex<FPTYPE>, Device, Device>;
Expand Down
27 changes: 13 additions & 14 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ Veff<OperatorPW<FPTYPE, Device>>::~Veff()

template<typename FPTYPE, typename Device>
void Veff<OperatorPW<FPTYPE, Device>>::act(
const psi::Psi<std::complex<FPTYPE>, Device> *psi_in,
const int n_npwx,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi
)const
const int nbands,
const int nbasis,
const int npol,
const std::complex<FPTYPE>* tmpsi_in,
std::complex<FPTYPE>* tmhpsi,
const int ngk_ik)const
{
ModuleBase::timer::tick("Operator", "VeffPW");

this->max_npw = psi_in->get_nbasis() / psi_in->npol;
int max_npw = nbasis / npol;
const int current_spin = this->isk[this->ik];
this->npol = psi_in->npol;

// std::complex<FPTYPE> *porter = new std::complex<FPTYPE>[wfcpw->nmaxgr];
for (int ib = 0; ib < n_npwx; ib += this->npol)
for (int ib = 0; ib < nbands; ib += npol)
{
if (this->npol == 1)
if (npol == 1)
{
// wfcpw->recip2real(tmpsi_in, porter, this->ik);
wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
Expand All @@ -77,7 +77,7 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
// std::complex<FPTYPE> *porter1 = new std::complex<FPTYPE>[wfcpw->nmaxgr];
// fft to real space and doing things.
wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
wfcpw->recip_to_real(this->ctx, tmpsi_in + this->max_npw, this->porter1, this->ik);
wfcpw->recip_to_real(this->ctx, tmpsi_in + max_npw, this->porter1, this->ik);
if(this->veff_col != 0)
{
/// denghui added at 20221109
Expand All @@ -102,10 +102,10 @@ void Veff<OperatorPW<FPTYPE, Device>>::act(
}
// (3) fft back to G space.
wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + this->max_npw, this->ik, true);
wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + max_npw, this->ik, true);
}
tmhpsi += this->max_npw * this->npol;
tmpsi_in += this->max_npw * this->npol;
tmhpsi += max_npw * npol;
tmpsi_in += max_npw * npol;
}
ModuleBase::timer::tick("Operator", "VeffPW");
}
Expand All @@ -120,7 +120,6 @@ hamilt::Veff<OperatorPW<FPTYPE, Device>>::Veff(const Veff<OperatorPW<T_in, Devic
this->veff_col = veff->get_veff_col();
this->veff_row = veff->get_veff_row();
this->wfcpw = veff->get_wfcpw();
this->npol = veff->get_npol();
resmem_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr);
resmem_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr);
this->veff = veff->get_veff();
Expand Down
Loading

0 comments on commit 1779e38

Please sign in to comment.