Skip to content

Commit

Permalink
fix V in AX
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed May 18, 2024
1 parent 1adb47e commit c6b6902
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
7 changes: 5 additions & 2 deletions source/module_beyonddft/AX/AX_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "module_base/scalapack_connector.h"
#include "module_base/tool_title.h"
#include "module_beyonddft/utils/lr_util.h"
#include "module_beyonddft/utils/lr_util_print.h"
namespace hamilt
{
//output: col first, consistent with blas
Expand Down Expand Up @@ -45,7 +46,7 @@ namespace hamilt
int i1 = 1;
int ivirt = nocc + 1;

char transa = 'T';
char transa = 'N';
char transb = 'N';
const double alpha = 1.0;
const double beta = add_on ? 1.0 : 0.0;
Expand All @@ -54,6 +55,7 @@ namespace hamilt
c.get_pointer(), &i1, &i1, pc.desc,
&beta, Vc.data<double>(), &i1, &i1, pVc.desc);

transa = 'T';
// AX_istate = c ^ TVc
// descC puts M(nvirt) to row
pdgemm_(&transa, &transb, &nvirt, &nocc, &naos,
Expand Down Expand Up @@ -101,7 +103,7 @@ namespace hamilt
int i1 = 1;
int ivirt = nocc + 1;

char transa = 'C';
char transa = 'N';
char transb = 'N';
const std::complex<double> alpha(1.0, 0.0);
const std::complex<double> beta = add_on ? std::complex<double>(1.0, 0.0) : std::complex<double>(0.0, 0.0);
Expand All @@ -110,6 +112,7 @@ namespace hamilt
c.get_pointer(), &i1, &i1, pc.desc,
&beta, Vc.data<std::complex<double>>(), &i1, &i1, pVc.desc);

transa = 'C';
// AX_istate = c ^ TVc
// descC puts M(nvirt) to row
pzgemm_(&transa, &transb, &nvirt, &nocc, &naos,
Expand Down
11 changes: 7 additions & 4 deletions source/module_beyonddft/AX/AX_serial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace hamilt
{
for (int mu = 0;mu < naos;++mu)
{
AX_istate(i * nvirt + a) += c(nocc + a, mu) * V_istate[isk].data<double>()[mu * naos + nu] * c(i, nu);
AX_istate(i * nvirt + a) += c(nocc + a, mu) * V_istate[isk].data<double>()[nu * naos + mu] * c(i, nu);
}
}
}
Expand Down Expand Up @@ -63,7 +63,7 @@ namespace hamilt
{
for (int mu = 0;mu < naos;++mu)
{
AX_istate(i * nvirt + a) += std::conj(c(nocc + a, mu)) * std::conj(V_istate[isk].data<std::complex<double>>()[mu * naos + nu]) * c(i, nu);
AX_istate(i * nvirt + a) += std::conj(c(nocc + a, mu)) * V_istate[isk].data<std::complex<double>>()[nu * naos + mu] * c(i, nu);
}
}
}
Expand Down Expand Up @@ -92,14 +92,15 @@ namespace hamilt
// Vc[naos*nocc]
container::Tensor Vc(DAT::DT_DOUBLE, DEV::CpuDevice, { nocc, naos });// (Vc)^T
Vc.zero();
char transa = 'T';
char transa = 'N';
char transb = 'N'; //c is col major
const double alpha = 1.0;
const double beta = add_on ? 1.0 : 0.0;
dgemm_(&transa, &transb, &naos, &nocc, &naos, &alpha,
V_istate[isk].data<double>(), &naos, c.get_pointer(), &naos, &beta,
Vc.data<double>(), &naos);

transa = 'T';
//AX_istate=c^TVc (nvirt major)
dgemm_(&transa, &transb, &nvirt, &nocc, &naos, &alpha,
c.get_pointer(nocc), &naos, Vc.data<double>(), &naos, &beta,
Expand Down Expand Up @@ -127,13 +128,15 @@ namespace hamilt
// Vc[naos*nocc] (V is hermitian)
container::Tensor Vc(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { nocc, naos });// (Vc)^T
Vc.zero();
char transa = 'C';
char transa = 'N';
char transb = 'N'; //c is col major
const std::complex<double> alpha(1.0, 0.0);
const std::complex<double> beta = add_on ? std::complex<double>(1.0, 0.0) : std::complex<double>(0.0, 0.0);
zgemm_(&transa, &transb, &naos, &nocc, &naos, &alpha,
V_istate[isk].data<std::complex<double>>(), &naos, c.get_pointer(), &naos, &beta,
Vc.data<std::complex<double>>(), &naos);

transa = 'C';
//AX_istate=c^\dagger Vc (nvirt major)
zgemm_(&transa, &transb, &nvirt, &nocc, &naos, &alpha,
c.get_pointer(nocc), &naos, Vc.data<std::complex<double>>(), &naos, &beta,
Expand Down

0 comments on commit c6b6902

Please sign in to comment.