diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index c85dac51bd..b6fbfb75c0 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -37,7 +37,6 @@ Operator::~Operator() template typename Operator::hpsi_info Operator::hPsi(hpsi_info& input) const { - ModuleBase::timer::tick("Operator", "hPsi"); using syncmem_op = psi::memory::synchronize_memory_op; auto psi_input = std::get<0>(input); std::tuple psi_info = psi_input->to_range(std::get<1>(input)); @@ -50,17 +49,6 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp { ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!"); } - - this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik)); - Operator* node((Operator*)this->next_op); - while (node != nullptr) - { - node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik)); - node = (Operator*)(node->next_op); - } - - ModuleBase::timer::tick("Operator", "hPsi"); - //if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return T* hpsi_pointer = std::get<2>(input); if (this->in_place) @@ -70,9 +58,33 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp delete this->hpsi; this->hpsi = new psi::Psi(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol); } + + auto call_act = [&, this](const Operator* op) -> void { + switch (act_type) + { + case 2: + op->act(*psi_input, *this->hpsi); + break; + default: + op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik)); + break; + } + }; + + ModuleBase::timer::tick("Operator", "hPsi"); + call_act(this); + Operator* node((Operator*)this->next_op); + while (node != nullptr) + { + call_act(node); + node = (Operator*)(node->next_op); + } + ModuleBase::timer::tick("Operator", "hPsi"); + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); } + template void Operator::init(const int ik_in) { diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 5634a9b549..fdddf170a6 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -53,6 +53,7 @@ class Operator ///do operation : |hpsi_choosed> = V|psi_choosed> ///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi + /// interface type 1: pointer-only (default) virtual void act(const int nbands, const int nbasis, const int npol, @@ -60,13 +61,17 @@ class Operator T* tmhpsi, const int ngk_ik = 0)const {}; - /// an developer-friendly interface for act() function - virtual psi::Psi act(const psi::Psi& psi_in) const { return psi_in; }; + /// developer-friendly interfaces for act() function + /// interface type 2: input and change the Psi-type HPsi + virtual void act(const psi::Psi& psi_in, psi::Psi& psi_out) const {}; + /// interface type 3: return a Psi-type HPsi + // virtual psi::Psi act(const psi::Psi& psi_in) const { return psi_in; }; Operator* next_op = nullptr; - protected: +protected: int ik = 0; + int act_type = 1; ///< determine which act() interface would be called in hPsi() mutable bool in_place = false;