diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index f758b3b79a..e48e46cfcb 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -1,4 +1,5 @@ #include "module_hamilt_general/operator.h" +#include "module_base/timer.h" using namespace hamilt; @@ -34,10 +35,42 @@ Operator::~Operator() } template -typename Operator::hpsi_info Operator::hPsi(hpsi_info&) const +typename Operator::hpsi_info Operator::hPsi(hpsi_info& input) const { - ModuleBase::WARNING_QUIT("Operator::hPsi", "hPsi error!"); - return hpsi_info(nullptr, 0, nullptr); + ModuleBase::timer::tick("Operator", "hPsi"); + using syncmem_op = psi::memory::synchronize_memory_op; + auto psi_input = std::get<0>(input); + std::tuple psi_info = psi_input->to_range(std::get<1>(input)); + int nbands = std::get<1>(psi_info); + + FPTYPE* tmhpsi = this->get_hpsi(input); + const FPTYPE* tmpsi_in = std::get<0>(psi_info); + //if range in hpsi_info is illegal, the first return of to_range() would be nullptr + if (tmpsi_in == nullptr) + { + ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!"); + } + + this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik)); + Operator* node((Operator*)this->next_op); + while (node != nullptr) + { + node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik)); + node = (Operator*)(node->next_op); + } + + ModuleBase::timer::tick("Operator", "hPsi"); + + //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return + FPTYPE* hpsi_pointer = std::get<2>(input); + if (this->in_place) + { + // ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size()); + syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size()); + delete this->hpsi; + this->hpsi = new psi::Psi(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol); + } + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } template diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 230510012f..5d44660346 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -38,7 +38,10 @@ class Operator //this is the core function for Operator // do H|psi> from input |psi> , - // output of hpsi would be first member of the returned tuple + /// as default, different operators donate hPsi independently + /// run this->act function for the first operator and run all act() for other nodes in chain table + /// if this procedure is not suitable for your operator, just override this function. + /// output of hpsi would be first member of the returned tuple typedef std::tuple*, const psi::Range, FPTYPE*> hpsi_info; virtual hpsi_info hPsi(hpsi_info& input)const; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp index 37fda9c761..7951441875 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp @@ -7,45 +7,6 @@ using namespace hamilt; template OperatorPW::~OperatorPW(){}; -template -typename OperatorPW::hpsi_info OperatorPW::hPsi( - hpsi_info& input) const -{ - ModuleBase::timer::tick("OperatorPW", "hPsi"); - auto psi_input = std::get<0>(input); - std::tuple*, int> psi_info = psi_input->to_range(std::get<1>(input)); - int nbands = std::get<1>(psi_info); - - std::complex *tmhpsi = this->get_hpsi(input); - const std::complex *tmpsi_in = std::get<0>(psi_info); - //if range in hpsi_info is illegal, the first return of to_range() would be nullptr - if(tmpsi_in == nullptr) - { - ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!"); - } - - this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik)); - OperatorPW* node((OperatorPW*)this->next_op); - while(node != nullptr) - { - node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik)); - node = (OperatorPW*)(node->next_op); - } - - ModuleBase::timer::tick("OperatorPW", "hPsi"); - - //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return - std::complex* hpsi_pointer = std::get<2>(input); - if(this->in_place) - { - // ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size()); - syncmem_complex_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size()); - delete this->hpsi; - this->hpsi = new psi::Psi, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol); - } - return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); -} - namespace hamilt { template class OperatorPW; template class OperatorPW; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h index d7e5069eaa..580c85bf47 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h @@ -7,16 +7,8 @@ template class OperatorPW : public Operator, Device> { public: - virtual ~OperatorPW(); - - //in PW code, different operators donate hPsi independently - //run this->act function for the first operator and run all act() for other nodes in chain table - using hpsi_info = typename hamilt::Operator, Device>::hpsi_info; - virtual hpsi_info hPsi(hpsi_info& input)const; - //main function which should be modified in Operator for PW base - - std::string classname = ""; - using syncmem_complex_op = psi::memory::synchronize_memory_op, Device, Device>; + virtual ~OperatorPW(); + std::string classname = ""; }; }//end namespace hamilt