Skip to content

Commit

Permalink
refactor Operator::act()
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Oct 8, 2023
1 parent f947c66 commit bf8f229
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
36 changes: 24 additions & 12 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ Operator<T, Device>::~Operator()
template<typename T, typename Device>
typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& input) const
{
ModuleBase::timer::tick("Operator", "hPsi");
using syncmem_op = psi::memory::synchronize_memory_op<T, Device, Device>;
auto psi_input = std::get<0>(input);
std::tuple<const T*, int> psi_info = psi_input->to_range(std::get<1>(input));
Expand All @@ -50,17 +49,6 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
{
ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!");
}

this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik));
Operator* node((Operator*)this->next_op);
while (node != nullptr)
{
node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik));
node = (Operator*)(node->next_op);
}

ModuleBase::timer::tick("Operator", "hPsi");

//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
T* hpsi_pointer = std::get<2>(input);
if (this->in_place)
Expand All @@ -70,9 +58,33 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
delete this->hpsi;
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
}

auto call_act = [&, this](const Operator* op) -> void {
switch (act_type)
{
case 2:
op->act(*psi_input, *this->hpsi);
break;
default:
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik));
break;
}
};

ModuleBase::timer::tick("Operator", "hPsi");
call_act(this);
Operator* node((Operator*)this->next_op);
while (node != nullptr)
{
call_act(node);
node = (Operator*)(node->next_op);
}
ModuleBase::timer::tick("Operator", "hPsi");

return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}


template<typename T, typename Device>
void Operator<T, Device>::init(const int ik_in)
{
Expand Down
11 changes: 8 additions & 3 deletions source/module_hamilt_general/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,25 @@ class Operator

///do operation : |hpsi_choosed> = V|psi_choosed>
///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi
/// interface type 1: pointer-only (default)
virtual void act(const int nbands,
const int nbasis,
const int npol,
const T* tmpsi_in,
T* tmhpsi,
const int ngk_ik = 0)const {};

/// an developer-friendly interface for act() function
virtual psi::Psi<T> act(const psi::Psi<T>& psi_in) const { return psi_in; };
/// developer-friendly interfaces for act() function
/// interface type 2: input and change the Psi-type HPsi
virtual void act(const psi::Psi<T>& psi_in, psi::Psi<T>& psi_out) const {};
/// interface type 3: return a Psi-type HPsi
// virtual psi::Psi<T> act(const psi::Psi<T>& psi_in) const { return psi_in; };

Operator* next_op = nullptr;

protected:
protected:
int ik = 0;
int act_type = 1; ///< determine which act() interface would be called in hPsi()

mutable bool in_place = false;

Expand Down

0 comments on commit bf8f229

Please sign in to comment.