Skip to content

Commit

Permalink
Skip vecdot.
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 10, 2025
1 parent 1b11ab1 commit 30dc8db
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/dartsort/cluster/ppcalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def ppca_e_step(
n_neighb = ndata["n_neighb"]
neighb_nc = ndata["neighb_nc"]
D_neighb = ndata["D_neighb"]
w = ndata["w0"] / ess
w_ = ndata["w1"] / ess
w__ = ndata["w2"] / ess
have_missing = ndata["have_missing"]
Expand Down Expand Up @@ -284,9 +285,9 @@ def ppca_e_step(
# pca-centered data
if yes_pca and have_missing:
CooinvWo = C_oochol.solve(W_o)
xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
Cooinvxcc = C_oochol.solve(xcc.T).T
# Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
# xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
# Cooinvxcc = C_oochol.solve(xcc.T).T
Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
else:
Cooinvxcc = Cooinvxc

Expand All @@ -301,27 +302,37 @@ def ppca_e_step(
e_xcu = xc[:, :, None] * ubar[:, None, :]
if yes_pca and have_missing:
e_mxcu = (Cooinvxc @ C_mo.T)[:, :, None] * ubar[:, None, :]
CmoCooinvWo = C_mo @ CooinvWo
e_mxcu += (uubar @ (W_m - CmoCooinvWo).T).mT
# CmoCooinvWo = C_mo @ CooinvWo
Wm_less_CmoCooinvWo = W_m.addmm(C_mo, CooinvWo, beta=-1)
shp = Wm_less_CmoCooinvWo.shape
# e_mxcu += (uubar @ (W_m - CmoCooinvWo).T).mT
# print(f"{uubar.shape=} {Wm_less_CmoCooinvWo.shape=}")
Wm_less_CmoCooinvWo = Wm_less_CmoCooinvWo.unsqueeze(0)
Wm_less_CmoCooinvWo = Wm_less_CmoCooinvWo.broadcast_to((len(uubar), *shp))
# e_mxcu += uubar.mT @ Wm_less_CmoCooinvWo
e_mxcu.baddbmm_(Wm_less_CmoCooinvWo, uubar)

# take weighted averages
if yes_pca:
mean_ubar = torch.linalg.vecdot(w_, ubar, dim=0)
mean_uubar = torch.linalg.vecdot(w__, uubar, dim=0)
mean_ubar = w @ ubar
mean_uubar = w @ uubar.view(n_neighb, -1)
mean_uubar = mean_uubar.view(uubar.shape[1:])

wx = torch.linalg.vecdot(w_, x, dim=0)
wx = w @ x
if have_missing:
wxbar_m = torch.linalg.vecdot(w_, xbar_m, dim=0)
wxbar_m = w @ xbar_m
ybar = y.new_zeros((rank, nc))
ybar[:, active_subset] = wx.view(rank, neighb_nc)
ybar[:, missing_subset] = wxbar_m.view(rank, nc - neighb_nc)
else:
ybar = wx.view(rank, nc)

if yes_pca:
wxcu = torch.linalg.vecdot(w__, e_xcu, dim=0)
wxcu = w @ e_xcu.view(n_neighb, -1)
wxcu = wxcu.view(e_xcu.shape[1:])
if have_missing and yes_pca:
wmxcu = torch.linalg.vecdot(w__, e_mxcu, dim=0)
wmxcu = w @ e_mxcu.view(n_neighb, -1)
wmxcu = wmxcu.view(e_mxcu.shape[1:])
ycubar = y.new_zeros((rank, nc, M))
ycubar[:, active_subset] = wxcu.view(rank, neighb_nc, M)
ycubar[:, missing_subset] = wmxcu.view(rank, nc - neighb_nc, M)
Expand Down

0 comments on commit 30dc8db

Please sign in to comment.