Skip to content

Commit

Permalink
update and UT for restore_psik; add lacpy interface
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Aug 2, 2023
1 parent 9631f73 commit 707d7bc
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 52 deletions.
4 changes: 2 additions & 2 deletions source/module_base/lapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ extern "C"
void zhegs2_(int *itype, char *uplo, int *n, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);

// copies a into b
void dlacpy_(char *uplo, int *m, int *n, double* a, int *lda, double *b, int *ldb);
void zlacpy_(char *uplo, int *m, int *n, std::complex<double>* a, int *lda, std::complex<double> *b, int *ldb);
void dlacpy_(const char* uplo, const int* m, const int* n, double* a, const int* lda, double* b, const int* ldb);
void zlacpy_(const char* uplo, const int* m, const int* n, std::complex<double>* a, const int* lda, std::complex<double>* b, const int* ldb);

// generates a real elementary reflector H of order n, such that
// H * ( alpha ) = ( beta ), H is unitary.
Expand Down
6 changes: 4 additions & 2 deletions source/module_base/scalapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ extern "C"
const int *ictxt, const int *lld, int *info);

void pdpotrf_(const char* uplo, const int* n, double* a, const int* ia, const int* ja, const int* desca, int* info);
// void pzpotrf_(char *uplo, int *n, double _Complex *a, int *ia, int *ja, int *desca, int *info);

void pzpotrf_(const char* uplo, const int* n, std::complex<double>* a, const int* ia, const int* ja, const int* desca, int* info);

void pdpotri_(const char* uplo, const int* n, double* a, const int* ia, const int* ja, const int* desca, int* info);

void pzpotri_(const char* uplo, const int* n, std::complex<double>* a, const int* ia, const int* ja, const int* desca, int* info);

void pdtran_(int *m , int *n ,
void pzlacpy_(const char* uplo, const int* m, const int* n, const std::complex<double>* a, const int* ia, const int* ja, const int* desca, std::complex<double>* b, const int* ib, const int* jb, const int* descb);

void pdtran_(int* m, int* n,
double *alpha , double *a , int *ia , int *ja , int *desca ,
double *beta , double *c , int *ic , int *jc , int *descc );

Expand Down
117 changes: 80 additions & 37 deletions source/module_ri/exx_symmetry.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#ifdef __EXX
#include "exx_symmetry.h"
#include <utility>
#include "module_psi/psi.h"
#include "module_cell/module_symmetry/symmetry.h"
#ifdef __MPI
#include <mpi.h>
#include "module_base/scalapack_connector.h"
#else
#include "module_base/lapack_connector.h"
#endif
#include "module_base/lapack_connector.h"

namespace ExxSym
{

std::vector<std::vector<std::complex<double>>> rearange_smat(
const int ikibz,
const std::vector<std::complex<double>> sloc_ikibz,
Expand Down Expand Up @@ -87,72 +84,118 @@ namespace ExxSym
return fullmat;
}

psi::Psi<std::complex<double>, psi::DEVICE_CPU> restore_psik(
psi::Psi<std::complex<double>, psi::DEVICE_CPU> restore_psik_lapack(
const int& ikibz,
const std::vector<std::complex<double>>& psi_ikibz,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& psi_ikibz,
const std::vector<std::complex<double>>& sloc_ikibz,
const std::vector<std::vector<std::complex<double>>>& sloc_ik,
const int& nbasis,
const int& nbands)
{
int nkstar = sloc_ik.size();
psi::Psi<std::complex<double>, psi::DEVICE_CPU> c_k(nkstar, nbands, nbasis);
psi::Psi<std::complex<double>, psi::DEVICE_CPU> tmpSc(1, nbands, nbasis);
// only col_maj is considered now
// 1. S(gk)c_gk
char transa = 'N';
char transb = 'N';
std::complex<double> alpha = 1.0;
std::complex<double> beta = 0.0;

zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis,
&alpha, sloc_ikibz.data(), &nbasis, psi_ikibz.get_pointer(), &nbasis,
&beta, tmpSc.get_pointer(), &nbasis);

//2. c_k = S^{-1}(k)S(gk)c_{gk} for each k
int ik = 0;
for (auto sk : sloc_ik)// copy, not reference: sk will be replaced by S^{-1}(k) after 2.1 (sk is const)
{
c_k.fix_k(ik);

// 2.1 S^{-1}(k)
char uplo = 'U';
int info = -1;
zpotrf_(&uplo, &nbasis, sk.data(), &nbasis, &info);
if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k).(info=" + std::to_string(info) + ").");
zpotri_(&uplo, &nbasis, sk.data(), &nbasis, &info);
if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when calculating inv(S(k)).(info=" + std::to_string(info) + ").");
//transpose and copy the upper triangle
std::vector<std::complex<double>> invsk(sk.size());
std::vector<std::complex<double>> ones(sk.size(), 0);
for (int i = 0;i < nbasis;i++) ones[i * nbasis + i] = std::complex<double>(1, 0);
char t = 'T';
zgemm_(&t, &t, &nbasis, &nbasis, &nbasis, &alpha, sk.data(), &nbasis, ones.data(), &nbasis, &beta, invsk.data(), &nbasis);
zlacpy_(&uplo, &nbasis, &nbasis, sk.data(), &nbasis, invsk.data(), &nbasis);

//2.2 S^{-1}(k) * S(gk)c_{gk}
zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis,
&alpha, invsk.data(), &nbasis, tmpSc.get_pointer(), &nbasis,
&beta, c_k.get_pointer(), &nbasis);
++ik;
}
return c_k;
}

#ifdef __MPI
psi::Psi<std::complex<double>, psi::DEVICE_CPU> restore_psik_scalapack(
const int& ikibz,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& psi_ikibz,
const std::vector<std::complex<double>>& sloc_ikibz,
const std::vector<std::vector<std::complex<double>>>& sloc_ik,
const int& nbasis,
const int& nbands,
const Parallel_Orbitals& pv,
const bool col_inside)
const Parallel_Orbitals& pv)
{
int kstar_size = sloc_ik.size();
psi::Psi<std::complex<double>, psi::DEVICE_CPU> c_k(kstar_size, pv.ncol_bands, pv.get_row_size());
int nkstar = sloc_ik.size();
psi::Psi<std::complex<double>, psi::DEVICE_CPU> c_k(nkstar, pv.ncol_bands, pv.get_row_size());
psi::Psi<std::complex<double>, psi::DEVICE_CPU> tmpSc(1, pv.ncol_bands, pv.get_row_size());
// only col_maj is considered now
// 1. S(gk)c_gk
// 1. S(gk)c_gk
char transa = 'N';
char transb = 'N';
std::complex<double> alpha = 1.0;
std::complex<double> beta = 0.0;
#ifdef __MPI
int i1 = 1;
pzgemm_(&transa, &transb, &nbasis, &nbands, &nbasis,
&alpha, sloc_ikibz.data(), &i1, &i1, pv.desc,
psi_ikibz.data(), &i1, &i1, pv.desc_wfc, &beta,
psi_ikibz.get_pointer(), &i1, &i1, pv.desc_wfc, &beta,
tmpSc.get_pointer(), &i1, &i1, pv.desc_wfc);
#else
zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis,
&alpha, sloc_ikibz.data(), &nbasis, psi_ikibz.data(), &nbasis,
&beta, tmpSc.get_pointer(), &nbasis);
#endif

//2. c_k = S^{-1}(k)S(gk)c_{gk} for each k
int ik = 0;
for (auto sk : sloc_ik)// copy, not reference: sk will be replaced by S^{-1}(k) after 2.1 (sk is const)
{
c_k.fix_k(ik);

//
// std::vector<std::complex<double>> invsk = sk;

// 2.1 S^{-1}(k)
char uplo = 'U';
int info = -1;
#ifdef __MPI
pzpotrf_(&uplo, &nbasis, sk.data(), &i1, &i1, pv.desc, &info);
if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k).");
if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k).(info=" + std::to_string(info) + ").");
pzpotri_(&uplo, &nbasis, sk.data(), &i1, &i1, pv.desc, &info);
#else
zpotrf_(&uplo, &nbasis, sk.data(), &nbasis, &info);
if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when factorizing S(k).");
zpotri_(&uplo, &nbasis, sk.data(), &nbasis, &info);
#endif
if (info != 0) ModuleBase::WARNING_QUIT("restore_psik", "Error when calculating inv(S(k)).(info=" + std::to_string(info) + ").");
//transpose and copy the upper triangle
std::vector<std::complex<double>> invsk(sk.size());
std::vector<std::complex<double>> ones(sk.size(), 0); //row-major
for (int i = 0;i < nbasis;++i)
if (pv.in_this_processor(i, i))
ones[pv.global2local_col(i) * pv.get_row_size() + pv.global2local_row(i)] = std::complex<double>(1, 0);
char t = 'T';
pzgemm_(&t, &t, &nbasis, &nbasis, &nbasis,
&alpha, sk.data(), &i1, &i1, pv.desc,
ones.data(), &i1, &i1, pv.desc, &beta,
invsk.data(), &i1, &i1, pv.desc);
pzlacpy_(&uplo, &nbasis, &nbasis, sk.data(), &i1, &i1, pv.desc, invsk.data(), &i1, &i1, pv.desc);

//2.2 S^{-1}(k) * S(gk)c_{gk}
#ifdef __MPI
pzgemm_(&transa, &transb, &nbasis, &nbands, &nbasis,
&alpha, sk.data(), &i1, &i1, pv.desc,
&alpha, invsk.data(), &i1, &i1, pv.desc,
tmpSc.get_pointer(), &i1, &i1, pv.desc_wfc, &beta,
c_k.get_pointer(), &i1, &i1, pv.desc_wfc);
#else
zgemm_(&transa, &transb, &nbasis, &nbands, &nbasis,
&alpha, sk.data(), &nbasis, tmpSc.get_pointer(), &nbasis,
&beta, c_k.get_pointer(), &nbasis);
#endif
++ik;
}
return c_k;
}
}
#endif
#endif

}
22 changes: 14 additions & 8 deletions source/module_ri/exx_symmetry.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#ifdef __EXX
#pragma once
#include <vector>
#include "module_basis/module_ao/parallel_orbitals.h"
Expand Down Expand Up @@ -34,21 +33,28 @@ namespace ExxSym
/// @param nbands [in] global number of bands
/// @param pv [in] parallel orbitals (for both matrix and wavefunction)
/// @param col_inside [in] whether the matrix is column-major (major means memory continuity)
/// @return
psi::Psi<std::complex<double>, psi::DEVICE_CPU> restore_psik(
/// @return c_k: wavefunction of each k in kstars[ikibz]
#ifdef __MPI
psi::Psi<std::complex<double>, psi::DEVICE_CPU> restore_psik_scalapack(
const int& ikibz,
const std::vector<std::complex<double>>& psi_ikibz,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& psi_ikibz,
const std::vector<std::complex<double>>& sloc_ikibz,
const std::vector<std::vector<std::complex<double>>>& sloc_ik,
const int& nbasis,
const int& nbands,
const Parallel_Orbitals& pv,
const bool col_inside);
const Parallel_Orbitals& pv);
#endif
psi::Psi<std::complex<double>, psi::DEVICE_CPU> restore_psik_lapack(
const int& ikibz,
const psi::Psi<std::complex<double>, psi::DEVICE_CPU>& psi_ikibz,
const std::vector<std::complex<double>>& sloc_ikibz,
const std::vector<std::vector<std::complex<double>>>& sloc_ik,
const int& nbasis,
const int& nbands);

std::vector<std::complex<double>> get_full_smat(
const std::vector<std::complex<double>>& locmat,
const int& nbasis,
const Parallel_2D& p2d,
const bool col_inside);
}
#endif
}
Loading

0 comments on commit 707d7bc

Please sign in to comment.