From 1f90725e8688ca27f424b3b4dbb4dd44891e5261 Mon Sep 17 00:00:00 2001 From: Zhe Chen Date: Mon, 27 Jun 2022 14:22:13 +0900 Subject: [PATCH 1/4] Enable returning distances from knn and knn_graph --- csrc/cpu/knn_cpu.cpp | 21 +++++-- csrc/cpu/knn_cpu.h | 11 ++-- csrc/cuda/knn_cuda.cu | 27 ++++++--- csrc/cuda/knn_cuda.h | 11 ++-- csrc/knn.cpp | 9 +-- test/test_knn.py | 100 +++++++++++++++++++++++++++++++--- torch_cluster/knn.py | 124 ++++++++++++++++++++++++++++++++++++------ 7 files changed, 251 insertions(+), 52 deletions(-) diff --git a/csrc/cpu/knn_cpu.cpp b/csrc/cpu/knn_cpu.cpp index 4ba7da8..7d8deca 100644 --- a/csrc/cpu/knn_cpu.cpp +++ b/csrc/cpu/knn_cpu.cpp @@ -4,10 +4,14 @@ #include "utils/KDTreeVectorOfVectorsAdaptor.h" #include "utils/nanoflann.hpp" -torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, - torch::optional ptr_x, - torch::optional ptr_y, int64_t k, - int64_t num_workers) { +using torch::indexing::Slice; +using torch::indexing::None; + +std::tuple +knn_cpu(torch::Tensor x, torch::Tensor y, + torch::optional ptr_x, + torch::optional ptr_y, int64_t k, + int64_t num_workers) { CHECK_CPU(x); CHECK_INPUT(x.dim() == 2); @@ -24,12 +28,14 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, } std::vector out_vec = std::vector(); + torch::Tensor out_vec_dist_sqr = torch::empty({y.size(0) * k}, y.options()); AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] { // See: nanoflann/examples/vector_of_vectors_example.cpp auto x_data = x.data_ptr(); auto y_data = y.data_ptr(); + auto out_vec_dist_sqr_data = out_vec_dist_sqr.data_ptr(); typedef std::vector> vec_t; if (!ptr_x.has_value()) { // Single example. @@ -54,6 +60,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]); for (size_t j = 0; j < num_matches; j++) { + out_vec_dist_sqr_data[out_vec.size() / 2] = out_dist_sqr[j]; out_vec.push_back(ret_index[j]); out_vec.push_back(i); } @@ -90,6 +97,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]); for (size_t j = 0; j < num_matches; j++) { + out_vec_dist_sqr_data[out_vec.size() / 2] = out_dist_sqr[j]; out_vec.push_back(x_start + ret_index[j]); out_vec.push_back(i); } @@ -101,5 +109,8 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, const int64_t size = out_vec.size() / 2; auto out = torch::from_blob(out_vec.data(), {size, 2}, x.options().dtype(torch::kLong)); - return out.t().index_select(0, torch::tensor({1, 0})); + return std::make_tuple( + out.t().index_select(0, torch::tensor({1, 0})), + out_vec_dist_sqr.index({Slice(None, size)}) + ); } diff --git a/csrc/cpu/knn_cpu.h b/csrc/cpu/knn_cpu.h index 97f11a4..ff3f53b 100644 --- a/csrc/cpu/knn_cpu.h +++ b/csrc/cpu/knn_cpu.h @@ -1,8 +1,11 @@ #pragma once +#include + #include "../extensions.h" -torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y, - torch::optional ptr_x, - torch::optional ptr_y, int64_t k, - int64_t num_workers); +std::tuple +knn_cpu(torch::Tensor x, torch::Tensor y, + torch::optional ptr_x, + torch::optional ptr_y, int64_t k, + int64_t num_workers); diff --git a/csrc/cuda/knn_cuda.cu b/csrc/cuda/knn_cuda.cu index cae8f28..fe32cca 100644 --- a/csrc/cuda/knn_cuda.cu +++ b/csrc/cuda/knn_cuda.cu @@ -32,8 +32,9 @@ __global__ void knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row, int64_t *__restrict__ col, - const int64_t k, const int64_t n, const int64_t m, const int64_t dim, - const int64_t num_examples, const bool cosine) { + scalar_t *__restrict__ dist, const int64_t k, const int64_t n, + const int64_t m, const int64_t dim, const int64_t num_examples, + const bool cosine) { const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x; if (n_y >= m) @@ -80,13 +81,17 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, for (int64_t e = 0; e < k; e++) { row[n_y * k + e] = n_y; col[n_y * k + e] = best_idx[e]; + dist[n_y * k + e] = best_dist[e]; } } -torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, - torch::optional ptr_x, - torch::optional ptr_y, const int64_t k, - const bool cosine) { +std::tuple +knn_cuda(const torch::Tensor x, + const torch::Tensor y, + torch::optional ptr_x, + torch::optional ptr_y, + const int64_t k, + const bool cosine) { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); @@ -117,6 +122,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, auto row = torch::empty({y.size(0) * k}, ptr_y.value().options()); auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options()); + auto dist = torch::empty({y.size(0) * k}, y.options()); dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS); @@ -126,10 +132,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, knn_kernel<<>>( x.data_ptr(), y.data_ptr(), ptr_x.value().data_ptr(), ptr_y.value().data_ptr(), - row.data_ptr(), col.data_ptr(), k, x.size(0), - y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine); + row.data_ptr(), col.data_ptr(), dist.data_ptr(), + k, x.size(0), y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine); }); auto mask = col != -1; - return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0); + return std::make_tuple( + torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0), + dist.masked_select(mask) + ); } diff --git a/csrc/cuda/knn_cuda.h b/csrc/cuda/knn_cuda.h index 4e732a8..ad31128 100644 --- a/csrc/cuda/knn_cuda.h +++ b/csrc/cuda/knn_cuda.h @@ -1,8 +1,11 @@ #pragma once +#include + #include "../extensions.h" -torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, - torch::optional ptr_x, - torch::optional ptr_y, int64_t k, - bool cosine); +std::tuple +knn_cuda(torch::Tensor x, torch::Tensor y, + torch::optional ptr_x, + torch::optional ptr_y, int64_t k, + bool cosine); diff --git a/csrc/knn.cpp b/csrc/knn.cpp index cfa1e4b..3c28bb3 100644 --- a/csrc/knn.cpp +++ b/csrc/knn.cpp @@ -15,10 +15,11 @@ PyMODINIT_FUNC PyInit__knn_cpu(void) { return NULL; } #endif #endif -CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y, - torch::optional ptr_x, - torch::optional ptr_y, int64_t k, bool cosine, - int64_t num_workers) { +CLUSTER_API std::tuple knn( + torch::Tensor x, torch::Tensor y, + torch::optional ptr_x, + torch::optional ptr_y, int64_t k, bool cosine, + int64_t num_workers) { if (x.device().is_cuda()) { #ifdef WITH_CUDA return knn_cuda(x, y, ptr_x, ptr_y, k, cosine); diff --git a/test/test_knn.py b/test/test_knn.py index 2600812..3d860c9 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -1,3 +1,4 @@ +import math from itertools import product import pytest @@ -32,21 +33,67 @@ def test_knn(dtype, device): batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device) - edge_index = knn(x, y, 2) + edge_index, distances = knn(x, y, 2, return_distances=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) - edge_index = knn(x, y, 2, batch_x, batch_y) + edge_index, distances = knn(x, y, 2, batch_x, batch_y, return_distances=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) if x.is_cuda: - edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True) + edge_index, distances = knn( + x, y, 2, batch_x, batch_y, cosine=True, return_distances=True + ) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + assert torch.allclose(distances, distances.new_tensor( + [1.0 - math.cos(math.pi / 4.0) for _ in range(4)] + )) # Skipping a batch batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device) batch_y = tensor([0, 2], torch.long, device) - edge_index = knn(x, y, 2, batch_x, batch_y) + edge_index,distances = knn(x, y, 2, batch_x, batch_y, return_distances=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) + + +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_knn_jit(dtype, device): + @torch.jit.script + def knn_jit(x: torch.Tensor, y: torch.Tensor, k: int, batch_x: torch.Tensor, + batch_y: torch.Tensor): + return knn(x, y, k, batch_x, batch_y) + + @torch.jit.script + def knn_jit_distance(x: torch.Tensor, y: torch.Tensor, k: int, + batch_x: torch.Tensor, batch_y: torch.Tensor): + return knn(x, y, k, batch_x, batch_y, return_distances=True) + + x = tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype, device) + y = tensor([ + [1, 0], + [-1, 0], + ], dtype, device) + + batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) + batch_y = tensor([0, 1], torch.long, device) + + edge_index = knn_jit(x, y, 2, batch_x, batch_y) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + + edge_index, distances = knn_jit_distance(x, y, 2, batch_x, batch_y) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @@ -58,23 +105,60 @@ def test_knn_graph(dtype, device): [+1, -1], ], dtype, device) - edge_index = knn_graph(x, k=2, flow='target_to_source') + edge_index, distances = knn_graph( + x, k=2, flow='target_to_source', return_distances=True + ) assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 3), (3, 0), (3, 2)]) + assert torch.allclose(distances, distances.new_tensor([4.0 for _ in range(8)])) - edge_index = knn_graph(x, k=2, flow='source_to_target') + edge_index = knn_graph( + x, k=2, flow='source_to_target', return_distances=False + ) assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_knn_graph_jit(dtype, device): + @torch.jit.script + def knn_graph_jit(x: torch.Tensor, k: int): + return knn_graph(x, k, flow="target_to_source") + + @torch.jit.script + def knn_graph_jit_distance(x: torch.Tensor, k: int): + return knn_graph(x, k, flow="target_to_source", return_distances=True) + + x = tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype, device) + + edge_index = knn_graph_jit(x, k=2) + assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), + (2, 3), (3, 0), (3, 2)]) + + edge_index, distances = knn_graph_jit_distance(x, k=2) + assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), + (2, 3), (3, 0), (3, 2)]) + assert torch.allclose(distances, distances.new_tensor([4.0 for _ in range(8)])) + + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_knn_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device) - edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True) + edge_index, distances = knn_graph( + x, k=5, flow='target_to_source', loop=True, return_distances=True + ) tree = scipy.spatial.cKDTree(x.cpu().numpy()) - _, col = tree.query(x.cpu(), k=5) + dist, col = tree.query(x.cpu(), k=5) truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) assert to_set(edge_index.cpu()) == truth + assert torch.allclose( + distances, torch.from_numpy(dist).to(distances).flatten().pow(2) + ) diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index b09899e..5dfb5a7 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -1,13 +1,18 @@ -from typing import Optional +from typing import Optional, Union, Tuple import torch -@torch.jit.script -def knn(x: torch.Tensor, y: torch.Tensor, k: int, - batch_x: Optional[torch.Tensor] = None, - batch_y: Optional[torch.Tensor] = None, cosine: bool = False, - num_workers: int = 1) -> torch.Tensor: +def _knn_impl( + x: torch.Tensor, + y: torch.Tensor, + k: int, + batch_x: Optional[torch.Tensor] = None, + batch_y: Optional[torch.Tensor] = None, + cosine: bool = False, + num_workers: int = 1, + return_distances: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`. @@ -31,8 +36,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + return_distances (boolean, optional): if :attr:`return_distances` is + True, there will be an additional returned tensor (same shape as + output.size(1)) representing the distance between the pair of + nodes for each edge. If :attr:`cosine` is False, the returned + distance is the squared Euclidean distance. If :attr:`cosine` is + True, the returned distance is :math:`1 - \frac{\mathbf{a} \cdot + \mathbf{b}}{\left\lVert \mathbf{a} \right\rVert \left\lVert + \mathbf{b} \right\rVert}`. (default: :obj:`False`) - :rtype: :class:`LongTensor` + :rtype: (:class:`LongTensor`, :class:`Tensor`) .. code-block:: python @@ -43,7 +56,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, batch_x = torch.tensor([0, 0, 0, 0]) y = torch.Tensor([[-1, 0], [1, 0]]) batch_y = torch.tensor([0, 0]) - assign_index = knn(x, y, 2, batch_x, batch_y) + assign_index, distance = knn(x, y, 2, batch_x, batch_y, + return_distances=True) """ x = x.view(-1, 1) if x.dim() == 1 else x @@ -71,10 +85,44 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, num_workers) -@torch.jit.script -def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, - loop: bool = False, flow: str = 'source_to_target', - cosine: bool = False, num_workers: int = 1) -> torch.Tensor: +def _knn_return_output( + x: torch.Tensor, + y: torch.Tensor, + k: int, + batch_x: Optional[torch.Tensor] = None, + batch_y: Optional[torch.Tensor] = None, + cosine: bool = False, + num_workers: int = 1, + return_distances: bool = False +) -> torch.Tensor: + output, _ = _knn_impl(x, y, k, batch_x, batch_y, cosine, num_workers, + return_distances) + + return output + + +knn = torch._jit_internal.boolean_dispatch( + arg_name="return_distances", + arg_index=7, + default=False, + if_true=_knn_impl, + if_false=_knn_return_output, + module_name=__name__, + func_name="knn" +) +knn.__doc__ = _knn_impl.__doc__ + + +def _knn_graph_impl( + x: torch.Tensor, + k: int, + batch: Optional[torch.Tensor] = None, + loop: bool = False, + flow: str = 'source_to_target', + cosine: bool = False, + num_workers: int = 1, + return_distances: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: r"""Computes graph edges to the nearest :obj:`k` points. Args: @@ -96,8 +144,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + return_distances (boolean, optional): if :attr:`return_distances` is + True, there will be an additional returned tensor (same shape as + output.size(1)) representing the distance between the pair of + nodes for each edge. If :attr:`cosine` is False, the returned + distance is the squared Euclidean distance. If :attr:`cosine` is + True, the returned distance is :math:`1 - \frac{\mathbf{a} \cdot + \mathbf{b}}{\left\lVert \mathbf{a} \right\rVert \left\lVert + \mathbf{b} \right\rVert}`. (default: :obj:`False`) - :rtype: :class:`LongTensor` + :rtype: (:class:`LongTensor`, :class:`Tensor`) .. code-block:: python @@ -106,12 +162,13 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch = torch.tensor([0, 0, 0, 0]) - edge_index = knn_graph(x, k=2, batch=batch, loop=False) + edge_index, distance = knn_graph(x, k=2, batch=batch, loop=False, + return_distances=True) """ assert flow in ['source_to_target', 'target_to_source'] - edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, - num_workers) + edge_index, distances = knn(x, x, k if loop else k + 1, batch, batch, + cosine, num_workers, return_distances=True) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] @@ -121,5 +178,36 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, if not loop: mask = row != col row, col = row[mask], col[mask] - - return torch.stack([row, col], dim=0) + distances = distances[mask] + + edge_index = torch.stack([row, col], dim=0) + + return edge_index, distances + + +def _knn_graph_return_output( + x: torch.Tensor, + k: int, + batch: Optional[torch.Tensor] = None, + loop: bool = False, + flow: str = 'source_to_target', + cosine: bool = False, + num_workers: int = 1, + return_distances: bool = False, +) -> torch.Tensor: + output, _ = _knn_graph_impl(x, k, batch, loop, flow, cosine, num_workers, + return_distances) + + return output + + +knn_graph = torch._jit_internal.boolean_dispatch( + arg_name="return_distances", + arg_index=7, + default=False, + if_true=_knn_graph_impl, + if_false=_knn_graph_return_output, + module_name=__name__, + func_name="knn_graph" +) +knn_graph.__doc__ = _knn_graph_impl.__doc__ From 01e79655443dce456613555ad9f19ec191b64703 Mon Sep 17 00:00:00 2001 From: Zhe Chen Date: Mon, 27 Jun 2022 14:49:28 +0900 Subject: [PATCH 2/4] Fix code style --- test/test_knn.py | 28 ++++++++++++++++++---------- torch_cluster/knn.py | 10 +++++----- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/test/test_knn.py b/test/test_knn.py index 3d860c9..f295d6d 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -35,11 +35,14 @@ def test_knn(dtype, device): edge_index, distances = knn(x, y, 2, return_distances=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) - assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) + assert torch.allclose(distances, + distances.new_tensor([1.0, 1.0, 1.0, 1.0])) - edge_index, distances = knn(x, y, 2, batch_x, batch_y, return_distances=True) + edge_index, distances = knn(x, y, 2, batch_x, batch_y, + return_distances=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) - assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) + assert torch.allclose(distances, + distances.new_tensor([1.0, 1.0, 1.0, 1.0])) if x.is_cuda: edge_index, distances = knn( @@ -53,16 +56,18 @@ def test_knn(dtype, device): # Skipping a batch batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device) batch_y = tensor([0, 2], torch.long, device) - edge_index,distances = knn(x, y, 2, batch_x, batch_y, return_distances=True) + edge_index, distances = knn(x, y, 2, batch_x, batch_y, + return_distances=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) - assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) + assert torch.allclose(distances, + distances.new_tensor([1.0, 1.0, 1.0, 1.0])) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) def test_knn_jit(dtype, device): @torch.jit.script - def knn_jit(x: torch.Tensor, y: torch.Tensor, k: int, batch_x: torch.Tensor, - batch_y: torch.Tensor): + def knn_jit(x: torch.Tensor, y: torch.Tensor, k: int, + batch_x: torch.Tensor, batch_y: torch.Tensor): return knn(x, y, k, batch_x, batch_y) @torch.jit.script @@ -93,7 +98,8 @@ def knn_jit_distance(x: torch.Tensor, y: torch.Tensor, k: int, edge_index, distances = knn_jit_distance(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) - assert torch.allclose(distances, distances.new_tensor([1.0, 1.0, 1.0, 1.0])) + assert torch.allclose(distances, + distances.new_tensor([1.0, 1.0, 1.0, 1.0])) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @@ -110,7 +116,8 @@ def test_knn_graph(dtype, device): ) assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 3), (3, 0), (3, 2)]) - assert torch.allclose(distances, distances.new_tensor([4.0 for _ in range(8)])) + assert torch.allclose(distances, + distances.new_tensor([4.0 for _ in range(8)])) edge_index = knn_graph( x, k=2, flow='source_to_target', return_distances=False @@ -143,7 +150,8 @@ def knn_graph_jit_distance(x: torch.Tensor, k: int): edge_index, distances = knn_graph_jit_distance(x, k=2) assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 3), (3, 0), (3, 2)]) - assert torch.allclose(distances, distances.new_tensor([4.0 for _ in range(8)])) + assert torch.allclose(distances, + distances.new_tensor([4.0 for _ in range(8)])) @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index 5dfb5a7..f1876c2 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple +from typing import Optional, Tuple import torch @@ -95,8 +95,8 @@ def _knn_return_output( num_workers: int = 1, return_distances: bool = False ) -> torch.Tensor: - output, _ = _knn_impl(x, y, k, batch_x, batch_y, cosine, num_workers, - return_distances) + output, _ = _knn_impl(x, y, k, batch_x, batch_y, cosine, + num_workers, return_distances) return output @@ -168,7 +168,7 @@ def _knn_graph_impl( assert flow in ['source_to_target', 'target_to_source'] edge_index, distances = knn(x, x, k if loop else k + 1, batch, batch, - cosine, num_workers, return_distances=True) + cosine, num_workers, return_distances=True) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] @@ -196,7 +196,7 @@ def _knn_graph_return_output( return_distances: bool = False, ) -> torch.Tensor: output, _ = _knn_graph_impl(x, k, batch, loop, flow, cosine, num_workers, - return_distances) + return_distances) return output From 0497082ad68d585bf96f5eb5584c4c6b83da46cd Mon Sep 17 00:00:00 2001 From: Zhe Chen Date: Mon, 27 Jun 2022 18:51:48 +0900 Subject: [PATCH 3/4] Make coverage tester happy --- torch_cluster/knn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index f1876c2..2039f31 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -94,7 +94,7 @@ def _knn_return_output( cosine: bool = False, num_workers: int = 1, return_distances: bool = False -) -> torch.Tensor: +) -> torch.Tensor: # pragma: no cover output, _ = _knn_impl(x, y, k, batch_x, batch_y, cosine, num_workers, return_distances) From a6a5aca2cee8fcce92301731beb0fc18c26b5265 Mon Sep 17 00:00:00 2001 From: Zhe Chen Date: Mon, 27 Jun 2022 18:54:27 +0900 Subject: [PATCH 4/4] Fix code style --- torch_cluster/knn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index 2039f31..ff64fa4 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -94,7 +94,7 @@ def _knn_return_output( cosine: bool = False, num_workers: int = 1, return_distances: bool = False -) -> torch.Tensor: # pragma: no cover +) -> torch.Tensor: # pragma: no cover output, _ = _knn_impl(x, y, k, batch_x, batch_y, cosine, num_workers, return_distances)