Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable returning distances from knn and knn_graph #133

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions csrc/cpu/knn_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
#include "utils/KDTreeVectorOfVectorsAdaptor.h"
#include "utils/nanoflann.hpp"

torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers) {
using torch::indexing::Slice;
using torch::indexing::None;

std::tuple<torch::Tensor, torch::Tensor>
knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers) {

CHECK_CPU(x);
CHECK_INPUT(x.dim() == 2);
Expand All @@ -24,12 +28,14 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
}

std::vector<size_t> out_vec = std::vector<size_t>();
torch::Tensor out_vec_dist_sqr = torch::empty({y.size(0) * k}, y.options());

AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
// See: nanoflann/examples/vector_of_vectors_example.cpp

auto x_data = x.data_ptr<scalar_t>();
auto y_data = y.data_ptr<scalar_t>();
auto out_vec_dist_sqr_data = out_vec_dist_sqr.data_ptr<scalar_t>();
typedef std::vector<std::vector<scalar_t>> vec_t;

if (!ptr_x.has_value()) { // Single example.
Expand All @@ -54,6 +60,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);

for (size_t j = 0; j < num_matches; j++) {
out_vec_dist_sqr_data[out_vec.size() / 2] = out_dist_sqr[j];
out_vec.push_back(ret_index[j]);
out_vec.push_back(i);
}
Expand Down Expand Up @@ -90,6 +97,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);

for (size_t j = 0; j < num_matches; j++) {
out_vec_dist_sqr_data[out_vec.size() / 2] = out_dist_sqr[j];
out_vec.push_back(x_start + ret_index[j]);
out_vec.push_back(i);
}
Expand All @@ -101,5 +109,8 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
const int64_t size = out_vec.size() / 2;
auto out = torch::from_blob(out_vec.data(), {size, 2},
x.options().dtype(torch::kLong));
return out.t().index_select(0, torch::tensor({1, 0}));
return std::make_tuple(
out.t().index_select(0, torch::tensor({1, 0})),
out_vec_dist_sqr.index({Slice(None, size)})
);
}
11 changes: 7 additions & 4 deletions csrc/cpu/knn_cpu.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include <tuple>

#include "../extensions.h"

torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers);
std::tuple<torch::Tensor, torch::Tensor>
knn_cpu(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
int64_t num_workers);
27 changes: 18 additions & 9 deletions csrc/cuda/knn_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ __global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
int64_t *__restrict__ row, int64_t *__restrict__ col,
const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
const int64_t num_examples, const bool cosine) {
scalar_t *__restrict__ dist, const int64_t k, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const bool cosine) {

const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
Expand Down Expand Up @@ -80,13 +81,17 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
for (int64_t e = 0; e < k; e++) {
row[n_y * k + e] = n_y;
col[n_y * k + e] = best_idx[e];
dist[n_y * k + e] = best_dist[e];
}
}

torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, const int64_t k,
const bool cosine) {
std::tuple<torch::Tensor, torch::Tensor>
knn_cuda(const torch::Tensor x,
const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y,
const int64_t k,
const bool cosine) {

CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
Expand Down Expand Up @@ -117,6 +122,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,

auto row = torch::empty({y.size(0) * k}, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
auto dist = torch::empty({y.size(0) * k}, y.options());

dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);

Expand All @@ -126,10 +132,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), dist.data_ptr<scalar_t>(),
k, x.size(0), y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
});

auto mask = col != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
return std::make_tuple(
torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0),
dist.masked_select(mask)
);
}
11 changes: 7 additions & 4 deletions csrc/cuda/knn_cuda.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include <tuple>

#include "../extensions.h"

torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
bool cosine);
std::tuple<torch::Tensor, torch::Tensor>
knn_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
bool cosine);
9 changes: 5 additions & 4 deletions csrc/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ PyMODINIT_FUNC PyInit__knn_cpu(void) { return NULL; }
#endif
#endif

CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine,
int64_t num_workers) {
CLUSTER_API std::tuple<torch::Tensor, torch::Tensor> knn(
torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k, bool cosine,
int64_t num_workers) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return knn_cuda(x, y, ptr_x, ptr_y, k, cosine);
Expand Down
108 changes: 100 additions & 8 deletions test/test_knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from itertools import product

import pytest
Expand Down Expand Up @@ -32,21 +33,73 @@ def test_knn(dtype, device):
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)

edge_index = knn(x, y, 2)
edge_index, distances = knn(x, y, 2, return_distances=True)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])
assert torch.allclose(distances,
distances.new_tensor([1.0, 1.0, 1.0, 1.0]))

edge_index = knn(x, y, 2, batch_x, batch_y)
edge_index, distances = knn(x, y, 2, batch_x, batch_y,
return_distances=True)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
assert torch.allclose(distances,
distances.new_tensor([1.0, 1.0, 1.0, 1.0]))

if x.is_cuda:
edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
edge_index, distances = knn(
x, y, 2, batch_x, batch_y, cosine=True, return_distances=True
)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
assert torch.allclose(distances, distances.new_tensor(
[1.0 - math.cos(math.pi / 4.0) for _ in range(4)]
))

# Skipping a batch
batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device)
batch_y = tensor([0, 2], torch.long, device)
edge_index = knn(x, y, 2, batch_x, batch_y)
edge_index, distances = knn(x, y, 2, batch_x, batch_y,
return_distances=True)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
assert torch.allclose(distances,
distances.new_tensor([1.0, 1.0, 1.0, 1.0]))


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_jit(dtype, device):
@torch.jit.script
def knn_jit(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: torch.Tensor, batch_y: torch.Tensor):
return knn(x, y, k, batch_x, batch_y)

@torch.jit.script
def knn_jit_distance(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: torch.Tensor, batch_y: torch.Tensor):
return knn(x, y, k, batch_x, batch_y, return_distances=True)

x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)
y = tensor([
[1, 0],
[-1, 0],
], dtype, device)

batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
batch_y = tensor([0, 1], torch.long, device)

edge_index = knn_jit(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

edge_index, distances = knn_jit_distance(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
assert torch.allclose(distances,
distances.new_tensor([1.0, 1.0, 1.0, 1.0]))


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
Expand All @@ -58,23 +111,62 @@ def test_knn_graph(dtype, device):
[+1, -1],
], dtype, device)

edge_index = knn_graph(x, k=2, flow='target_to_source')
edge_index, distances = knn_graph(
x, k=2, flow='target_to_source', return_distances=True
)
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
(2, 3), (3, 0), (3, 2)])
assert torch.allclose(distances,
distances.new_tensor([4.0 for _ in range(8)]))

edge_index = knn_graph(x, k=2, flow='source_to_target')
edge_index = knn_graph(
x, k=2, flow='source_to_target', return_distances=False
)
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_knn_graph_jit(dtype, device):
@torch.jit.script
def knn_graph_jit(x: torch.Tensor, k: int):
return knn_graph(x, k, flow="target_to_source")

@torch.jit.script
def knn_graph_jit_distance(x: torch.Tensor, k: int):
return knn_graph(x, k, flow="target_to_source", return_distances=True)

x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
], dtype, device)

edge_index = knn_graph_jit(x, k=2)
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
(2, 3), (3, 0), (3, 2)])

edge_index, distances = knn_graph_jit_distance(x, k=2)
assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
(2, 3), (3, 0), (3, 2)])
assert torch.allclose(distances,
distances.new_tensor([4.0 for _ in range(8)]))


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)

edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True)
edge_index, distances = knn_graph(
x, k=5, flow='target_to_source', loop=True, return_distances=True
)

tree = scipy.spatial.cKDTree(x.cpu().numpy())
_, col = tree.query(x.cpu(), k=5)
dist, col = tree.query(x.cpu(), k=5)
truth = set([(i, j) for i, ns in enumerate(col) for j in ns])

assert to_set(edge_index.cpu()) == truth
assert torch.allclose(
distances, torch.from_numpy(dist).to(distances).flatten().pow(2)
)
Loading