Skip to content

Commit

Permalink
Merge pull request #3 from priba/master
Browse files Browse the repository at this point in the history
Initialization for the sinkhorn iterations
  • Loading branch information
fwilliams authored Feb 20, 2019
2 parents bd212b8 + a1ec9e0 commit 8ae5b8f
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 9 deletions.
17 changes: 13 additions & 4 deletions examples/sinkhorn_loss_functional/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
print('Set B')
print(set_b)

# Condition P*1_d = a and P^T*1_d = b
a = torch.ones(set_a.shape[0:2],
requires_grad=False,
device=set_a.device) / set_a.shape[1]
device=set_a.device)

b = torch.ones(set_b.shape[0:2],
requires_grad=False,
device=set_b.device) / set_b.shape[1]
device=set_b.device)

# Compute the cost matrix
M = pairwise_distances(set_a, set_b, p=args.lp_distance)
Expand All @@ -47,11 +48,19 @@
print(M)

# Compute the transport matrix between each pair of sets in the minibatch with default parameters
P = sinkhorn(a, b, M, 1e-3)

P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8)
print('Transport Matrix')
print(P)

print('Condition error')

aprox_a = P.sum(2)
aprox_b = P.sum(1)

print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item()))
print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item()))

# Compute the loss
loss = (M * P).sum(2).sum(1)

Expand Down
74 changes: 74 additions & 0 deletions examples/sinkhorn_loss_unbalanced/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import torch
from fml.functional import pairwise_distances, sinkhorn

if __name__ == '__main__':
# Parse input arguments
parser = argparse.ArgumentParser(
description='Sinkhorn loss using the functional interface.')
parser.add_argument('--batch_size', '-bz', type=int, default=3,
help='Batch size.')
parser.add_argument('--set1_size', '-sz1', type=int, default=5,
help='Set size.')
parser.add_argument('--set2_size', '-sz2', type=int, default=10,
help='Set size.')
parser.add_argument('--point_dim', '-pd', type=int, default=4,
help='Point dimension.')
parser.add_argument('--lp_distance', '-p', type=int, default=2,
help='p for the Lp-distance.')

args = parser.parse_args()

# Set the parameters
minibatch_size = args.batch_size
set1_size = args.set1_size
set2_size = args.set2_size
point_dim = args.point_dim

# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
set_a = torch.rand([minibatch_size, set1_size, point_dim])
set_b = torch.rand([minibatch_size, set2_size, point_dim])

print('Set A')
print(set_a)

print('Set B')
print(set_b)

# Condition P*1 = a and P^T*1 = b
a = torch.ones(set_a.shape[0:2],
requires_grad=False,
device=set_a.device)

b = torch.ones(set_b.shape[0:2],
requires_grad=False,
device=set_b.device)
# Have the same total mass than set_a
b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True)

# Compute the cost matrix
M = pairwise_distances(set_a, set_b, p=args.lp_distance)

print('Distance')
print(M)

# Compute the transport matrix between each pair of sets in the minibatch with default parameters
P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8)

print('Transport Matrix')
print(P)

print('Condition error')

aprox_a = P.sum(2)
aprox_b = P.sum(1)

print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item()))
print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item()))

# Compute the loss
loss = (M * P).sum(2).sum(1)

print('Loss')
print(loss)

73 changes: 73 additions & 0 deletions examples/sinkhorn_loss_weighted/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import torch
from fml.functional import pairwise_distances, sinkhorn

if __name__ == '__main__':
# Parse input arguments
parser = argparse.ArgumentParser(
description='Sinkhorn loss using the functional interface.')
parser.add_argument('--batch_size', '-bz', type=int, default=3,
help='Batch size.')
parser.add_argument('--set_size', '-sz', type=int, default=10,
help='Set size.')
parser.add_argument('--point_dim', '-pd', type=int, default=4,
help='Point dimension.')
parser.add_argument('--lp_distance', '-p', type=int, default=2,
help='p for the Lp-distance.')

args = parser.parse_args()

# Set the parameters
minibatch_size = args.batch_size
set_size = args.set_size
point_dim = args.point_dim

# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
set_a = torch.rand([minibatch_size, set_size, point_dim])
set_b = torch.rand([minibatch_size, set_size, point_dim])

print('Set A')
print(set_a)

print('Set B')
print(set_b)

# Condition P*1 = a and P^T*1 = b
a = torch.rand(set_a.shape[0:2],
requires_grad=False,
device=set_a.device)
# Keep an average mass of 1 per node
a = a * set_a.shape[1] / a.sum(1, keepdim=True)

b = torch.rand(set_b.shape[0:2],
requires_grad=False,
device=set_b.device)
# Have the same total mass than set_a
b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True)

# Compute the cost matrix
M = pairwise_distances(set_a, set_b, p=args.lp_distance)

print('Distance')
print(M)

# Compute the transport matrix between each pair of sets in the minibatch with default parameters
P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8)

print('Transport Matrix')
print(P)

print('Condition error')

aprox_a = P.sum(2)
aprox_b = P.sum(1)

print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item()))
print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item()))

# Compute the loss
loss = (M * P).sum(2).sum(1)

print('Loss')
print(loss)

6 changes: 3 additions & 3 deletions fml/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2):
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape)
if len(b.shape) != 3:
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape)

return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3)


Expand Down Expand Up @@ -69,8 +68,9 @@ def sinkhorn(a: torch.Tensor, b: torch.Tensor, M: torch.Tensor, eps: float,
raise ValueError("Got unexpected shape for tensor b (%s). Expected [nb, n] where M has shape [nb, m, n]." %
str(b.shape))

# Initialize the iteration with the change of variable
u = torch.zeros(a.shape, dtype=a.dtype, device=a.device)
v = torch.zeros(b.shape, dtype=b.dtype, device=b.device)
v = eps * torch.log(b)

M_t = torch.transpose(M, 1, 2)

Expand All @@ -97,7 +97,7 @@ def stabilized_log_sum_exp(x):
break

log_P = (-M + u.unsqueeze(2) + v.unsqueeze(1)) / eps

P = torch.exp(log_P)

return P
Expand Down
4 changes: 2 additions & 2 deletions fml/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def forward(self, predicted, expected, a=None, b=None):
if a is None:
a = torch.ones(predicted.shape[0:2],
requires_grad=False,
device=predicted.device) / predicted.shape[1]
device=predicted.device)
else:
a = a.to(predicted.device)

if b is None:
b = torch.ones(predicted.shape[0:2],
requires_grad=False,
device=predicted.device) / predicted.shape[1]
device=predicted.device)
else:
b = b.to(predicted.device)

Expand Down

0 comments on commit 8ae5b8f

Please sign in to comment.