diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 50f186fef8..37b8694a56 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -118,6 +118,16 @@ FPTYPE* Operator::get_hpsi(const hpsi_info& info) const return hpsi_pointer; } +template +void Operator::act( + const int nbands, + const int nbasis, + const int npol, + const FPTYPE* tmpsi_in, + FPTYPE* tmhpsi, + const int ngk_ik)const +{ +} namespace hamilt { template class Operator; diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 6c0121d1c2..4f8c10a57d 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -46,7 +46,16 @@ class Operator virtual void add(Operator* next); - virtual int get_ik() const {return this->ik;} + virtual int get_ik() const { return this->ik; } + + //do operation : |hpsi_choosed> = V|psi_choosed> + //V is the target operator act on choosed psi, the consequence should be added to choosed hpsi + virtual void act(const int nbands, + const int nbasis, + const int npol, + const FPTYPE* tmpsi_in, + FPTYPE* tmhpsi, + const int ngk_ik = 0)const; Operator* next_op = nullptr; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp index d6f387aad8..c4e69cf909 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.cpp @@ -31,26 +31,27 @@ Ekinetic>::~Ekinetic() {} template void Ekinetic>::act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi)const + const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik)const { - ModuleBase::timer::tick("Operator", "EkineticPW"); - const int npw = psi_in->get_ngk(this->ik); - this->max_npw = psi_in->get_nbasis() / psi_in->npol; + ModuleBase::timer::tick("Operator", "EkineticPW"); + int max_npw = nbasis / npol; const FPTYPE *gk2_ik = &(this->gk2[this->ik * this->gk2_col]); // denghui added 20221019 - ekinetic_op()(this->ctx, n_npwx, npw, this->max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in); - // for (int ib = 0; ib < n_npwx; ++ib) + ekinetic_op()(this->ctx, nbands, ngk_ik, max_npw, tpiba2, gk2_ik, tmhpsi, tmpsi_in); + // for (int ib = 0; ib < nbands; ++ib) // { - // for (int ig = 0; ig < npw; ++ig) + // for (int ig = 0; ig < ngk_ik; ++ig) // { // tmhpsi[ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ig]; // } - // tmhpsi += this->max_npw; - // tmpsi_in += this->max_npw; + // tmhpsi += max_npw; + // tmpsi_in += max_npw; // } ModuleBase::timer::tick("Operator", "EkineticPW"); } diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h index 505afa1666..32d6e55aeb 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/ekinetic_pw.h @@ -34,11 +34,12 @@ class Ekinetic> : public OperatorPW virtual ~Ekinetic(); - virtual void act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi)const override; + virtual void act(const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik = 0)const override; // denghuilu added for copy construct at 20221105 int get_gk2_row() const {return this->gk2_row;} @@ -49,10 +50,6 @@ class Ekinetic> : public OperatorPW private: - mutable int max_npw = 0; - - mutable int npol = 0; - FPTYPE tpiba2 = 0.0; const FPTYPE* gk2 = nullptr; int gk2_row = 0; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp index f3ae9ec76f..a9d7047048 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp @@ -39,11 +39,12 @@ Meta>::~Meta() template void Meta>::act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi -)const + const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik)const { if (XC_Functional::get_func_type() != 3) { @@ -52,17 +53,15 @@ void Meta>::act( ModuleBase::timer::tick("Operator", "MetaPW"); - const int npw = psi_in->get_ngk(this->ik); const int current_spin = this->isk[this->ik]; - this->max_npw = psi_in->get_nbasis() / psi_in->npol; + int max_npw = nbasis / npol; //npol == 2 case has not been considered - this->npol = psi_in->npol; - for (int ib = 0; ib < n_npwx; ++ib) + for (int ib = 0; ib < nbands; ++ib) { for (int j = 0; j < 3; j++) { - meta_op()(this->ctx, this->ik, j, npw, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data(), wfcpw->get_kvec_c_data(), tmpsi_in, this->porter); + meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data(), wfcpw->get_kvec_c_data(), tmpsi_in, this->porter); wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik); if(this->vk_col != 0) { @@ -70,11 +69,11 @@ void Meta>::act( } wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik); - meta_op()(this->ctx, this->ik, j, npw, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data(), wfcpw->get_kvec_c_data(), this->porter, tmhpsi, true); + meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data(), wfcpw->get_kvec_c_data(), this->porter, tmhpsi, true); } // x,y,z directions - tmhpsi += this->max_npw; - tmpsi_in += this->max_npw; + tmhpsi += max_npw; + tmpsi_in += max_npw; } ModuleBase::timer::tick("Operator", "MetaPW"); } diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h index 5df7b95e84..17d1d12b3a 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h @@ -35,10 +35,12 @@ class Meta> : public OperatorPW virtual ~Meta(); - virtual void act(const psi::Psi, Device>* psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi) const override; + virtual void act(const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk = 0)const override; // denghui added for copy constructor at 20221105 FPTYPE get_tpiba() const diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp index b773855e89..76aa9c5e43 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp @@ -206,30 +206,31 @@ void Nonlocal>::add_nonlocal_pp(std::complex } template -void Nonlocal>::act -( - const psi::Psi, Device>* psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi)const +void Nonlocal>::act( + const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik)const { ModuleBase::timer::tick("Operator", "NonlocalPW"); - this->npw = psi_in->get_ngk(this->ik); - this->max_npw = psi_in->get_nbasis() / psi_in->npol; - this->npol = psi_in->npol; + this->npw = ngk_ik; + this->max_npw = nbasis / npol; + this->npol = npol; if (this->ppcell->nkb > 0) { //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // qianrui optimize 2021-3-31 int nkb = this->ppcell->nkb; - if (this->nkb_m < n_npwx * nkb) { - resmem_complex_op()(this->ctx, this->becp, n_npwx * nkb, "Nonlocal::becp"); + if (this->nkb_m < nbands * nkb) { + resmem_complex_op()(this->ctx, this->becp, nbands * nkb, "Nonlocal::becp"); } - // ModuleBase::ComplexMatrix becp(n_npwx, nkb, false); + // ModuleBase::ComplexMatrix becp(nbands, nkb, false); char transa = 'C'; char transb = 'N'; - if (n_npwx == 1) + if (nbands == 1) { int inc = 1; // denghui replace 2022-10-20 @@ -250,7 +251,7 @@ void Nonlocal>::act } else { - int npm = n_npwx; + int npm = nbands; //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // denghui replace 2022-10-20 gemm_op()( @@ -264,16 +265,16 @@ void Nonlocal>::act this->vkb, this->ppcell->vkb.nc, tmpsi_in, - this->max_npw, + max_npw, &this->zero, this->becp, nkb ); } - Parallel_Reduce::reduce_complex_double_pool(becp, nkb * n_npwx); + Parallel_Reduce::reduce_complex_double_pool(becp, nkb * nbands); - this->add_nonlocal_pp(tmhpsi, becp, n_npwx); + this->add_nonlocal_pp(tmhpsi, becp, nbands); } ModuleBase::timer::tick("Operator", "NonlocalPW"); } diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h index 21bcf96a55..914a7b8cb7 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h @@ -37,12 +37,12 @@ class Nonlocal> : public OperatorPW virtual void init(const int ik_in)override; - virtual void act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi - )const override; + virtual void act(const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk = 0)const override; const int *get_isk() const {return this->isk;} const pseudopot_cell_vnl *get_ppcell() const {return this->ppcell;} diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp index 8096648ee2..105dc5f92a 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp @@ -14,7 +14,7 @@ typename OperatorPW::hpsi_info OperatorPW::hPsi( ModuleBase::timer::tick("OperatorPW", "hPsi"); auto psi_input = std::get<0>(input); std::tuple*, int> psi_info = psi_input->to_range(std::get<1>(input)); - int n_npwx = std::get<1>(psi_info); + int nbands = std::get<1>(psi_info); std::complex *tmhpsi = this->get_hpsi(input); const std::complex *tmpsi_in = std::get<0>(psi_info); @@ -24,11 +24,11 @@ typename OperatorPW::hpsi_info OperatorPW::hPsi( ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!"); } - this->act(psi_input, n_npwx, tmpsi_in, tmhpsi); + this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik)); OperatorPW* node((OperatorPW*)this->next_op); while(node != nullptr) { - node->act(psi_input, n_npwx, tmpsi_in, tmhpsi); + node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik)); node = (OperatorPW*)(node->next_op); } @@ -41,18 +41,9 @@ typename OperatorPW::hpsi_info OperatorPW::hPsi( // ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size()); syncmem_complex_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size()); delete this->hpsi; - this->hpsi = new psi::Psi, Device>(hpsi_pointer, *psi_input, 1, n_npwx/psi_input->npol); + this->hpsi = new psi::Psi, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol); } - return hpsi_info(this->hpsi, psi::Range(1, 0, 0, n_npwx/psi_input->npol), hpsi_pointer); -} - -template -void OperatorPW::act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi) const -{ + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } namespace hamilt { diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h index 0c7759eb75..d7e5069eaa 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h @@ -14,13 +14,6 @@ class OperatorPW : public Operator, Device> using hpsi_info = typename hamilt::Operator, Device>::hpsi_info; virtual hpsi_info hPsi(hpsi_info& input)const; //main function which should be modified in Operator for PW base - //do operation : |hpsi_choosed> = V|psi_choosed> - //V is the target operator act on choosed psi, the consequence should be added to choosed hpsi - virtual void act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi)const; std::string classname = ""; using syncmem_complex_op = psi::memory::synchronize_memory_op, Device, Device>; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp index 9a56cc1ff1..e2d2ae79e5 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp @@ -38,22 +38,22 @@ Veff>::~Veff() template void Veff>::act( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi -)const + const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik)const { ModuleBase::timer::tick("Operator", "VeffPW"); - this->max_npw = psi_in->get_nbasis() / psi_in->npol; + int max_npw = nbasis / npol; const int current_spin = this->isk[this->ik]; - this->npol = psi_in->npol; // std::complex *porter = new std::complex[wfcpw->nmaxgr]; - for (int ib = 0; ib < n_npwx; ib += this->npol) + for (int ib = 0; ib < nbands; ib += npol) { - if (this->npol == 1) + if (npol == 1) { // wfcpw->recip2real(tmpsi_in, porter, this->ik); wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik); @@ -77,7 +77,7 @@ void Veff>::act( // std::complex *porter1 = new std::complex[wfcpw->nmaxgr]; // fft to real space and doing things. wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik); - wfcpw->recip_to_real(this->ctx, tmpsi_in + this->max_npw, this->porter1, this->ik); + wfcpw->recip_to_real(this->ctx, tmpsi_in + max_npw, this->porter1, this->ik); if(this->veff_col != 0) { /// denghui added at 20221109 @@ -102,10 +102,10 @@ void Veff>::act( } // (3) fft back to G space. wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true); - wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + this->max_npw, this->ik, true); + wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + max_npw, this->ik, true); } - tmhpsi += this->max_npw * this->npol; - tmpsi_in += this->max_npw * this->npol; + tmhpsi += max_npw * npol; + tmpsi_in += max_npw * npol; } ModuleBase::timer::tick("Operator", "VeffPW"); } @@ -120,7 +120,6 @@ hamilt::Veff>::Veff(const Veffveff_col = veff->get_veff_col(); this->veff_row = veff->get_veff_row(); this->wfcpw = veff->get_wfcpw(); - this->npol = veff->get_npol(); resmem_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr); resmem_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr); this->veff = veff->get_veff(); diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h index 638f4c5786..1cbf9c7201 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h @@ -33,18 +33,17 @@ class Veff> : public OperatorPW virtual ~Veff(); - virtual void act ( - const psi::Psi, Device> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi - )const override; + virtual void act(const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik = 0)const override; // denghui added for copy constructor at 20221105 const FPTYPE *get_veff() const {return this->veff;} int get_veff_col() const {return this->veff_col;} - int get_veff_row() const {return this->veff_row;} - int get_npol() const {return this->npol;} + int get_veff_row() const { return this->veff_row; } const int *get_isk() const {return isk;} const ModulePW::PW_Basis_K* get_wfcpw() const { @@ -53,10 +52,6 @@ class Veff> : public OperatorPW private: - mutable int max_npw = 0; - - mutable int npol = 0; - const int* isk = nullptr; const ModulePW::PW_Basis_K* wfcpw = nullptr; diff --git a/source/module_hsolver/test/diago_mock.h b/source/module_hsolver/test/diago_mock.h index d597d3aded..83f41a14c4 100644 --- a/source/module_hsolver/test/diago_mock.h +++ b/source/module_hsolver/test/diago_mock.h @@ -491,11 +491,12 @@ class OperatorMock_d : public hamilt::OperatorPW } virtual void act ( - const psi::Psi> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi - )const + const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik = 0)const { int nprocs=1, mypnum=0; #ifdef __MPI @@ -504,7 +505,7 @@ class OperatorMock_d : public hamilt::OperatorPW #endif std::complex *hpsi0 = new std::complex[DIAGOTEST::npw]; - for(int m = 0; m< n_npwx; m++) + for (int m = 0; m < nbands; m++) { for(int i=0;i } Parallel_Reduce::reduce_complex_double_pool(hpsi0, DIAGOTEST::npw); DIAGOTEST::divide_psi>(hpsi0, tmhpsi); - tmhpsi += psi_in->get_nbasis(); - tmpsi_in += psi_in->get_nbasis(); + tmhpsi += nbasis; + tmpsi_in += nbasis; } delete [] hpsi0; } @@ -535,11 +536,12 @@ class OperatorMock_f : public hamilt::OperatorPW } virtual void act ( - const psi::Psi> *psi_in, - const int n_npwx, - const std::complex* tmpsi_in, - std::complex* tmhpsi - )const + const int nbands, + const int nbasis, + const int npol, + const std::complex* tmpsi_in, + std::complex* tmhpsi, + const int ngk_ik = 0)const { int nprocs=1, mypnum=0; #ifdef __MPI @@ -548,7 +550,7 @@ class OperatorMock_f : public hamilt::OperatorPW #endif std::complex *hpsi0 = new std::complex[DIAGOTEST::npw]; - for(int m = 0; m< n_npwx; m++) + for (int m = 0; m < nbands; m++) { for(int i=0;i MPI_Allreduce(MPI_IN_PLACE, hpsi0, DIAGOTEST::npw, MPI_C_FLOAT_COMPLEX, MPI_SUM, MPI_COMM_WORLD); #endif DIAGOTEST::divide_psi_f>(hpsi0, tmhpsi); - tmhpsi += psi_in->get_nbasis(); - tmpsi_in += psi_in->get_nbasis(); + tmhpsi += nbasis; + tmpsi_in += nbasis; } delete [] hpsi0; }