Skip to content

Commit

Permalink
Remove abuse use of cp.einsum and cp.zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
henryw7 committed Jan 7, 2025
1 parent e4fac68 commit 0dbb558
Showing 1 changed file with 43 additions and 55 deletions.
98 changes: 43 additions & 55 deletions gpu4pyscf/solvent/hessian/pcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def analytic_grad_vmat(pcmobj, dm, mo_coeff, mo_occ, atmlst=None, verbose=None):
intopt_derivative = int3c1e.VHFOpt(mol)
intopt_derivative.build(cutoff = 1e-14, aosym = False)

dIdx_mo = cupy.zeros([len(atmlst), 3, nmo, nocc])
dIdx_mo = cupy.empty([len(atmlst), 3, nmo, nocc])

dIdA = int1e_grids_ip1(mol, grid_coords, charges = q_sym, intopt = intopt_derivative, charge_exponents = charge_exp**2)
aoslice = mol.aoslice_by_atom()
Expand All @@ -203,37 +203,45 @@ def analytic_grad_vmat(pcmobj, dm, mo_coeff, mo_occ, atmlst=None, verbose=None):
# dIdx[i_atom, :, :, :] = 0
# dIdx[i_atom, :, p0:p1, :] += dIdA[:, p0:p1, :]
# dIdx[i_atom, :, :, p0:p1] += dIdA[:, p0:p1, :].transpose(0,2,1)
dIdx_mo[i_atom, :, :, :] += cupy.einsum('ip,dpq,qj->dij', mo_coeff[p0:p1, :].T, dIdA[:, p0:p1, :], mocc)
dIdx_mo[i_atom, :, :, :] += cupy.einsum('ip,dpq,qj->dij', mo_coeff.T, dIdA[:, p0:p1, :].transpose(0,2,1), mocc[p0:p1, :])
dIdA_mo = dIdA[:, p0:p1, :] @ mocc
dIdA_mo = cupy.einsum('ip,dpj->dij', mo_coeff[p0:p1, :].T, dIdA_mo)
dIdB_mo = dIdA[:, p0:p1, :].transpose(0,2,1) @ mocc[p0:p1, :]
dIdB_mo = cupy.einsum('ip,dpj->dij', mo_coeff.T, dIdB_mo)
dIdx_mo[i_atom, :, :, :] = dIdA_mo + dIdB_mo

for i_atom in atmlst:
g0,g1 = gridslice[i_atom]
dIdC = int1e_grids_ip2(mol, grid_coords[g0:g1,:], charges = q_sym[g0:g1],
intopt = intopt_derivative, charge_exponents = charge_exp[g0:g1]**2)
dIdx_mo[i_atom, :, :, :] += cupy.einsum('ip,dpq,qj->dij', mo_coeff.T, dIdC, mocc)
dIdC_mo = dIdC @ mocc
dIdC_mo = cupy.einsum('ip,dpj->dij', mo_coeff.T, dIdC_mo)
dIdx_mo[i_atom, :, :, :] += dIdC_mo

dV_on_molecule_dx_mo = dIdx_mo

inverse_K = cupy.linalg.inv(K)
def append_dS_dot_q(dS, dSii, q, output, atmlst, gridslice):
def get_dS_dot_q(dS, dSii, q, atmlst, gridslice):
output = cupy.einsum('diA,i->Adi', dSii[:,:,atmlst], q)
for i_atom in atmlst:
g0,g1 = gridslice[i_atom]
output[i_atom, :, g0:g1] += cupy.einsum('dij,j->di', dS[:,g0:g1,:], q)
output[i_atom, :, :] -= cupy.einsum('dij,j->di', dS[:,:,g0:g1], q[g0:g1])
output += cupy.einsum('diA,i->Adi', dSii[:,:,atmlst], q)
def append_dST_dot_q(dS, dSii, q, output, atmlst, gridslice):
append_dS_dot_q(-dS.transpose(0,2,1), dSii, q, output, atmlst, gridslice)
return output
def get_dST_dot_q(dS, dSii, q, atmlst, gridslice):
return get_dS_dot_q(-dS.transpose(0,2,1), dSii, q, atmlst, gridslice)

def append_dA_dot_q(dA, q, output, atmlst, gridslice):
output += cupy.einsum('diA,i->Adi', dA[:,:,atmlst], q)
def get_dA_dot_q(dA, q, atmlst, gridslice):
return cupy.einsum('diA,i->Adi', dA[:,:,atmlst], q)

def append_dD_dot_q(dD, q, output, atmlst, gridslice):
def get_dD_dot_q(dD, q, atmlst, gridslice):
output = cupy.zeros([len(atmlst), 3, ngrids])
for i_atom in atmlst:
g0,g1 = gridslice[i_atom]
output[i_atom, :, g0:g1] += cupy.einsum('dij,j->di', dD[:,g0:g1,:], q)
output[i_atom, :, :] -= cupy.einsum('dij,j->di', dD[:,:,g0:g1], q[g0:g1])
def append_dDT_dot_q(dD, q, output, atmlst, gridslice):
append_dD_dot_q(-dD.transpose(0,2,1), q, output, atmlst, gridslice)
return output
def get_dDT_dot_q(dD, q, atmlst, gridslice):
return get_dD_dot_q(-dD.transpose(0,2,1), q, atmlst, gridslice)

if pcmobj.method.upper() in ['C-PCM', 'CPCM', 'COSMO']:
_, dS = get_dD_dS(pcmobj.surface, with_D=False, with_S=True)
Expand All @@ -242,8 +250,7 @@ def append_dDT_dot_q(dD, q, output, atmlst, gridslice):
dF = None

# dR = 0, dK = dS
dSdx_dot_q = cupy.zeros((len(atmlst), 3, ngrids))
append_dS_dot_q(dS, dSii, q_sym, dSdx_dot_q, atmlst, gridslice)
dSdx_dot_q = get_dS_dot_q(dS, dSii, q_sym, atmlst, gridslice)

dqdx_fix_Vq = cupy.einsum('ij,Adj->Adi', inverse_K, dSdx_dot_q)

Expand All @@ -258,51 +265,42 @@ def append_dDT_dot_q(dD, q, output, atmlst, gridslice):
# dK = dS - f_eps/(2*pi) * (dD*A*S + D*dA*S + D*A*dS)
f_eps_over_2pi = f_epsilon/(2.0*PI)

dSdx_dot_q = cupy.zeros((len(atmlst), 3, ngrids))
q = inverse_K @ R @ v_grids
append_dS_dot_q(dS, dSii, q, dSdx_dot_q, atmlst, gridslice)
dSdx_dot_q = get_dS_dot_q(dS, dSii, q, atmlst, gridslice)

DA = D*A
dKdx_dot_q = dSdx_dot_q - f_eps_over_2pi * cupy.einsum('ij,Adj->Adi', DA, dSdx_dot_q)

dAdx_dot_Sq = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, S @ q, dAdx_dot_Sq, atmlst, gridslice)
dAdx_dot_Sq = get_dA_dot_q(dA, S @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_2pi * cupy.einsum('ij,Adj->Adi', D, dAdx_dot_Sq)

AS = (A * S.T).T # It's just diag(A) @ S
dDdx_dot_ASq = cupy.zeros((len(atmlst), 3, ngrids))
append_dD_dot_q(dD, AS @ q, dDdx_dot_ASq, atmlst, gridslice)
dDdx_dot_ASq = get_dD_dot_q(dD, AS @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_2pi * dDdx_dot_ASq

dqdx_fix_Vq = -cupy.einsum('ij,Adj->Adi', inverse_K, dKdx_dot_q)

dAdx_dot_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, v_grids, dAdx_dot_V, atmlst, gridslice)
dAdx_dot_V = get_dA_dot_q(dA, v_grids, atmlst, gridslice)

dDdx_dot_AV = cupy.zeros((len(atmlst), 3, ngrids))
append_dD_dot_q(dD, A * v_grids, dDdx_dot_AV, atmlst, gridslice)
dDdx_dot_AV = get_dD_dot_q(dD, A * v_grids, atmlst, gridslice)

dRdx_dot_V = f_eps_over_2pi * (dDdx_dot_AV + cupy.einsum('ij,Adj->Adi', D, dAdx_dot_V))
dqdx_fix_Vq += cupy.einsum('ij,Adj->Adi', inverse_K, dRdx_dot_V)

invKT_V = inverse_K.T @ v_grids
dDdxT_dot_invKT_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dDT_dot_q(dD, invKT_V, dDdxT_dot_invKT_V, atmlst, gridslice)
dDdxT_dot_invKT_V = get_dDT_dot_q(dD, invKT_V, atmlst, gridslice)

DT_invKT_V = D.T @ invKT_V
dAdxT_dot_DT_invKT_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, DT_invKT_V, dAdxT_dot_DT_invKT_V, atmlst, gridslice)
dAdxT_dot_DT_invKT_V = get_dA_dot_q(dA, DT_invKT_V, atmlst, gridslice)
dqdx_fix_Vq += f_eps_over_2pi * (cupy.einsum('i,Adi->Adi', A, dDdxT_dot_invKT_V) + dAdxT_dot_DT_invKT_V)

dSdxT_dot_invKT_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dST_dot_q(dS, dSii, invKT_V, dSdxT_dot_invKT_V, atmlst, gridslice)
dSdxT_dot_invKT_V = get_dST_dot_q(dS, dSii, invKT_V, atmlst, gridslice)
dKdxT_dot_invKT_V = dSdxT_dot_invKT_V

dKdxT_dot_invKT_V -= f_eps_over_2pi * cupy.einsum('ij,Adj->Adi', AS.T, dDdxT_dot_invKT_V)
dKdxT_dot_invKT_V -= f_eps_over_2pi * cupy.einsum('ij,Adj->Adi', S.T, dAdxT_dot_DT_invKT_V)

dSdxT_dot_AT_DT_invKT_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dST_dot_q(dS, dSii, DA.T @ invKT_V, dSdxT_dot_AT_DT_invKT_V, atmlst, gridslice)
dSdxT_dot_AT_DT_invKT_V = get_dST_dot_q(dS, dSii, DA.T @ invKT_V, atmlst, gridslice)
dKdxT_dot_invKT_V -= f_eps_over_2pi * dSdxT_dot_AT_DT_invKT_V

dqdx_fix_Vq += -cupy.einsum('ij,Adj->Adi', R.T @ inverse_K.T, dKdxT_dot_invKT_V)
Expand All @@ -319,31 +317,25 @@ def append_dDT_dot_q(dD, q, output, atmlst, gridslice):
f_eps_over_4pi = f_epsilon/(4.0*PI)

def dK_dot_q(q):
dSdx_dot_q = cupy.zeros((len(atmlst), 3, ngrids))
append_dS_dot_q(dS, dSii, q, dSdx_dot_q, atmlst, gridslice)
dSdx_dot_q = get_dS_dot_q(dS, dSii, q, atmlst, gridslice)

DA = D*A
dKdx_dot_q = dSdx_dot_q - f_eps_over_4pi * cupy.einsum('ij,Adj->Adi', DA, dSdx_dot_q)

dAdx_dot_Sq = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, S @ q, dAdx_dot_Sq, atmlst, gridslice)
dAdx_dot_Sq = get_dA_dot_q(dA, S @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_4pi * cupy.einsum('ij,Adj->Adi', D, dAdx_dot_Sq)

AS = (A * S.T).T # It's just diag(A) @ S
dDdx_dot_ASq = cupy.zeros((len(atmlst), 3, ngrids))
append_dD_dot_q(dD, AS @ q, dDdx_dot_ASq, atmlst, gridslice)
dDdx_dot_ASq = get_dD_dot_q(dD, AS @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_4pi * dDdx_dot_ASq

dDdxT_dot_q = cupy.zeros((len(atmlst), 3, ngrids))
append_dDT_dot_q(dD, q, dDdxT_dot_q, atmlst, gridslice)
dDdxT_dot_q = get_dDT_dot_q(dD, q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_4pi * cupy.einsum('ij,Adj->Adi', AS.T, dDdxT_dot_q)

dAdxT_dot_DT_q = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, D.T @ q, dAdxT_dot_DT_q, atmlst, gridslice)
dAdxT_dot_DT_q = get_dA_dot_q(dA, D.T @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_4pi * cupy.einsum('ij,Adj->Adi', S.T, dAdxT_dot_DT_q)

dSdxT_dot_AT_DT_q = cupy.zeros((len(atmlst), 3, ngrids))
append_dST_dot_q(dS, dSii, DA.T @ q, dSdxT_dot_AT_DT_q, atmlst, gridslice)
dSdxT_dot_AT_DT_q = get_dST_dot_q(dS, dSii, DA.T @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_4pi * dSdxT_dot_AT_DT_q

return dKdx_dot_q
Expand All @@ -354,22 +346,18 @@ def dK_dot_q(q):
dKdx_dot_q = dK_dot_q(q)
dqdx_fix_Vq = -cupy.einsum('ij,Adj->Adi', inverse_K, dKdx_dot_q)

dAdx_dot_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, v_grids, dAdx_dot_V, atmlst, gridslice)
dAdx_dot_V = get_dA_dot_q(dA, v_grids, atmlst, gridslice)

dDdx_dot_AV = cupy.zeros((len(atmlst), 3, ngrids))
append_dD_dot_q(dD, A * v_grids, dDdx_dot_AV, atmlst, gridslice)
dDdx_dot_AV = get_dD_dot_q(dD, A * v_grids, atmlst, gridslice)

dRdx_dot_V = f_eps_over_2pi * (dDdx_dot_AV + cupy.einsum('ij,Adj->Adi', D, dAdx_dot_V))
dqdx_fix_Vq += cupy.einsum('ij,Adj->Adi', inverse_K, dRdx_dot_V)

invKT_V = inverse_K.T @ v_grids
dDdxT_dot_invKT_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dDT_dot_q(dD, invKT_V, dDdxT_dot_invKT_V, atmlst, gridslice)
dDdxT_dot_invKT_V = get_dDT_dot_q(dD, invKT_V, atmlst, gridslice)

DT_invKT_V = D.T @ invKT_V
dAdxT_dot_DT_invKT_V = cupy.zeros((len(atmlst), 3, ngrids))
append_dA_dot_q(dA, DT_invKT_V, dAdxT_dot_DT_invKT_V, atmlst, gridslice)
dAdxT_dot_DT_invKT_V = get_dA_dot_q(dA, DT_invKT_V, atmlst, gridslice)
dqdx_fix_Vq += f_eps_over_2pi * (cupy.einsum('i,Adi->Adi', A, dDdxT_dot_invKT_V) + dAdxT_dot_DT_invKT_V)

dKdx_dot_invKT_V = dK_dot_q(invKT_V)
Expand All @@ -384,7 +372,7 @@ def dK_dot_q(q):
for i_xyz in range(3):
dIdx_from_dqdx = int1e_grids(mol, grid_coords, charges = dqdx_fix_Vq[i_atom, i_xyz, :],
intopt = intopt_fock, charge_exponents = charge_exp**2)
dV_on_molecule_dx_mo[i_atom, i_xyz, :, :] += cupy.einsum('ip,pq,qj->ij', mo_coeff.T, dIdx_from_dqdx, mocc)
dV_on_molecule_dx_mo[i_atom, i_xyz, :, :] += mo_coeff.T @ dIdx_from_dqdx @ mocc

atom_coords = mol.atom_coords(unit='B')
atom_charges = numpy.asarray(mol.atom_charges(), dtype=numpy.float64)
Expand Down Expand Up @@ -416,7 +404,7 @@ def dK_dot_q(q):
invK_R_dVdx = 0.5 * (inverse_K @ R + R.T @ inverse_K.T) @ dV_on_charge_dx[i_atom, i_xyz, :]
dIdx_from_dVdx = int1e_grids(mol, grid_coords, charges = invK_R_dVdx,
intopt = intopt_fock, charge_exponents = charge_exp**2)
dV_on_molecule_dx_mo[i_atom, i_xyz, :, :] += cupy.einsum('ip,pq,qj->ij', mo_coeff.T, dIdx_from_dVdx, mocc)
dV_on_molecule_dx_mo[i_atom, i_xyz, :, :] += mo_coeff.T @ dIdx_from_dVdx @ mocc

t1 = log.timer_debug1('computing solvent grad veff', *t1)
return dV_on_molecule_dx_mo
Expand Down

0 comments on commit 0dbb558

Please sign in to comment.