Skip to content

Commit

Permalink
CVCX-type AX in R of Z-vector eq
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed May 22, 2024
1 parent c6b6902 commit ac629ca
Show file tree
Hide file tree
Showing 8 changed files with 1,011 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ if(ENABLE_LCAO)
deltaspin
numerical_atomic_orbitals
lr
lr_grad
)
if (USE_ELPA)
target_link_libraries(${ABACUS_BIN_NAME}
Expand Down
3 changes: 2 additions & 1 deletion source/module_beyonddft/Grad/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(dm_diff)
add_subdirectory(dm_diff)
add_subdirectory(CVCX)
11 changes: 11 additions & 0 deletions source/module_beyonddft/Grad/CVCX/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
if(ENABLE_LCAO)
add_library(
lr_grad
OBJECT
CVCX_parallel.cpp
CVCX_serial.cpp
)
if(BUILD_TESTING)
add_subdirectory(test)
endif()
endif()
82 changes: 82 additions & 0 deletions source/module_beyonddft/Grad/CVCX/CVCX.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#pragma once
#include <ATen/core/tensor.h>
#include "module_psi/psi.h"
#include <vector>
#ifdef __MPI
#include "module_basis/module_ao/parallel_2d.h"
#endif
namespace hamilt
{
// occ
/// $\sum_{k\mu\nu}C^*_{\mu i}K_{\mu\nu}C_{\nu k}X_{ak}^*$
template <typename T>
void CVCX_occ_forloop_serial(
const std::vector<container::Tensor>& V_istate,
const psi::Psi<T, psi::DEVICE_CPU>& c,
const psi::Psi<T, psi::DEVICE_CPU>& X_istate,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<T, psi::DEVICE_CPU>& AX_istate);
template <typename T>
void CVCX_occ_blas(
const std::vector<container::Tensor>& V_istate,
const psi::Psi<T, psi::DEVICE_CPU>& c,
const psi::Psi<T, psi::DEVICE_CPU>& X_istate,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<T, psi::DEVICE_CPU>& AX_istate,
const bool add_on = true);
#ifdef __MPI
template <typename T>
void CVCX_occ_pblas(
const std::vector<container::Tensor>& V_istate,
const Parallel_2D& pmat,
const psi::Psi<T, psi::DEVICE_CPU>& c,
const Parallel_2D& pc,
const psi::Psi<T, psi::DEVICE_CPU>& X_istate,
const Parallel_2D& px,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<T, psi::DEVICE_CPU>& AX_istate,
const bool add_on = true);
#endif
// virt
/// $\sum_{b\mu\nu}X^*_{bi}C^*_{\mu b}K_{\mu\nu}C_{\nu a}$
template <typename T>
void CVCX_virt_forloop_serial(
const std::vector<container::Tensor>& V_istate,
const psi::Psi<T, psi::DEVICE_CPU>& c,
const psi::Psi<T, psi::DEVICE_CPU>& X_istate,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<T, psi::DEVICE_CPU>& AX_istate);
template <typename T>
void CVCX_virt_blas(
const std::vector<container::Tensor>& V_istate,
const psi::Psi<T, psi::DEVICE_CPU>& c,
const psi::Psi<T, psi::DEVICE_CPU>& X_istate,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<T, psi::DEVICE_CPU>& AX_istate,
const bool add_on = true);
#ifdef __MPI
template <typename T>
void CVCX_virt_pblas(
const std::vector<container::Tensor>& V_istate,
const Parallel_2D& pmat,
const psi::Psi<T, psi::DEVICE_CPU>& c,
const Parallel_2D& pc,
const psi::Psi<T, psi::DEVICE_CPU>& X_istate,
const Parallel_2D& px,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<T, psi::DEVICE_CPU>& AX_istate,
const bool add_on = true);
#endif
}
264 changes: 264 additions & 0 deletions source/module_beyonddft/Grad/CVCX/CVCX_parallel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
#ifdef __MPI
#include "CVCX.h"
#include "module_base/scalapack_connector.h"
#include "module_base/tool_title.h"
#include "module_beyonddft/utils/lr_util.h"
#include "module_beyonddft/utils/lr_util_print.h"
namespace hamilt
{
template <>
void CVCX_occ_pblas(
const std::vector<container::Tensor>& V_istate,
const Parallel_2D& pmat,
const psi::Psi<double, psi::DEVICE_CPU>& c,
const Parallel_2D& pc,
const psi::Psi<double, psi::DEVICE_CPU>& X_istate,
const Parallel_2D& px,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<double, psi::DEVICE_CPU>& AX_istate,
const bool add_on)
{
ModuleBase::TITLE("hamilt_lrtd", "CVCX_occ_pblas");
assert(pmat.comm_2D == pc.comm_2D);
assert(pmat.comm_2D == px.comm_2D);
assert(pmat.blacs_ctxt == pc.blacs_ctxt);
assert(pmat.blacs_ctxt == px.blacs_ctxt);
assert(px.get_local_size() > 0 && AX_istate.get_nbasis() == px.get_local_size());

int nks = c.get_nk();
assert(V_istate.size() == nks);

Parallel_2D pcv;
LR_Util::setup_2d_division(pcv, pmat.get_block_size(), nocc, naos, pmat.comm_2D, pmat.blacs_ctxt);
Parallel_2D pcx;
LR_Util::setup_2d_division(pcx, pmat.get_block_size(), naos, nvirt, pmat.comm_2D, pmat.blacs_ctxt);
for (int isk = 0;isk < nks;++isk)
{
AX_istate.fix_k(isk);
X_istate.fix_k(isk);
c.fix_k(isk);

const int i1 = 1;
const int ivirt = nocc + 1;
const char trans = 'T';
const char notrans = 'N'; //c is col major
const double one = 1.0;
const double zero = 0.0;

// c^TV[nocc*naos]
container::Tensor cv(DAT::DT_DOUBLE, DEV::CpuDevice, { pcv.get_col_size(), pcv.get_row_size() });
pdgemm_(&trans, &notrans, &nocc, &naos, &naos,
&one, c.get_pointer(), &i1, &i1, pc.desc,
V_istate[isk].data<double>(), &i1, &i1, pmat.desc,
&zero, cv.data<double>(), &i1, &i1, pcv.desc);

// cX^T[naos*nvirt]
container::Tensor cx(DAT::DT_DOUBLE, DEV::CpuDevice, { pcx.get_col_size(), pcx.get_row_size() });
pdgemm_(&notrans, &trans, &naos, &nvirt, &nocc,
&one, c.get_pointer(), &i1, &i1, pc.desc,
X_istate.get_pointer(), &i1, &i1, px.desc,
&zero, cx.data<double>(), &i1, &i1, pcx.desc);

//AX_istate=[cX^T]^T[c^TV]^T (nvirt major)
pdgemm_(&trans, &trans, &nvirt, &nocc, &naos,
&one, cx.data<double>(), &i1, &i1, pcx.desc,
cv.data<double>(), &i1, &i1, pcv.desc,
add_on ? &one : &zero, AX_istate.get_pointer(), &i1, &i1, px.desc);
}
}

template <>
void CVCX_occ_pblas(
const std::vector<container::Tensor>& V_istate,
const Parallel_2D& pmat,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& c,
const Parallel_2D& pc,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& X_istate,
const Parallel_2D& px,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<std::complex<double>, psi::DEVICE_CPU>& AX_istate,
const bool add_on)
{
ModuleBase::TITLE("hamilt_lrtd", "CVCX_occ_pblas");
assert(pmat.comm_2D == pc.comm_2D);
assert(pmat.comm_2D == px.comm_2D);
assert(pmat.blacs_ctxt == pc.blacs_ctxt);
assert(pmat.blacs_ctxt == px.blacs_ctxt);
assert(px.get_local_size() > 0 && AX_istate.get_nbasis() == px.get_local_size());

int nks = c.get_nk();
assert(V_istate.size() == nks);

Parallel_2D pcv;
LR_Util::setup_2d_division(pcv, pmat.get_block_size(), nocc, naos, pmat.comm_2D, pmat.blacs_ctxt);
Parallel_2D pcx;
LR_Util::setup_2d_division(pcx, pmat.get_block_size(), naos, nvirt, pmat.comm_2D, pmat.blacs_ctxt);
for (int isk = 0;isk < nks;++isk)
{
AX_istate.fix_k(isk);
X_istate.fix_k(isk);
c.fix_k(isk);

const int i1 = 1;
const int ivirt = nocc + 1;
const char trans = 'T';
const char dagger = 'C';
const char notrans = 'N'; //c is col major
const std::complex<double> one(1.0, 0.0);
const std::complex<double> zero(0.0, 0.0);

// c^TV[nocc*naos]
container::Tensor cv(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pcv.get_col_size(), pcv.get_row_size() });
pzgemm_(&dagger, &notrans, &nocc, &naos, &naos,
&one, c.get_pointer(), &i1, &i1, pc.desc,
V_istate[isk].data<std::complex<double>>(), &i1, &i1, pmat.desc,
&zero, cv.data<std::complex<double>>(), &i1, &i1, pcv.desc);

// cX^T[naos*nvirt]
container::Tensor cx(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pcx.get_col_size(), pcx.get_row_size() });
pzgemm_(&notrans, &dagger, &naos, &nvirt, &nocc,
&one, c.get_pointer(), &i1, &i1, pc.desc,
X_istate.get_pointer(), &i1, &i1, px.desc,
&zero, cx.data<std::complex<double>>(), &i1, &i1, pcx.desc);

//AX_istate=[cX^T]^T[c^TV]^T (nvirt major)
pzgemm_(&trans, &trans, &nvirt, &nocc, &naos,
&one, cx.data<std::complex<double>>(), &i1, &i1, pcx.desc,
cv.data<std::complex<double>>(), &i1, &i1, pcv.desc,
add_on ? &one : &zero, AX_istate.get_pointer(), &i1, &i1, px.desc);
}
}

template <>
void CVCX_virt_pblas(
const std::vector<container::Tensor>& V_istate,
const Parallel_2D& pmat,
const psi::Psi<double, psi::DEVICE_CPU>& c,
const Parallel_2D& pc,
const psi::Psi<double, psi::DEVICE_CPU>& X_istate,
const Parallel_2D& px,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<double, psi::DEVICE_CPU>& AX_istate,
const bool add_on)
{
ModuleBase::TITLE("hamilt_lrtd", "CVCX_virt_pblas");
assert(pmat.comm_2D == pc.comm_2D);
assert(pmat.comm_2D == px.comm_2D);
assert(pmat.blacs_ctxt == pc.blacs_ctxt);
assert(pmat.blacs_ctxt == px.blacs_ctxt);
assert(px.get_local_size() > 0 && AX_istate.get_nbasis() == px.get_local_size());

int nks = c.get_nk();
assert(V_istate.size() == nks);

Parallel_2D pcv;
LR_Util::setup_2d_division(pcv, pmat.get_block_size(), naos, nvirt, pmat.comm_2D, pmat.blacs_ctxt);
Parallel_2D pcx;
LR_Util::setup_2d_division(pcx, pmat.get_block_size(), nocc, naos, pmat.comm_2D, pmat.blacs_ctxt);
for (int isk = 0;isk < nks;++isk)
{
AX_istate.fix_k(isk);
X_istate.fix_k(isk);
c.fix_k(isk);

const int i1 = 1;
const int ivirt = nocc + 1;
const char trans = 'T';
const char dagger = 'C';
const char notrans = 'N'; //c is col major
const double one = 1.0;
const double zero = 0.0;

// VC[naos*nvirt]
container::Tensor cv(DAT::DT_DOUBLE, DEV::CpuDevice, { pcv.get_col_size(), pcv.get_row_size() });
pdgemm_(&notrans, &notrans, &naos, &nvirt, &naos,
&one, V_istate[isk].data<double>(), &i1, &i1, pmat.desc,
c.get_pointer(), &i1, &ivirt, pc.desc,
&zero, cv.data<double>(), &i1, &i1, pcv.desc);

// X^TC^T[nocc*naos]
container::Tensor cx(DAT::DT_DOUBLE, DEV::CpuDevice, { pcx.get_col_size(), pcx.get_row_size() });
pdgemm_(&dagger, &dagger, &nocc, &naos, &nvirt,
&one, X_istate.get_pointer(), &i1, &i1, px.desc,
c.get_pointer(), &i1, &ivirt, pc.desc,
&zero, cx.data<double>(), &i1, &i1, pcx.desc);

//AX_istate=[VC]^T[X^TC^T]^T (nvirt major)
pdgemm_(&trans, &trans, &nvirt, &nocc, &naos,
&one, cv.data<double>(), &i1, &i1, pcv.desc,
cx.data<double>(), &i1, &i1, pcx.desc,
add_on ? &one : &zero, AX_istate.get_pointer(), &i1, &i1, px.desc);
}
}

template <>
void CVCX_virt_pblas(
const std::vector<container::Tensor>& V_istate,
const Parallel_2D& pmat,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& c,
const Parallel_2D& pc,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& X_istate,
const Parallel_2D& px,
const int& naos,
const int& nocc,
const int& nvirt,
psi::Psi<std::complex<double>, psi::DEVICE_CPU>& AX_istate,
const bool add_on)
{
ModuleBase::TITLE("hamilt_lrtd", "CVCX_virt_pblas");
assert(pmat.comm_2D == pc.comm_2D);
assert(pmat.comm_2D == px.comm_2D);
assert(pmat.blacs_ctxt == pc.blacs_ctxt);
assert(pmat.blacs_ctxt == px.blacs_ctxt);
assert(px.get_local_size() > 0 && AX_istate.get_nbasis() == px.get_local_size());

int nks = c.get_nk();
assert(V_istate.size() == nks);

Parallel_2D pcv;
LR_Util::setup_2d_division(pcv, pmat.get_block_size(), naos, nvirt, pmat.comm_2D, pmat.blacs_ctxt);
Parallel_2D pcx;
LR_Util::setup_2d_division(pcx, pmat.get_block_size(), nocc, naos, pmat.comm_2D, pmat.blacs_ctxt);
for (int isk = 0;isk < nks;++isk)
{
AX_istate.fix_k(isk);
X_istate.fix_k(isk);
c.fix_k(isk);

const int i1 = 1;
const int ivirt = nocc + 1;
const char trans = 'T';
const char dagger = 'C';
const char notrans = 'N'; //c is col major
const std::complex<double> one(1.0, 0.0);
const std::complex<double> zero(0.0, 0.0);

// VC[naos*nvirt]
container::Tensor cv(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pcv.get_col_size(), pcv.get_row_size() });
pzgemm_(&notrans, &notrans, &naos, &nvirt, &naos,
&one, V_istate[isk].data<std::complex<double>>(), &i1, &i1, pmat.desc,
c.get_pointer(), &i1, &ivirt, pc.desc,
&zero, cv.data<std::complex<double>>(), &i1, &i1, pcv.desc);

// X^TC^T[nocc*naos]
container::Tensor cx(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { pcx.get_col_size(), pcx.get_row_size() });
pzgemm_(&dagger, &dagger, &nocc, &naos, &nvirt,
&one, X_istate.get_pointer(), &i1, &i1, px.desc,
c.get_pointer(), &i1, &ivirt, pc.desc,
&zero, cx.data<std::complex<double>>(), &i1, &i1, pcx.desc);

//AX_istate=[VC]^T[X^TC^T]^T (nvirt major)
pzgemm_(&trans, &trans, &nvirt, &nocc, &naos,
&one, cv.data<std::complex<double>>(), &i1, &i1, pcv.desc,
cx.data<std::complex<double>>(), &i1, &i1, pcx.desc,
add_on ? &one : &zero, AX_istate.get_pointer(), &i1, &i1, px.desc);
}
}
}
#endif
Loading

0 comments on commit ac629ca

Please sign in to comment.