diff --git a/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp index ac5e92616b..ea785361f1 100644 --- a/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp @@ -160,7 +160,7 @@ void cal_ddpsir_ylm( const int ll = atom->iw2l[iw]; const int idx_lm = atom->iw2_ylm[iw]; - const double rl = pow(distance1, ll); + const double rl = pow_int(distance1, ll); // derivative of wave functions with respect to atom positions. const double tmpdphi_rly = (dtmp - tmp * ll / distance1) / rl * rly[idx_lm] / distance1; @@ -268,8 +268,8 @@ void cal_ddpsir_ylm( const int ll = atom->iw2l[iw]; const int idx_lm = atom->iw2_ylm[iw]; - const double rl = pow(distance, ll); - const double r_lp2 = pow(distance, ll + 2); + const double rl = pow_int(distance, ll); + const double r_lp2 =rl * distance * distance; // d/dr (R_l / r^l) const double tmpdphi = (dtmp - tmp * ll / distance) / rl; diff --git a/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp b/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp index 87644dbcfa..8c6a4ce637 100644 --- a/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp +++ b/source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp @@ -68,8 +68,9 @@ void cal_dpsir_ylm( double distance = std::sqrt(dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2]); ModuleBase::Ylm::grad_rl_sph_harm(ucell.atoms[it].nwl, dr[0], dr[1], dr[2], rly, grly.get_ptr_2D()); - if (distance < 1e-9) + if (distance < 1e-9) { distance = 1e-9; +} const double position = distance / delta_r; @@ -115,7 +116,7 @@ void cal_dpsir_ylm( const int ll = atom->iw2l[iw]; const int idx_lm = atom->iw2_ylm[iw]; - const double rl = pow(distance, ll); + const double rl = pow_int(distance, ll); // 3D wave functions p_psi[iw] = tmp * rly[idx_lm] / rl; diff --git a/source/module_hamilt_lcao/module_gint/gint_tools.h b/source/module_hamilt_lcao/module_gint/gint_tools.h index 178ccf6fa5..c8771dfbe6 100644 --- a/source/module_hamilt_lcao/module_gint/gint_tools.h +++ b/source/module_hamilt_lcao/module_gint/gint_tools.h @@ -136,6 +136,29 @@ class Gint_inout namespace Gint_Tools { +// if exponent is an integer between 0 and 5 (the most common cases in gint), +// pow_int is much faster than std::pow +inline double pow_int(const double base, const int exp) +{ + switch (exp) + { + case 0: + return 1.0; + case 1: + return base; + case 2: + return base * base; + case 3: + return base * base * base; + case 4: + return base * base * base * base; + case 5: + return base * base * base * base * base; + default: + double result = std::pow(base, exp); + return result; + } +} // vindex[pw.bxyz] int* get_vindex(const int bxyz, const int bx, diff --git a/source/module_hamilt_lcao/module_gint/grid_meshball.cpp b/source/module_hamilt_lcao/module_gint/grid_meshball.cpp index e01518c668..1868559d16 100644 --- a/source/module_hamilt_lcao/module_gint/grid_meshball.cpp +++ b/source/module_hamilt_lcao/module_gint/grid_meshball.cpp @@ -127,7 +127,7 @@ double Grid_MeshBall::deal_with_atom_spillage(const double *pos) cell[ip] = i*this->bigcell_vec1[ip] + j*this->bigcell_vec2[ip] + k*this->bigcell_vec3[ip]; - dx += std::pow(cell[ip] - pos[ip], 2); + dx += (cell[ip] - pos[ip]) * (cell[ip] - pos[ip]); } r2 = std::min(dx, r2); } diff --git a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh index 5c5882be8f..31ccf3ca2c 100644 --- a/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh +++ b/source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh @@ -5,6 +5,30 @@ namespace GintKernel { +// if exponent is an integer between 0 and 5 (the most common cases in gint), +// pow_int is much faster than std::pow +static __device__ double pow_int(double base, int exp) +{ + switch (exp) + { + case 0: + return 1.0; + case 1: + return base; + case 2: + return base * base; + case 3: + return base * base * base; + case 4: + return base * base * base * base; + case 5: + return base * base * base * base * base; + default: + double result = pow(base, exp); + return result; + } +} + static __device__ void interp_rho(const double dist, const double delta_r, const int atype, @@ -145,7 +169,7 @@ static __device__ void interp_f(const double dist, // Extract information from atom_iw2_* arrays const int ll = atom_iw2_l[it_nw_iw]; const int idx_lm = atom_iw2_ylm[it_nw_iw]; - const double rl = pow(dist, ll); + const double rl = pow_int(dist, ll); const double rl_r = 1.0 / rl; const double dist_r = 1 / dist; const int dpsi_idx = psi_idx * 3;