Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Davidson iterations for tddft on GPU #305

Merged
merged 10 commits into from
Jan 20, 2025
209 changes: 124 additions & 85 deletions gpu4pyscf/tdscf/_lr_eig.py

Large diffs are not rendered by default.

35 changes: 24 additions & 11 deletions gpu4pyscf/tdscf/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@

import numpy as np
import cupy as cp
import scipy.linalg
from pyscf import gto
from pyscf import lib
from pyscf.tdscf import rhf as tdhf_cpu
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh, eig as lr_eig, real_eig
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh, real_eig
from gpu4pyscf import scf
from gpu4pyscf.lib.cupy_helper import contract, tag_array
from gpu4pyscf.lib import utils
Expand Down Expand Up @@ -53,7 +51,7 @@ def gen_tda_operation(mf, fock_ao=None, singlet=True, wfnsym=None):
orbo2 = orbo * 2. # *2 for double occupancy

e_ia = hdiag = mo_energy[viridx] - mo_energy[occidx,None]
hdiag = hdiag.ravel().get()
hdiag = hdiag.ravel()
vresp = mf.gen_response(singlet=singlet, hermi=0)
nocc, nvir = e_ia.shape

Expand All @@ -66,7 +64,7 @@ def vind(zs):
v1mo = contract('xpq,qo->xpo', v1ao, orbo)
v1mo = contract('xpo,pv->xov', v1mo, orbv.conj())
v1mo += zs * e_ia
return v1mo.reshape(v1mo.shape[0],-1).get()
return v1mo.reshape(v1mo.shape[0],-1)

return vind, hdiag

Expand Down Expand Up @@ -100,11 +98,15 @@ class TDBase(lib.StreamObject):
get_ab = NotImplemented

def get_precond(self, hdiag):
threshold_t=1.0e-4
def precond(x, e, *args):
if isinstance(e, np.ndarray):
e = e[0]
n_states = x.shape[0]
diagd = cp.repeat(hdiag.reshape(1,-1), n_states, axis=0)
e = e.reshape(-1,1)
diagd = hdiag - (e-self.level_shift)
diagd[abs(diagd)<1e-8] = 1e-8
diagd = cp.where(abs(diagd) < threshold_t, cp.sign(diagd)*threshold_t, diagd)
a_size = x.shape[1]//2
diagd[:,a_size:] = diagd[:,a_size:]*(-1)
return x/diagd
return precond

Expand Down Expand Up @@ -170,6 +172,17 @@ def _contract_multipole(tdobj, ints, hermi=True, xy=None):
class TDA(TDBase):
__doc__ = tdhf_cpu.TDA.__doc__

def get_precond(self, hdiag):
threshold_t=1.0e-4
def precond(x, e, *args):
n_states = x.shape[0]
diagd = cp.repeat(hdiag.reshape(1,-1), n_states, axis=0)
e = e.reshape(-1,1)
diagd = hdiag - (e-self.level_shift)
diagd = cp.where(abs(diagd) < threshold_t, cp.sign(diagd)*threshold_t, diagd)
return x/diagd
return precond

def gen_vind(self, mf=None):
'''Generate function to compute Ax'''
if mf is None:
Expand Down Expand Up @@ -228,7 +241,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down Expand Up @@ -291,10 +304,10 @@ def vind(zs):
v1_top += xs * e_ia # AX
v1_bot += ys * e_ia # (A*)Y
return cp.hstack((v1_top.reshape(nz,nocc*nvir),
-v1_bot.reshape(nz,nocc*nvir))).get()
-v1_bot.reshape(nz,nocc*nvir)))

hdiag = cp.hstack([hdiag.ravel(), -hdiag.ravel()])
return vind, hdiag.get()
return vind, hdiag


class TDHF(TDBase):
Expand Down
7 changes: 3 additions & 4 deletions gpu4pyscf/tdscf/rks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
import cupy as cp
from pyscf import lib
from pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.dft.rks import KohnShamDFT
from gpu4pyscf.lib.cupy_helper import contract, tag_array, transpose_sum
from gpu4pyscf.lib import logger
Expand Down Expand Up @@ -54,7 +54,6 @@ def gen_vind(self, mf=None):
d_ia = e_ia ** .5
ed_ia = e_ia * d_ia
hdiag = e_ia.ravel() ** 2
hdiag = hdiag.get()
vresp = mf.gen_response(singlet=singlet, hermi=1)
nocc, nvir = e_ia.shape

Expand All @@ -71,7 +70,7 @@ def vind(zs):
v1mo = contract('xpo,pv->xov', v1mo, orbv)
v1mo += zs * ed_ia
v1mo *= d_ia
return v1mo.reshape(v1mo.shape[0],-1).get()
return v1mo.reshape(v1mo.shape[0],-1)

return vind, hdiag

Expand All @@ -95,7 +94,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down
8 changes: 4 additions & 4 deletions gpu4pyscf/tdscf/tests/test_tdrhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def test_tda_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,nocc,nvir)
ref = mf.to_cpu().TDA().set(singlet=False).gen_vind()[0](zs)
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

df_mf = self.df_mf
ref = df_mf.to_cpu().TDA().set(singlet=True).gen_vind()[0](zs)
dat = df_mf.TDA().set(singlet=True).gen_vind()[0](cp.asarray(zs))
dat = df_mf.TDA().set(singlet=True).gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tdhf_vind(self):
Expand All @@ -140,12 +140,12 @@ def test_tdhf_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,2,nocc,nvir)
ref = mf.to_cpu().TDHF().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDHF().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDHF().set(singlet=True).gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

df_mf = self.df_mf
ref = df_mf.to_cpu().TDHF().set(singlet=False).gen_vind()[0](zs)
dat = df_mf.TDHF().set(singlet=False).gen_vind()[0](zs)
dat = df_mf.TDHF().set(singlet=False).gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions gpu4pyscf/tdscf/tests/test_tdrks.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_tda_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,nocc,nvir)
ref = mf.to_cpu().TDA().set(singlet=False).gen_vind()[0](zs)
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tddft_vind(self):
Expand All @@ -261,7 +261,7 @@ def test_tddft_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,2,nocc,nvir)
ref = mf.to_cpu().TDDFT().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDDFT().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDDFT().set(singlet=True).gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_casida_tddft_vind(self):
Expand All @@ -271,7 +271,7 @@ def test_casida_tddft_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,nocc,nvir)
ref = mf.to_cpu().CasidaTDDFT().gen_vind()[0](zs)
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs))
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions gpu4pyscf/tdscf/tests/test_tduhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_tda_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDA().set().gen_vind()[0](zs)
dat = mf.TDA().set().gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().set().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tdhf_vind(self):
Expand All @@ -111,7 +111,7 @@ def test_tdhf_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,2,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDHF().set().gen_vind()[0](zs)
dat = mf.TDHF().set().gen_vind()[0](zs)
dat = mf.TDHF().set().gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions gpu4pyscf/tdscf/tests/test_tduks.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_tda_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDA().gen_vind()[0](zs)
dat = mf.TDA().gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tddft_vind(self):
Expand All @@ -198,7 +198,7 @@ def test_tddft_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,2,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDDFT().gen_vind()[0](zs)
dat = mf.TDDFT().gen_vind()[0](cp.asarray(zs))
dat = mf.TDDFT().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_casida_tddft_vind(self):
Expand All @@ -209,7 +209,7 @@ def test_casida_tddft_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().CasidaTDDFT().gen_vind()[0](zs)
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs))
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
22 changes: 11 additions & 11 deletions gpu4pyscf/tdscf/uhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def gen_tda_operation(mf, fock_ao=None, wfnsym=None):
e_ia_a = mo_energy[0][viridxa] - mo_energy[0][occidxa,None]
e_ia_b = mo_energy[1][viridxb] - mo_energy[1][occidxb,None]
e_ia = cp.hstack((e_ia_a.reshape(-1), e_ia_b.reshape(-1)))
hdiag = e_ia.get()
hdiag = e_ia
nocca, nvira = e_ia_a.shape
noccb, nvirb = e_ia_b.shape

Expand All @@ -88,7 +88,7 @@ def vind(zs):
v1a += za * e_ia_a
v1b += zb * e_ia_b
hx = cp.hstack((v1a.reshape(nz,-1), v1b.reshape(nz,-1)))
return hx.get()
return hx

return vind, hdiag

Expand Down Expand Up @@ -185,7 +185,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down Expand Up @@ -258,7 +258,7 @@ def gen_vind(self):
orbva = mo_coeff[0][:,viridxa]
orbov = (orbob, orbva)
e_ia = mo_energy[0][viridxa] - mo_energy[1][occidxb,None]
hdiag = e_ia.ravel().get()
hdiag = e_ia.ravel()

elif extype == 1:
occidxa = mo_occ[0] > 0
Expand All @@ -267,7 +267,7 @@ def gen_vind(self):
orbvb = mo_coeff[1][:,viridxb]
orbov = (orboa, orbvb)
e_ia = mo_energy[1][viridxb] - mo_energy[0][occidxa,None]
hdiag = e_ia.ravel().get()
hdiag = e_ia.ravel()

vresp = gen_uhf_response_sf(
mf, hermi=0, collinear=self.collinear,
Expand All @@ -283,7 +283,7 @@ def vind(zs):
v1mo = contract('xpq,qo->xpo', v1ao, orbo)
v1mo = contract('xpo,pv->xov', v1mo, orbv.conj())
v1mo += zs * e_ia
return v1mo.reshape(len(v1mo), -1).get()
return v1mo.reshape(len(v1mo), -1)

return vind, hdiag

Expand Down Expand Up @@ -461,10 +461,10 @@ def vind(zs):
v1_bot[:,:nocca*nvira] += v1a_bot.reshape(nz,-1)
v1_top[:,nocca*nvira:] += v1b_top.reshape(nz,-1)
v1_bot[:,nocca*nvira:] += v1b_bot.reshape(nz,-1)
return cp.hstack([v1_top, -v1_bot]).get()
return cp.hstack([v1_top, -v1_bot])

hdiag = cp.hstack([hdiag.ravel(), -hdiag.ravel()])
return vind, hdiag.get()
return vind, hdiag


class TDHF(TDBase):
Expand Down Expand Up @@ -578,9 +578,9 @@ def gen_vind(self):

extype = self.extype
if extype == 0:
hdiag = cp.hstack([e_ia_b2a.ravel(), -e_ia_a2b.ravel()]).get()
hdiag = cp.hstack([e_ia_b2a.ravel(), -e_ia_a2b.ravel()])
else:
hdiag = cp.hstack([e_ia_a2b.ravel(), -e_ia_b2a.ravel()]).get()
hdiag = cp.hstack([e_ia_a2b.ravel(), -e_ia_b2a.ravel()])

vresp = gen_uhf_response_sf(
mf, hermi=0, collinear=self.collinear,
Expand Down Expand Up @@ -681,7 +681,7 @@ def vind(zs):
v1_top += zs_a2b * e_ia_a2b
v1_bot += zs_b2a * e_ia_b2a
hx = cp.hstack([v1_top.reshape(nz,-1), -v1_bot.reshape(nz,-1)])
return hx.get()
return hx

return vind, hdiag

Expand Down
8 changes: 4 additions & 4 deletions gpu4pyscf/tdscf/uks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cupy as cp
from pyscf import symm
from pyscf import lib
from pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.dft.rks import KohnShamDFT
from gpu4pyscf.lib.cupy_helper import contract, tag_array, transpose_sum
from gpu4pyscf.lib import logger
Expand Down Expand Up @@ -69,7 +69,7 @@ def gen_vind(self, mf=None):
d_ia = e_ia**.5
ed_ia = e_ia * d_ia
hdiag = e_ia ** 2
hdiag = hdiag.get()
hdiag = hdiag
vresp = mf.gen_response(mo_coeff, mo_occ, hermi=1)
nocca, nvira = e_ia_a.shape
noccb, nvirb = e_ia_b.shape
Expand All @@ -96,7 +96,7 @@ def vind(zs):
hx = cp.hstack((v1a.reshape(nz,-1), v1b.reshape(nz,-1)))
hx += ed_ia * zs
hx *= d_ia
return hx.get()
return hx

return vind, hdiag

Expand All @@ -120,7 +120,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down
Loading