Skip to content

Commit

Permalink
greenkhorn
Browse files Browse the repository at this point in the history
  • Loading branch information
tfzhou committed Jun 30, 2022
1 parent ec77f97 commit 1c4a778
Showing 1 changed file with 72 additions and 18 deletions.
90 changes: 72 additions & 18 deletions lib/models/modules/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,83 @@


def distributed_sinkhorn(out, sinkhorn_iterations=3, epsilon=0.05):
Q = torch.exp(out / epsilon).t() # K x B
B = Q.shape[1]
K = Q.shape[0]
L = torch.exp(out / epsilon).t() # K x B
B = L.shape[1]
K = L.shape[0]

# make the matrix sums to 1
sum_Q = torch.sum(Q)
Q /= sum_Q
sum_L = torch.sum(L)
L /= sum_L

for _ in range(sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
Q /= sum_of_rows
Q /= K
L /= torch.sum(L, dim=1, keepdim=True)
L /= K

# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
L /= torch.sum(L, dim=0, keepdim=True)
L /= B

Q *= B # the colomns must sum to 1 so that Q is an assignment
Q = Q.t()
L *= B
L = L.t()

indexs = torch.argmax(Q, dim=1)
# Q = torch.nn.functional.one_hot(indexs, num_classes=Q.shape[1]).float()
Q = F.gumbel_softmax(Q, tau=0.5, hard=True)
indexs = torch.argmax(L, dim=1)
# L = torch.nn.functional.one_hot(indexs, num_classes=L.shape[1]).float()
L = F.gumbel_softmax(L, tau=0.5, hard=True)

return Q, indexs
return L, indexs


def distributed_greenkhorn(out, sinkhorn_iterations=100, epsilon=0.05):
L = torch.exp(out / epsilon).t()
K = L.shape[0]
B = L.shape[1]

# make the matrix sums to 1
sum_L = torch.sum(L)
L /= sum_L

r = torch.ones((K,), dtype=L.dtype).to(L.device) / K
c = torch.ones((B,), dtype=L.dtype).to(L.device) / B

r_sum = torch.sum(L, axis=1)
c_sum = torch.sum(L, axis=0)

r_gain = r_sum - r + r * torch.log(r / r_sum + 1e-5)
c_gain = c_sum - c + c * torch.log(c / c_sum + 1e-5)

for _ in range(sinkhorn_iterations):
i = torch.argmax(r_gain)
j = torch.argmax(c_gain)
r_gain_max = r_gain[i]
c_gain_max = c_gain[j]

if r_gain_max > c_gain_max:
scaling = r[i] / r_sum[i]
old_row = L[i, :]
new_row = old_row * scaling
L[i, :] = new_row

L = L / torch.sum(L)
r_sum = torch.sum(L, axis=1)
c_sum = torch.sum(L, axis=0)

r_gain = r_sum - r + r * torch.log(r / r_sum + 1e-5)
c_gain = c_sum - c + c * torch.log(c / c_sum + 1e-5)
else:
scaling = c[j] / c_sum[j]
old_col = L[:, j]
new_col = old_col * scaling
L[:, j] = new_col

L = L / torch.sum(L)
r_sum = torch.sum(L, axis=1)
c_sum = torch.sum(L, axis=0)

r_gain = r_sum - r + r * torch.log(r / r_sum + 1e-5)
c_gain = c_sum - c + c * torch.log(c / c_sum + 1e-5)

L = L.t()

indexs = torch.argmax(L, dim=1)
G = F.gumbel_softmax(L, tau=0.5, hard=True)

return L, indexs

0 comments on commit 1c4a778

Please sign in to comment.