Skip to content

Commit

Permalink
Feature : Add deepks_v_delta, which can help to train DeepKS model wi…
Browse files Browse the repository at this point in the history
…th loss term about Hamiltonian, psi and band (deepmodeling#4594)

* Add deepks_v_delta, which can help to train DeepKS model with loss term about Hamiltonian, psi and band.
    When deepks_out_labels equals 1 , it can output labels about Hamiltonian. Meanwhile, it will output v_delta_precalc when deepks_v_delta equals 1, and output psialpha and grad_evdm to save disk memory when deepks_v_delta equal 2.

* debug: I used the index of nlm_save incorrectly

* add integration test

* add check functions for unit test

* add doc for INPUT parameter deepks_v_delta

* [pre-commit.ci lite] apply automatic fixes

* transfer datatype of h_mat from double* to std::vector

* use mn_size to prevent calculate many times; initilize iic; avoid using GlobalV::NLOCAL in out_deepks_labels

* modify the integration test into a smaller example

* [pre-commit.ci lite] apply automatic fixes

* add const

* Fix: bug in merge

* place model file and jle.orb in Model_ProjOrb to save memory

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent 3deedee commit 796c911
Show file tree
Hide file tree
Showing 44 changed files with 1,158 additions and 17,938 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ toolchain.tar.gz
time.json
*.pyc
__pycache__
abacus.json
abacus.json
*.npy
9 changes: 9 additions & 0 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
- [bessel\_descriptor\_smooth](#bessel_descriptor_smooth)
- [bessel\_descriptor\_sigma](#bessel_descriptor_sigma)
- [deepks\_bandgap](#deepks_bandgap)
- [deepks\_v\_delta](#deepks_v_delta)
- [deepks\_out\_unittest](#deepks_out_unittest)
- [OFDFT: orbital free density functional theory](#ofdft-orbital-free-density-functional-theory)
- [of\_kinetic](#of_kinetic)
Expand Down Expand Up @@ -1943,6 +1944,14 @@ Warning: this function is not robust enough for the current version. Please try
- **Description**: include bandgap label for DeePKS training
- **Default**: False

### deepks_v_delta

- **Type**: int
- **Availability**: numerical atomic orbital basis
- **Description**: Include V_delta label for DeePKS training. When `deepks_out_labels` is true and `deepks_v_delta` > 0, ABACUS will output h_base.npy, v_delta.npy and h_tot.npy(h_tot=h_base+v_delta).
Meanwhile, when `deepks_v_delta` equals 1, ABACUS will also output v_delta_precalc.npy, which is used to calculate V_delta during DeePKS training. However, when the number of atoms grows, the size of v_delta_precalc.npy will be very large. In this case, it's recommended to set `deepks_v_delta` as 2, and ABACUS will output psialpha.npy and grad_evdm.npy but not v_delta_precalc.npy. These two files are small and can be used to calculate v_delta_precalc in the procedure of training DeePKS.
- **Default**: 0

### deepks_out_unittest

- **Type**: Boolean
Expand Down
21 changes: 11 additions & 10 deletions source/module_base/global_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int out_ndigits = 8;
std::string DFT_FUNCTIONAL = "default";
double XC_TEMPERATURE = 0.0;
int NSPIN = 1; // LDA
bool TWO_EFERMI = 0; // two fermi energy, exist only magnetization is fixed.
bool TWO_EFERMI = false; // two fermi energy, exist only magnetization is fixed.
double nupdown = 0.0;
int CURRENT_K = 0;
int CAL_FORCE = 0; // if cal_force >1, means do the grid integration 'cal_force' times.
Expand Down Expand Up @@ -86,9 +86,9 @@ int NQX = 10000; // number of points describing reciprocal radial tab
int NQXQ = 10000; // number of points describing reciprocal radial tab for Q

int NURSE = 0; // used for debug.
bool COLOUR = 0;
bool GAMMA_ONLY_LOCAL = 0; // mohan add 2010-10-20
bool GAMMA_ONLY_PW = 0; // mohan add 2012-06-05
bool COLOUR = false;
bool GAMMA_ONLY_LOCAL = false; // mohan add 2010-10-20
bool GAMMA_ONLY_PW = false; // mohan add 2012-06-05

int T_IN_H = 1; // mohan add 2010-11-28
int VL_IN_H = 1;
Expand Down Expand Up @@ -202,8 +202,9 @@ double soc_lambda = 1.0;
bool FINAL_SCF = false; // LiuXh add 20180619

bool deepks_out_labels = false; // caoyu add 2021-10-16 for DeePKS, wenfei 2022-1-16
bool deepks_scf = false; // caoyu add 2021-10-16 for DeePKS, wenfei 2022-1-16
bool deepks_bandgap = false; // for bandgap label. QO added 2021-12-15
bool deepks_scf = false; // caoyu add 2021-10-16 for DeePKS, wenfei 2022-1-16
bool deepks_bandgap = false; // for bandgap label. QO added 2021-12-15
int deepks_v_delta = 0; // for v_delta label. xinyuan added 2023-2-15
bool deepks_out_unittest = false;

bool deepks_equiv = false;
Expand Down Expand Up @@ -251,8 +252,8 @@ double MIXING_BETA_MAG = 1.6;
double MIXING_GG0_MAG = 1.00;
double MIXING_GG0_MIN = 0.1;
double MIXING_ANGLE = 0.0;
bool MIXING_TAU = 0;
bool MIXING_DMR = 0;
bool MIXING_TAU = false;
bool MIXING_DMR = false;

//==========================================================
// device flags added by denghui
Expand Down Expand Up @@ -280,8 +281,8 @@ int out_interval = 1; // convert from out_hsR_interval liuyu 2023-04-18
//==========================================================
// Deltaspin related
//==========================================================
bool sc_mag_switch = 0;
bool decay_grad_switch = 0;
bool sc_mag_switch = false;
bool decay_grad_switch = false;
double sc_thr = 1.0e-6;
int nsc = 100;
int nsc_min = 2;
Expand Down
4 changes: 3 additions & 1 deletion source/module_base/global_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ extern bool deepks_scf; //(need libnpy and libtorch) if set 1, a trained
// would be needed to cal V_delta and F_delta
extern bool deepks_bandgap; // for bandgap label. QO added 2021-12-15

extern bool deepks_equiv; // whether to use equviariant version of DeePKS
extern int deepks_v_delta; // for v_delta label. xinyuan added 2023-2-15

extern bool deepks_equiv; //whether to use equviariant version of DeePKS

extern bool deepks_setorb;

Expand Down
14 changes: 11 additions & 3 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,13 +878,13 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
// 1) print Hamiltonian and Overlap matrix
if (this->conv_elec || iter == GlobalV::SCF_NMAX)
{
if (!GlobalV::GAMMA_ONLY_LOCAL && hsolver::HSolverLCAO<TK>::out_mat_hs[0])
if (!GlobalV::GAMMA_ONLY_LOCAL && (hsolver::HSolverLCAO<TK>::out_mat_hs[0] || GlobalV::deepks_v_delta))
{
this->GK.renew(true);
}
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
{
if (hsolver::HSolverLCAO<TK>::out_mat_hs[0])
if (hsolver::HSolverLCAO<TK>::out_mat_hs[0]|| GlobalV::deepks_v_delta)
{
this->p_hamilt->updateHk(ik);
}
Expand Down Expand Up @@ -920,6 +920,12 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
this->ParaV,
GlobalV::DRANK);
}
#ifdef __DEEPKS
if(GlobalV::deepks_v_delta)
{
GlobalC::ld.save_h_mat(h_mat.p,this->ParaV.nloc);
}
#endif
}
}
}
Expand Down Expand Up @@ -1180,14 +1186,16 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
LDI.out_deepks_labels(this->pelec->f_en.etot,
this->pelec->klist->get_nks(),
GlobalC::ucell.nat,
GlobalV::NLOCAL,
this->pelec->ekb,
this->pelec->klist->kvec_d,
GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
&(this->ParaV),
*(this->psi),
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM());
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
GlobalV::deepks_v_delta);

ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
#endif
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/module_deepks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ if(ENABLE_DEEPKS)
LCAO_deepks_psialpha.cpp
LCAO_deepks_torch.cpp
LCAO_deepks_vdelta.cpp
LCAO_deepks_hmat.cpp
LCAO_deepks_interface.cpp
)

Expand Down
64 changes: 64 additions & 0 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ LCAO_Deepks::~LCAO_Deepks()
}

del_gdmx();

}

void LCAO_Deepks::init(
Expand Down Expand Up @@ -140,6 +141,16 @@ void LCAO_Deepks::init(

this->pv = &pv_in;

if(GlobalV::deepks_v_delta)
{
//allocate and init h_mat
if(GlobalV::GAMMA_ONLY_LOCAL)
{
int nloc=this->pv->nloc;
this->h_mat.resize(nloc,0.0);
}
}

return;
}

Expand Down Expand Up @@ -417,4 +428,57 @@ void LCAO_Deepks::del_orbital_pdm_shell(const int nks)
return;
}

void LCAO_Deepks::init_v_delta_pdm_shell(const int nks,const int nlocal)
{

this->v_delta_pdm_shell = new double**** [nks];

const int mn_size=(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1);
for (int iks=0; iks<nks; iks++)
{
this->v_delta_pdm_shell[iks] = new double*** [nlocal];

for (int mu=0; mu<nlocal; mu++)
{
this->v_delta_pdm_shell[iks][mu] = new double** [nlocal];

for (int nu=0; nu<nlocal; nu++)
{
this->v_delta_pdm_shell[iks][mu][nu] = new double* [this->inlmax];

for(int inl = 0; inl < this->inlmax; inl++)
{
this->v_delta_pdm_shell[iks][mu][nu][inl] = new double [mn_size];
ModuleBase::GlobalFunc::ZEROS(v_delta_pdm_shell[iks][mu][nu][inl], mn_size);
}
}
}
}

return;
}

void LCAO_Deepks::del_v_delta_pdm_shell(const int nks,const int nlocal)
{
for (int iks=0; iks<nks; iks++)
{
for (int mu=0; mu<nlocal; mu++)
{
for (int nu=0; nu<nlocal; nu++)
{
for (int inl = 0;inl < this->inlmax; inl++)
{
delete[] this->v_delta_pdm_shell[iks][mu][nu][inl];
}
delete[] this->v_delta_pdm_shell[iks][mu][nu];
}
delete[] this->v_delta_pdm_shell[iks][mu];
}
delete[] this->v_delta_pdm_shell[iks];
}
delete[] this->v_delta_pdm_shell;

return;
}

#endif
63 changes: 63 additions & 0 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class LCAO_Deepks
///\rho_{HL} = c_{L, \mu}c_{L,\nu} - c_{H, \mu}c_{H,\nu} \f$ (for gamma_only)
ModuleBase::matrix o_delta;

///(Unit: Ry) Hamiltonian matrix
std::vector<double> h_mat;

/// Correction term to the Hamiltonian matrix: \f$\langle\psi|V_\delta|\psi\rangle\f$ (for gamma only)
std::vector<double> H_V_delta;
/// Correction term to Hamiltonian, for multi-k
Expand Down Expand Up @@ -156,6 +159,14 @@ class LCAO_Deepks
// orbital_precalc:[1,NAt,NDscrpt]; gvdm*orbital_pdm_shell
torch::Tensor orbital_precalc_tensor;

// v_delta_pdm_shell[nks,nlocal,nlocal,Inl,nm*nm] = overlap * overlap
double***** v_delta_pdm_shell;
// v_delta_precalc[nks,nlocal,nlocal,NAt,NDscrpt] = gvdm * v_delta_pdm_shell;
torch::Tensor v_delta_precalc_tensor;
//for v_delta==2 , new v_delta_precalc storage method
torch::Tensor psialpha_tensor;
torch::Tensor gevdm_tensor;

/// size of descriptor(projector) basis set
int n_descriptor;

Expand Down Expand Up @@ -229,6 +240,10 @@ class LCAO_Deepks
// for bandgap label calculation; QO added on 2022-1-7
void init_orbital_pdm_shell(const int nks);
void del_orbital_pdm_shell(const int nks);

//for v_delta label calculation; xinyuan added on 2023-2-22
void init_v_delta_pdm_shell(const int nks,const int nlocal);
void del_v_delta_pdm_shell(const int nks,const int nlocal);

//-------------------
// LCAO_deepks_psialpha.cpp
Expand Down Expand Up @@ -443,6 +458,12 @@ class LCAO_Deepks
// 11. cal_orbital_precalc_k : orbital_precalc is usted for training with orbital label,
// for multi-k case, which equals gvdm * orbital_pdm_shell,
// orbital_pdm_shell[1,Inl,nm*nm] = dm_hl_k * overlap * overlap
//12. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
// which equals gvdm * v_delta_pdm_shell,
// v_delta_pdm_shell = overlap * overlap
//13. check_v_delta_precalc : check v_delta_precalc
//14. prepare_psialpha : prepare psialpha for outputting npy file
//15. prepare_gevdm : prepare gevdm for outputting npy file

public:
/// Calculates descriptors
Expand Down Expand Up @@ -493,6 +514,28 @@ class LCAO_Deepks
const LCAO_Orbitals& orb,
Grid_Driver& GridD);

//calculates v_delta_precalc
void cal_v_delta_precalc(const int nlocal,
const int nat,
const UnitCell &ucell,
const LCAO_Orbitals &orb,
Grid_Driver &GridD);
void check_v_delta_precalc(const int nat, const int nks,const int nlocal);

// prepare psialpha for outputting npy file
void prepare_psialpha(const int nlocal,
const int nat,
const UnitCell &ucell,
const LCAO_Orbitals &orb,
Grid_Driver &GridD);
void check_vdp_psialpha(const int nat, const int nks, const int nlocal);

// prepare gevdm for outputting npy file
void prepare_gevdm(
const int nat,
const LCAO_Orbitals &orb);
void check_vdp_gevdm(const int nat);

private:
const Parallel_Orbitals* pv;
void cal_gvdm(const int nat);
Expand All @@ -518,6 +561,10 @@ class LCAO_Deepks
// 7. save_npy_s : stress
// 8. save_npy_o: orbital
// 9. save_npy_orbital_precalc: orbital_precalc -> orbital_precalc.npy
//10. save_npy_h : Hamiltonian
//11. save_npy_v_delta_precalc : v_delta_precalc
//12. save_npy_psialpha : psialpha
//13. save_npy_gevdm : grav_evdm , can use psialpha and gevdm to calculate v_delta_precalc

public:
/// print density matrices
Expand Down Expand Up @@ -557,6 +604,12 @@ class LCAO_Deepks

void load_npy_gedm(const int nat);

//xinyuan added on 2023-2-20
void save_npy_h(const ModuleBase::matrix &H,const std::string &h_file,const int nlocal);//just for gamma only
void save_npy_v_delta_precalc(const int nat, const int nks,const int nlocal);
void save_npy_psialpha(const int nat, const int nks,const int nlocal);
void save_npy_gevdm(const int nat);

//-------------------
// LCAO_deepks_mpi.cpp
//-------------------
Expand All @@ -573,6 +626,16 @@ class LCAO_Deepks
int ndim, // second dimension
double** mat); // the array being reduced
#endif

//-------------------
// LCAO_deepks_hmat.cpp
//-------------------
void save_h_mat(const double *h_mat_in,const int nloc);
void save_h_mat(const std::complex<double> *h_mat_in,const int nloc);
//Collect data in h_in to matrix h_out. Note that left lower trianger in h_out is filled
void collect_h_mat(const std::vector<double> h_in,ModuleBase::matrix &h_out,const int nlocal);//just for gamma only
void check_h_mat(const ModuleBase::matrix &H,const std::string &h_file,const int nlocal);//just for gamma only

};

namespace GlobalC
Expand Down
Loading

0 comments on commit 796c911

Please sign in to comment.