Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the code for GPU-Reranking #380

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions fastreid/evaluation/GPU-Re-Ranking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

[[Paper]](https://arxiv.org/abs/2012.07620v2)

On the Market-1501 dataset, we accelerate the re-ranking processing from **89.2s** to **9.4ms** with one K40m GPU, facilitating the real-time post-processing.
Similarly, we observe that our method achieves comparable or even better retrieval results on the other four image retrieval benchmarks,
i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, with limited time cost.

## Prerequisites

The code was mainly developed and tested with python 3.7, PyTorch 1.4.1, CUDA 10.2, and CentOS release 6.10.

The code has been included in `/extension`. To compile it:

```shell
cd extension
sh make.sh
```

## Demo

The demo script `main.py` provides the gnn re-ranking method using the prepared feature.

```shell
python main.py --data_path PATH_TO_DATA --k1 26 --k2 7
```

## Citation
```bibtex
@article{zhang2020understanding,
title={Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective},
author={Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang},
journal={arXiv preprint arXiv:2012.07620},
year={2020}
}
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/extension.h>
#include <iostream>
#include <set>

at::Tensor build_adjacency_matrix_forward(torch::Tensor initial_rank);


#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

at::Tensor build_adjacency_matrix(at::Tensor initial_rank) {
CHECK_INPUT(initial_rank);
return build_adjacency_matrix_forward(initial_rank);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &build_adjacency_matrix, "build_adjacency_matrix (CUDA)");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>

#define CUDA_1D_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)


__global__ void build_adjacency_matrix_kernel(float* initial_rank, float* A, const int total_num, const int topk, const int nthreads, const int all_num) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < all_num; i += stride) {
int ii = i / topk;
A[ii * total_num + int(initial_rank[i])] = float(1.0);
}
}

at::Tensor build_adjacency_matrix_forward(at::Tensor initial_rank) {
const auto total_num = initial_rank.size(0);
const auto topk = initial_rank.size(1);
const auto all_num = total_num * topk;
auto A = torch::zeros({total_num, total_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float));

const int threads = 1024;
const int blocks = (all_num + threads - 1) / threads;

build_adjacency_matrix_kernel<<<blocks, threads>>>(initial_rank.data_ptr<float>(), A.data_ptr<float>(), total_num, topk, threads, all_num);
return A;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

from setuptools import setup, Extension

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
name='build_adjacency_matrix',
ext_modules=[
CUDAExtension('build_adjacency_matrix', [
'build_adjacency_matrix.cpp',
'build_adjacency_matrix_kernel.cu',
]),
],
cmdclass={
'build_ext':BuildExtension
})
4 changes: 4 additions & 0 deletions fastreid/evaluation/GPU-Re-Ranking/extension/make.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
cd adjacency_matrix
python setup.py install
cd ../propagation
python setup.py install
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <torch/extension.h>
#include <iostream>
#include <set>

at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S);


#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

at::Tensor gnn_propagate(at::Tensor A ,at::Tensor initial_rank, at::Tensor S) {
CHECK_INPUT(A);
CHECK_INPUT(initial_rank);
CHECK_INPUT(S);
return gnn_propagate_forward(A, initial_rank, S);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &gnn_propagate, "gnn propagate (CUDA)");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <iostream>

__global__ void gnn_propagate_forward_kernel(float* initial_rank, float* A, float* A_qe, float* S, const int sample_num, const int topk, const int total_num) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < total_num; i += stride) {
int fea = i % sample_num;
int sample_index = i / sample_num;
float sum = 0.0;
for (int j = 0; j < topk ; j++) {
int topk_fea_index = int(initial_rank[sample_index*topk+j]) * sample_num + fea;
sum += A[ topk_fea_index] * S[sample_index*topk+j];
}
A_qe[i] = sum;
}
}

at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S) {
const auto sample_num = A.size(0);
const auto topk = initial_rank.size(1);

const auto total_num = sample_num * sample_num ;
auto A_qe = torch::zeros({sample_num, sample_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float));

const int threads = 1024;
const int blocks = (total_num + threads - 1) / threads;

gnn_propagate_forward_kernel<<<blocks, threads>>>(initial_rank.data_ptr<float>(), A.data_ptr<float>(), A_qe.data_ptr<float>(), S.data_ptr<float>(), sample_num, topk, total_num);
return A_qe;

}
37 changes: 37 additions & 0 deletions fastreid/evaluation/GPU-Re-Ranking/extension/propagation/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

from setuptools import setup, Extension

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
name='gnn_propagate',
ext_modules=[
CUDAExtension('gnn_propagate', [
'gnn_propagate.cpp',
'gnn_propagate_kernel.cu',
]),
],
cmdclass={
'build_ext':BuildExtension
})
57 changes: 57 additions & 0 deletions fastreid/evaluation/GPU-Re-Ranking/gnn_reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

import torch
import numpy as np

import build_adjacency_matrix
import gnn_propagate

from utils import *



def gnn_reranking(X_q, X_g, k1, k2):
query_num, gallery_num = X_q.shape[0], X_g.shape[0]

X_u = torch.cat((X_q, X_g), axis = 0)
original_score = torch.mm(X_u, X_u.t())
del X_u, X_q, X_g

# initial ranking list
S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True)

# stage 1
A = build_adjacency_matrix.forward(initial_rank.float())
S = S * S

# stage 2
if k2 != 1:
for i in range(2):
A = A + A.T
A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float(), S[:, :k2].contiguous().float())
A_norm = torch.norm(A, p=2, dim=1, keepdim=True)
A = A.div(A_norm.expand_as(A))


cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t())
del A, S

L = torch.sort(-cosine_similarity, dim = 1)[1]
L = L.data.cpu().numpy()
return L
62 changes: 62 additions & 0 deletions fastreid/evaluation/GPU-Re-Ranking/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

import os
import torch
import argparse
import numpy as np

from utils import *
from gnn_reranking import *

parser = argparse.ArgumentParser(description='Reranking_is_GNN')
parser.add_argument('--data_path',
type=str,
default='../xm_rerank_gpu_2/features/market_88_test.pkl',
help='path to dataset')
parser.add_argument('--k1',
type=int,
default=26, # Market-1501
# default=60, # Veri-776
help='parameter k1')
parser.add_argument('--k2',
type=int,
default=7, # Market-1501
# default=10, # Veri-776
help='parameter k2')

args = parser.parse_args()

def main():
data = load_pickle(args.data_path)

query_cam = data['query_cam']
query_label = data['query_label']
gallery_cam = data['gallery_cam']
gallery_label = data['gallery_label']

gallery_feature = torch.FloatTensor(data['gallery_f'])
query_feature = torch.FloatTensor(data['query_f'])
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()

indices = gnn_reranking(query_feature, gallery_feature, args.k1, args.k2)
evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam)

if __name__ == '__main__':
main()
Loading