Skip to content

Commit

Permalink
move the act-based hPsi into basic Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Sep 5, 2023
1 parent ef5d620 commit 7dad605
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 53 deletions.
39 changes: 36 additions & 3 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "module_hamilt_general/operator.h"
#include "module_base/timer.h"

using namespace hamilt;

Expand Down Expand Up @@ -34,10 +35,42 @@ Operator<FPTYPE, Device>::~Operator()
}

template<typename FPTYPE, typename Device>
typename Operator<FPTYPE, Device>::hpsi_info Operator<FPTYPE, Device>::hPsi(hpsi_info&) const
typename Operator<FPTYPE, Device>::hpsi_info Operator<FPTYPE, Device>::hPsi(hpsi_info& input) const
{
ModuleBase::WARNING_QUIT("Operator::hPsi", "hPsi error!");
return hpsi_info(nullptr, 0, nullptr);
ModuleBase::timer::tick("Operator", "hPsi");
using syncmem_op = psi::memory::synchronize_memory_op<FPTYPE, Device, Device>;
auto psi_input = std::get<0>(input);
std::tuple<const FPTYPE*, int> psi_info = psi_input->to_range(std::get<1>(input));
int nbands = std::get<1>(psi_info);

FPTYPE* tmhpsi = this->get_hpsi(input);
const FPTYPE* tmpsi_in = std::get<0>(psi_info);
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
if (tmpsi_in == nullptr)
{
ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!");
}

this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik));
Operator* node((Operator*)this->next_op);
while (node != nullptr)
{
node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik));
node = (Operator*)(node->next_op);
}

ModuleBase::timer::tick("Operator", "hPsi");

//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
FPTYPE* hpsi_pointer = std::get<2>(input);
if (this->in_place)
{
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
delete this->hpsi;
this->hpsi = new psi::Psi<FPTYPE, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
}
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}

template<typename FPTYPE, typename Device>
Expand Down
5 changes: 4 additions & 1 deletion source/module_hamilt_general/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ class Operator
//this is the core function for Operator
// do H|psi> from input |psi> ,

// output of hpsi would be first member of the returned tuple
/// as default, different operators donate hPsi independently
/// run this->act function for the first operator and run all act() for other nodes in chain table
/// if this procedure is not suitable for your operator, just override this function.
/// output of hpsi would be first member of the returned tuple
typedef std::tuple<const psi::Psi<FPTYPE, Device>*, const psi::Range, FPTYPE*> hpsi_info;
virtual hpsi_info hPsi(hpsi_info& input)const;

Expand Down
39 changes: 0 additions & 39 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,6 @@ using namespace hamilt;
template<typename FPTYPE, typename Device>
OperatorPW<FPTYPE, Device>::~OperatorPW(){};

template<typename FPTYPE, typename Device>
typename OperatorPW<FPTYPE, Device>::hpsi_info OperatorPW<FPTYPE, Device>::hPsi(
hpsi_info& input) const
{
ModuleBase::timer::tick("OperatorPW", "hPsi");
auto psi_input = std::get<0>(input);
std::tuple<const std::complex<FPTYPE>*, int> psi_info = psi_input->to_range(std::get<1>(input));
int nbands = std::get<1>(psi_info);

std::complex<FPTYPE> *tmhpsi = this->get_hpsi(input);
const std::complex<FPTYPE> *tmpsi_in = std::get<0>(psi_info);
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
if(tmpsi_in == nullptr)
{
ModuleBase::WARNING_QUIT("OperatorPW", "please choose correct range of psi for hPsi()!");
}

this->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(this->ik));
OperatorPW* node((OperatorPW*)this->next_op);
while(node != nullptr)
{
node->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, tmhpsi, psi_input->get_ngk(node->ik));
node = (OperatorPW*)(node->next_op);
}

ModuleBase::timer::tick("OperatorPW", "hPsi");

//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
std::complex<FPTYPE>* hpsi_pointer = std::get<2>(input);
if(this->in_place)
{
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
syncmem_complex_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
delete this->hpsi;
this->hpsi = new psi::Psi<std::complex<FPTYPE>, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
}
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
}

namespace hamilt {
template class OperatorPW<float, psi::DEVICE_CPU>;
template class OperatorPW<double, psi::DEVICE_CPU>;
Expand Down
12 changes: 2 additions & 10 deletions source/module_hamilt_pw/hamilt_pwdft/operator_pw/operator_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@ template<typename FPTYPE, typename Device = psi::DEVICE_CPU>
class OperatorPW : public Operator<std::complex<FPTYPE>, Device>
{
public:
virtual ~OperatorPW();

//in PW code, different operators donate hPsi independently
//run this->act function for the first operator and run all act() for other nodes in chain table
using hpsi_info = typename hamilt::Operator<std::complex<FPTYPE>, Device>::hpsi_info;
virtual hpsi_info hPsi(hpsi_info& input)const;
//main function which should be modified in Operator for PW base

std::string classname = "";
using syncmem_complex_op = psi::memory::synchronize_memory_op<std::complex<FPTYPE>, Device, Device>;
virtual ~OperatorPW();
std::string classname = "";
};

}//end namespace hamilt
Expand Down

0 comments on commit 7dad605

Please sign in to comment.