diff --git a/examples/sinkhorn_loss_functional/main.py b/examples/sinkhorn_loss_functional/main.py index d2564a8..8bb2482 100644 --- a/examples/sinkhorn_loss_functional/main.py +++ b/examples/sinkhorn_loss_functional/main.py @@ -32,13 +32,14 @@ print('Set B') print(set_b) + # Condition P*1_d = a and P^T*1_d = b a = torch.ones(set_a.shape[0:2], requires_grad=False, - device=set_a.device) / set_a.shape[1] + device=set_a.device) b = torch.ones(set_b.shape[0:2], requires_grad=False, - device=set_b.device) / set_b.shape[1] + device=set_b.device) # Compute the cost matrix M = pairwise_distances(set_a, set_b, p=args.lp_distance) @@ -47,11 +48,19 @@ print(M) # Compute the transport matrix between each pair of sets in the minibatch with default parameters - P = sinkhorn(a, b, M, 1e-3) - + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) + print('Transport Matrix') print(P) + print('Condition error') + + aprox_a = P.sum(2) + aprox_b = P.sum(1) + + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) + # Compute the loss loss = (M * P).sum(2).sum(1) diff --git a/examples/sinkhorn_loss_unbalanced/main.py b/examples/sinkhorn_loss_unbalanced/main.py new file mode 100644 index 0000000..0a6abd5 --- /dev/null +++ b/examples/sinkhorn_loss_unbalanced/main.py @@ -0,0 +1,74 @@ +import argparse +import torch +from fml.functional import pairwise_distances, sinkhorn + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='Sinkhorn loss using the functional interface.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set1_size', '-sz1', type=int, default=5, + help='Set size.') + parser.add_argument('--set2_size', '-sz2', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--lp_distance', '-p', type=int, default=2, + help='p for the Lp-distance.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set1_size = args.set1_size + set2_size = args.set2_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set1_size, point_dim]) + set_b = torch.rand([minibatch_size, set2_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + # Condition P*1 = a and P^T*1 = b + a = torch.ones(set_a.shape[0:2], + requires_grad=False, + device=set_a.device) + + b = torch.ones(set_b.shape[0:2], + requires_grad=False, + device=set_b.device) + # Have the same total mass than set_a + b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True) + + # Compute the cost matrix + M = pairwise_distances(set_a, set_b, p=args.lp_distance) + + print('Distance') + print(M) + + # Compute the transport matrix between each pair of sets in the minibatch with default parameters + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) + + print('Transport Matrix') + print(P) + + print('Condition error') + + aprox_a = P.sum(2) + aprox_b = P.sum(1) + + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) + + # Compute the loss + loss = (M * P).sum(2).sum(1) + + print('Loss') + print(loss) + diff --git a/examples/sinkhorn_loss_weighted/main.py b/examples/sinkhorn_loss_weighted/main.py new file mode 100644 index 0000000..0ffd4d2 --- /dev/null +++ b/examples/sinkhorn_loss_weighted/main.py @@ -0,0 +1,73 @@ +import argparse +import torch +from fml.functional import pairwise_distances, sinkhorn + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='Sinkhorn loss using the functional interface.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set_size', '-sz', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--lp_distance', '-p', type=int, default=2, + help='p for the Lp-distance.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set_size = args.set_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set_size, point_dim]) + set_b = torch.rand([minibatch_size, set_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + # Condition P*1 = a and P^T*1 = b + a = torch.rand(set_a.shape[0:2], + requires_grad=False, + device=set_a.device) + # Keep an average mass of 1 per node + a = a * set_a.shape[1] / a.sum(1, keepdim=True) + + b = torch.rand(set_b.shape[0:2], + requires_grad=False, + device=set_b.device) + # Have the same total mass than set_a + b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True) + + # Compute the cost matrix + M = pairwise_distances(set_a, set_b, p=args.lp_distance) + + print('Distance') + print(M) + + # Compute the transport matrix between each pair of sets in the minibatch with default parameters + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) + + print('Transport Matrix') + print(P) + + print('Condition error') + + aprox_a = P.sum(2) + aprox_b = P.sum(1) + + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) + + # Compute the loss + loss = (M * P).sum(2).sum(1) + + print('Loss') + print(loss) + diff --git a/fml/functional.py b/fml/functional.py index 86698e6..9600c1a 100644 --- a/fml/functional.py +++ b/fml/functional.py @@ -15,7 +15,6 @@ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2): raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape) if len(b.shape) != 3: raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape) - return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3) @@ -69,8 +68,9 @@ def sinkhorn(a: torch.Tensor, b: torch.Tensor, M: torch.Tensor, eps: float, raise ValueError("Got unexpected shape for tensor b (%s). Expected [nb, n] where M has shape [nb, m, n]." % str(b.shape)) + # Initialize the iteration with the change of variable u = torch.zeros(a.shape, dtype=a.dtype, device=a.device) - v = torch.zeros(b.shape, dtype=b.dtype, device=b.device) + v = eps * torch.log(b) M_t = torch.transpose(M, 1, 2) @@ -97,7 +97,7 @@ def stabilized_log_sum_exp(x): break log_P = (-M + u.unsqueeze(2) + v.unsqueeze(1)) / eps - + P = torch.exp(log_P) return P diff --git a/fml/nn.py b/fml/nn.py index c875c1c..d74ba5d 100644 --- a/fml/nn.py +++ b/fml/nn.py @@ -42,14 +42,14 @@ def forward(self, predicted, expected, a=None, b=None): if a is None: a = torch.ones(predicted.shape[0:2], requires_grad=False, - device=predicted.device) / predicted.shape[1] + device=predicted.device) else: a = a.to(predicted.device) if b is None: b = torch.ones(predicted.shape[0:2], requires_grad=False, - device=predicted.device) / predicted.shape[1] + device=predicted.device) else: b = b.to(predicted.device)