-
Notifications
You must be signed in to change notification settings - Fork 345
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add vector retrieve * [feat]: remove graphlearn dependency
- Loading branch information
1 parent
f7e36c7
commit 1a44891
Showing
15 changed files
with
566 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# 向量近邻检索 | ||
|
||
## Pai 命令 | ||
|
||
```sql | ||
pai -name easy_rec_ext -project algo_public | ||
-Dcmd=vector_retrieve | ||
-Dquery_table=odps://pai_online_project/tables/query_vector_table | ||
-Ddoc_table=odps://pai_online_project/tables/doc_vector_table | ||
-Doutput_table=odps://pai_online_project/tables/result_vector_table | ||
-Dcluster='{"worker" : {"count":3, "cpu":600, "gpu":100, "memory":10000}}' | ||
-Dknn_distance=inner_product | ||
-Dknn_num_neighbours=100 | ||
-Dknn_feature_dims=128 | ||
-Dknn_index_type=gpu_ivfflat | ||
-Dknn_feature_delimiter=',' | ||
-Dbuckets='oss://${oss_bucket}/' | ||
-Darn='acs:ram::${xxxxxxxxxxxxx}:role/AliyunODPSPAIDefaultRole' | ||
-DossHost='oss-cn-hangzhou-internal.aliyuncs.com' | ||
``` | ||
|
||
## 参数说明 | ||
|
||
| 参数名 | 默认值 | 参数说明 | | ||
| --------------------- | ------------- | ------------------------------------------------------------------------- | | ||
| query_table | 无 | 输入查询表, schema: (id bigint, vector string) | | ||
| doc_table | 无 | 输入索引表, schema: (id bigint, vector string) | | ||
| output_table | 无 | 输出表, schema: (query_id bigint, doc_id bigint, distance double) | | ||
| knn_distance | inner_product | 计算距离的方法:l2、inner_product | | ||
| knn_num_neighbours | 无 | top n, 每个query输出多少个近邻 | | ||
| knn_feature_dims | 无 | 向量维度 | | ||
| knn_feature_delimiter | , | 向量字符串分隔符 | | ||
| knn_index_type | ivfflat | 向量索引类型:'flat', 'ivfflat', 'ivfpq', 'gpu_flat', 'gpu_ivfflat', 'gpu_ivfpg' | | ||
| knn_nlist | 5 | 聚类的簇个数, number of split cluster on each worker | | ||
| knn_nprobe | 2 | 检索时只考虑距离与输入向量最近的簇个数, number of probe part on each worker | | ||
| knn_compress_dim | 8 | 当index_type为`ivfpq` and `gpu_ivfpq`时, 指定压缩的维度,必须为float属性个数的因子 | | ||
|
||
## 使用示例 | ||
|
||
### 1. 创建查询表 | ||
|
||
```sql | ||
create table doc_table(pk BIGINT,vector string) partitioned by (pt string); | ||
|
||
INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410') | ||
VALUES | ||
(1, '0.1,0.2,-0.4,0.5'), | ||
(2, '-0.1,0.8,0.4,0.5'), | ||
(3, '0.59,0.2,0.4,0.15'), | ||
(10, '0.1,-0.2,0.4,-0.5'), | ||
(20, '-0.1,-0.2,0.4,0.5'), | ||
(30, '0.5,0.2,0.43,0.15') | ||
; | ||
``` | ||
|
||
### 2. 创建索引表 | ||
|
||
```sql | ||
create table query_table(pk BIGINT,vector string) partitioned by (pt string); | ||
|
||
INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410') | ||
VALUES | ||
(1, '0.1,0.2,0.4,0.5'), | ||
(2, '-0.1,0.2,0.4,0.5'), | ||
(3, '0.5,0.2,0.4,0.5'), | ||
(10, '0.1,0.2,0.4,0.5'), | ||
(20, '-0.1,-0.2,0.4,0.5'), | ||
(30, '0.5,0.2,0.43,0.15') | ||
; | ||
``` | ||
|
||
### 3. 执行向量检索 | ||
|
||
```sql | ||
pai -name easy_rec_ext -project algo_public_dev | ||
-Dcmd='vector_retrieve' | ||
-DentryFile='run.py' | ||
-Dquery_table='odps://${project}/tables/query_table/pt=20190410' | ||
-Ddoc_table='odps://${project}/tables/doc_table/pt=20190410' | ||
-Doutput_table='odps://${project}/tables/knn_result_table/pt=20190410' | ||
-Dknn_distance=inner_product | ||
-Dknn_num_neighbours=2 | ||
-Dknn_feature_dims=4 | ||
-Dknn_index_type='ivfflat' | ||
-Dknn_feature_delimiter=',' | ||
-Dbuckets='oss://${oss_bucket}/' | ||
-Darn='acs:ram::${xxxxxxxxxxxxx}:role/AliyunODPSPAIDefaultRole' | ||
-DossHost='oss-cn-shenzhen-internal.aliyuncs.com' | ||
-Dcluster='{ | ||
\"worker\" : { | ||
\"count\" : 1, | ||
\"cpu\" : 600 | ||
} | ||
}'; | ||
; | ||
``` | ||
|
||
### 4. 查看结果 | ||
|
||
```sql | ||
SELECT * from knn_result_table where pt='20190410'; | ||
|
||
-- query doc distance | ||
-- 1 3 0.17999999225139618 | ||
-- 1 1 0.13999998569488525 | ||
-- 2 2 0.5800000429153442 | ||
-- 2 1 0.5600000619888306 | ||
-- 3 3 0.5699999928474426 | ||
-- 3 30 0.5295000076293945 | ||
-- 10 30 0.10700000077486038 | ||
-- 10 20 -0.0599999874830246 | ||
-- 20 20 0.46000003814697266 | ||
-- 20 2 0.3800000250339508 | ||
-- 30 3 0.5370000004768372 | ||
-- 30 30 0.4973999857902527 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import logging | ||
from datetime import datetime | ||
|
||
import common_io | ||
import numpy as np | ||
import tensorflow as tf | ||
try: | ||
import graphlearn as gl | ||
except: | ||
logging.WARN( | ||
'GraphLearn is not installed. You can install it by "pip install http://odps-release.cn-hangzhou.oss-cdn.aliyun-inc.com/graphlearn/tunnel/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl."' # noqa: E501 | ||
) | ||
|
||
|
||
if tf.__version__ >= '2.0': | ||
tf = tf.compat.v1 | ||
|
||
|
||
class VectorRetrieve(object): | ||
|
||
def __init__(self, | ||
query_table, | ||
doc_table, | ||
out_table, | ||
ndim, | ||
delimiter=',', | ||
batch_size=4, | ||
index_type='ivfflat', | ||
nlist=10, | ||
nprobe=2, | ||
distance=1, | ||
m=8): | ||
"""Retrieve top n neighbours by query vector. | ||
Args: | ||
query_table: query vector table | ||
doc_table: document vector table | ||
out_table: output table | ||
ndim: int, number of feature dimensions | ||
delimiter: delimiter for feature vectors | ||
batch_size: query batch size | ||
index_type: search model `flat`, `ivfflat`, `ivfpq`, `gpu_ivfflat` | ||
nlist: number of split part on each worker | ||
nprobe: probe part on each worker | ||
distance: type of distance, 0 is l2 distance(default), 1 is inner product. | ||
m: number of dimensions for each node after compress | ||
""" | ||
self.query_table = query_table | ||
self.doc_table = doc_table | ||
self.out_table = out_table | ||
self.ndim = ndim | ||
self.delimiter = delimiter | ||
self.batch_size = batch_size | ||
|
||
gl.set_inter_threadnum(8) | ||
gl.set_knn_metric(distance) | ||
knn_option = gl.IndexOption() | ||
knn_option.name = 'knn' | ||
knn_option.index_type = index_type | ||
knn_option.nlist = nlist | ||
knn_option.nprobe = nprobe | ||
knn_option.m = m | ||
self.knn_option = knn_option | ||
|
||
def __call__(self, top_n, task_index, task_count, *args, **kwargs): | ||
g = gl.Graph() | ||
g.node( | ||
self.doc_table, | ||
'doc', | ||
decoder=gl.Decoder( | ||
attr_types=['float'] * self.ndim, attr_delimiter=self.delimiter), | ||
option=self.knn_option) | ||
g.init(task_index=task_index, task_count=task_count) | ||
|
||
query_reader = common_io.table.TableReader( | ||
self.query_table, slice_id=task_index, slice_count=task_count) | ||
num_records = query_reader.get_row_count() | ||
total_batch_num = num_records // self.batch_size + 1.0 | ||
batch_num = 0 | ||
print('total input records: {}'.format(query_reader.get_row_count())) | ||
print('total_batch_num: {}'.format(total_batch_num)) | ||
print('output_table: {}'.format(self.out_table)) | ||
|
||
output_table_writer = common_io.table.TableWriter(self.out_table, | ||
task_index) | ||
count = 0 | ||
while True: | ||
try: | ||
batch_query_nodes, batch_query_feats = zip( | ||
*query_reader.read(self.batch_size, allow_smaller_final_batch=True)) | ||
batch_num += 1.0 | ||
print('{} process: {:.2f}'.format(datetime.now().time(), | ||
batch_num / total_batch_num)) | ||
feats = to_np_array(batch_query_feats, self.delimiter) | ||
rt_ids, rt_dists = g.search('doc', feats, gl.KnnOption(k=top_n)) | ||
|
||
for query_node, nodes, dists in zip(batch_query_nodes, rt_ids, | ||
rt_dists): | ||
query = np.array([query_node] * len(nodes), dtype='int64') | ||
output_table_writer.write( | ||
zip(query, nodes, dists), (0, 1, 2), allow_type_cast=False) | ||
count += 1 | ||
if np.mod(count, 100) == 0: | ||
print('write ', count, ' query nodes totally') | ||
except Exception as e: | ||
print(e) | ||
break | ||
|
||
print('==finished==') | ||
output_table_writer.close() | ||
query_reader.close() | ||
g.close() | ||
|
||
|
||
def to_np_array(batch_query_feats, attr_delimiter): | ||
return np.array( | ||
[map(float, feat.split(attr_delimiter)) for feat in batch_query_feats], | ||
dtype='float32') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.