Skip to content

Commit

Permalink
Perf: optimize function pow in module_gint (deepmodeling#4680)
Browse files Browse the repository at this point in the history
* optimize function pow in module_gint

* fix a bug

* rename function pow in module_gint

* optimize the calculation of r_lp2

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: Mohan Chen <[email protected]>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 14, 2024
1 parent ad08479 commit 98f0682
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 7 deletions.
6 changes: 3 additions & 3 deletions source/module_hamilt_lcao/module_gint/cal_ddpsir_ylm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void cal_ddpsir_ylm(
const int ll = atom->iw2l[iw];
const int idx_lm = atom->iw2_ylm[iw];

const double rl = pow(distance1, ll);
const double rl = pow_int(distance1, ll);

// derivative of wave functions with respect to atom positions.
const double tmpdphi_rly = (dtmp - tmp * ll / distance1) / rl * rly[idx_lm] / distance1;
Expand Down Expand Up @@ -268,8 +268,8 @@ void cal_ddpsir_ylm(
const int ll = atom->iw2l[iw];
const int idx_lm = atom->iw2_ylm[iw];

const double rl = pow(distance, ll);
const double r_lp2 = pow(distance, ll + 2);
const double rl = pow_int(distance, ll);
const double r_lp2 =rl * distance * distance;

// d/dr (R_l / r^l)
const double tmpdphi = (dtmp - tmp * ll / distance) / rl;
Expand Down
5 changes: 3 additions & 2 deletions source/module_hamilt_lcao/module_gint/cal_dpsir_ylm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ void cal_dpsir_ylm(
double distance = std::sqrt(dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2]);

ModuleBase::Ylm::grad_rl_sph_harm(ucell.atoms[it].nwl, dr[0], dr[1], dr[2], rly, grly.get_ptr_2D());
if (distance < 1e-9)
if (distance < 1e-9) {
distance = 1e-9;
}

const double position = distance / delta_r;

Expand Down Expand Up @@ -115,7 +116,7 @@ void cal_dpsir_ylm(
const int ll = atom->iw2l[iw];
const int idx_lm = atom->iw2_ylm[iw];

const double rl = pow(distance, ll);
const double rl = pow_int(distance, ll);

// 3D wave functions
p_psi[iw] = tmp * rly[idx_lm] / rl;
Expand Down
23 changes: 23 additions & 0 deletions source/module_hamilt_lcao/module_gint/gint_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,29 @@ class Gint_inout

namespace Gint_Tools
{
// if exponent is an integer between 0 and 5 (the most common cases in gint),
// pow_int is much faster than std::pow
inline double pow_int(const double base, const int exp)
{
switch (exp)
{
case 0:
return 1.0;
case 1:
return base;
case 2:
return base * base;
case 3:
return base * base * base;
case 4:
return base * base * base * base;
case 5:
return base * base * base * base * base;
default:
double result = std::pow(base, exp);
return result;
}
}
// vindex[pw.bxyz]
int* get_vindex(const int bxyz,
const int bx,
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/module_gint/grid_meshball.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ double Grid_MeshBall::deal_with_atom_spillage(const double *pos)
cell[ip] = i*this->bigcell_vec1[ip] +
j*this->bigcell_vec2[ip] +
k*this->bigcell_vec3[ip];
dx += std::pow(cell[ip] - pos[ip], 2);
dx += (cell[ip] - pos[ip]) * (cell[ip] - pos[ip]);
}
r2 = std::min(dx, r2);
}
Expand Down
26 changes: 25 additions & 1 deletion source/module_hamilt_lcao/module_gint/kernels/cuda/interp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,30 @@

namespace GintKernel
{
// if exponent is an integer between 0 and 5 (the most common cases in gint),
// pow_int is much faster than std::pow
static __device__ double pow_int(double base, int exp)
{
switch (exp)
{
case 0:
return 1.0;
case 1:
return base;
case 2:
return base * base;
case 3:
return base * base * base;
case 4:
return base * base * base * base;
case 5:
return base * base * base * base * base;
default:
double result = pow(base, exp);
return result;
}
}

static __device__ void interp_rho(const double dist,
const double delta_r,
const int atype,
Expand Down Expand Up @@ -145,7 +169,7 @@ static __device__ void interp_f(const double dist,
// Extract information from atom_iw2_* arrays
const int ll = atom_iw2_l[it_nw_iw];
const int idx_lm = atom_iw2_ylm[it_nw_iw];
const double rl = pow(dist, ll);
const double rl = pow_int(dist, ll);
const double rl_r = 1.0 / rl;
const double dist_r = 1 / dist;
const int dpsi_idx = psi_idx * 3;
Expand Down

0 comments on commit 98f0682

Please sign in to comment.