From 1a44891c557935eec29c6dc0d4dfc4a543fbd39a Mon Sep 17 00:00:00 2001 From: yangxudong Date: Sat, 19 Feb 2022 13:46:14 +0800 Subject: [PATCH] [feat] add vector retrieve (#109) * add vector retrieve * [feat]: remove graphlearn dependency --- README.md | 6 +- docs/source/feature/feature.rst | 5 +- docs/source/intro.md | 6 +- docs/source/vector_retrieve.md | 116 ++++++++++++++++ .../python/feature_column/feature_column.py | 8 +- easy_rec/python/inference/vector_retrieve.py | 124 ++++++++++++++++++ easy_rec/python/test/odps_run.py | 12 ++ pai_jobs/easy_rec_flow/easy_rec.lua | 75 ++++++++++- pai_jobs/easy_rec_flow/easy_rec.xml | 24 ++++ pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua | 75 ++++++++++- pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml | 24 ++++ pai_jobs/run.py | 56 +++++++- .../create_inner_vector_table.sql | 35 +++++ .../vector_retrieve/drop_table.sql | 3 + .../vector_retrieve/run_vector_retrieve.sql | 16 +++ 15 files changed, 566 insertions(+), 19 deletions(-) create mode 100644 docs/source/vector_retrieve.md create mode 100644 easy_rec/python/inference/vector_retrieve.py create mode 100644 samples/odps_script/vector_retrieve/create_inner_vector_table.sql create mode 100644 samples/odps_script/vector_retrieve/drop_table.sql create mode 100644 samples/odps_script/vector_retrieve/run_vector_retrieve.sql diff --git a/README.md b/README.md index bfaf166fd..effdc2f42 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@   -## What is EasyRec? +## What is EasyRec? ![intro.png](docs/images/intro.png) @@ -58,6 +58,10 @@ EasyRec implements state of the art deep learning models used in common recommed - Easy to implement [customized models](docs/source/models/user_define.md) - Not need to care about data pipelines +### Fast vector retrieve + +- Run [knn algorithm](docs/source/vector_retrieve.md) of vectors in distribute environment +   ## Get Started diff --git a/docs/source/feature/feature.rst b/docs/source/feature/feature.rst index c6715e0d0..65a4cc1e5 100644 --- a/docs/source/feature/feature.rst +++ b/docs/source/feature/feature.rst @@ -345,9 +345,8 @@ rank模型中配置相应字段: - regularization\_lambda: 变分dropout层的正则化系数设置 - embedding\_wise\_variational\_dropout: 变分dropout层维度是否为embedding维度(true:embedding维度;false:feature维度;默认false) - - - +备注: +**这个配在model_config下面,跟model_class平级** 分隔符 diff --git a/docs/source/intro.md b/docs/source/intro.md index c54dd171a..48a2fede3 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -1,6 +1,6 @@ # EasyRec简介 -## What is EasyRec? +## What is EasyRec? ![intro.png](../images/intro.png) @@ -56,4 +56,8 @@ EasyRec implements state of the art machine learning models used in common recom - Easy to implement customized models - Not need to care about data pipelines +### Fast vector retrieve + +- Run [`knn algorithm`](vector_retrieve.md) of vectors in distribute environment + 欢迎加入【EasyRec推荐算法交流群】,钉钉群号 : 32260796 diff --git a/docs/source/vector_retrieve.md b/docs/source/vector_retrieve.md new file mode 100644 index 000000000..d88cc9704 --- /dev/null +++ b/docs/source/vector_retrieve.md @@ -0,0 +1,116 @@ +# 向量近邻检索 + +## Pai 命令 + +```sql +pai -name easy_rec_ext -project algo_public +-Dcmd=vector_retrieve +-Dquery_table=odps://pai_online_project/tables/query_vector_table +-Ddoc_table=odps://pai_online_project/tables/doc_vector_table +-Doutput_table=odps://pai_online_project/tables/result_vector_table +-Dcluster='{"worker" : {"count":3, "cpu":600, "gpu":100, "memory":10000}}' +-Dknn_distance=inner_product +-Dknn_num_neighbours=100 +-Dknn_feature_dims=128 +-Dknn_index_type=gpu_ivfflat +-Dknn_feature_delimiter=',' +-Dbuckets='oss://${oss_bucket}/' +-Darn='acs:ram::${xxxxxxxxxxxxx}:role/AliyunODPSPAIDefaultRole' +-DossHost='oss-cn-hangzhou-internal.aliyuncs.com' +``` + +## 参数说明 + +| 参数名 | 默认值 | 参数说明 | +| --------------------- | ------------- | ------------------------------------------------------------------------- | +| query_table | 无 | 输入查询表, schema: (id bigint, vector string) | +| doc_table | 无 | 输入索引表, schema: (id bigint, vector string) | +| output_table | 无 | 输出表, schema: (query_id bigint, doc_id bigint, distance double) | +| knn_distance | inner_product | 计算距离的方法:l2、inner_product | +| knn_num_neighbours | 无 | top n, 每个query输出多少个近邻 | +| knn_feature_dims | 无 | 向量维度 | +| knn_feature_delimiter | , | 向量字符串分隔符 | +| knn_index_type | ivfflat | 向量索引类型:'flat', 'ivfflat', 'ivfpq', 'gpu_flat', 'gpu_ivfflat', 'gpu_ivfpg' | +| knn_nlist | 5 | 聚类的簇个数, number of split cluster on each worker | +| knn_nprobe | 2 | 检索时只考虑距离与输入向量最近的簇个数, number of probe part on each worker | +| knn_compress_dim | 8 | 当index_type为`ivfpq` and `gpu_ivfpq`时, 指定压缩的维度,必须为float属性个数的因子 | + +## 使用示例 + +### 1. 创建查询表 + +```sql +create table doc_table(pk BIGINT,vector string) partitioned by (pt string); + +INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410') +VALUES + (1, '0.1,0.2,-0.4,0.5'), + (2, '-0.1,0.8,0.4,0.5'), + (3, '0.59,0.2,0.4,0.15'), + (10, '0.1,-0.2,0.4,-0.5'), + (20, '-0.1,-0.2,0.4,0.5'), + (30, '0.5,0.2,0.43,0.15') +; +``` + +### 2. 创建索引表 + +```sql +create table query_table(pk BIGINT,vector string) partitioned by (pt string); + +INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410') +VALUES + (1, '0.1,0.2,0.4,0.5'), + (2, '-0.1,0.2,0.4,0.5'), + (3, '0.5,0.2,0.4,0.5'), + (10, '0.1,0.2,0.4,0.5'), + (20, '-0.1,-0.2,0.4,0.5'), + (30, '0.5,0.2,0.43,0.15') +; +``` + +### 3. 执行向量检索 + +```sql +pai -name easy_rec_ext -project algo_public_dev +-Dcmd='vector_retrieve' +-DentryFile='run.py' +-Dquery_table='odps://${project}/tables/query_table/pt=20190410' +-Ddoc_table='odps://${project}/tables/doc_table/pt=20190410' +-Doutput_table='odps://${project}/tables/knn_result_table/pt=20190410' +-Dknn_distance=inner_product +-Dknn_num_neighbours=2 +-Dknn_feature_dims=4 +-Dknn_index_type='ivfflat' +-Dknn_feature_delimiter=',' +-Dbuckets='oss://${oss_bucket}/' +-Darn='acs:ram::${xxxxxxxxxxxxx}:role/AliyunODPSPAIDefaultRole' +-DossHost='oss-cn-shenzhen-internal.aliyuncs.com' +-Dcluster='{ + \"worker\" : { + \"count\" : 1, + \"cpu\" : 600 + } +}'; +; +``` + +### 4. 查看结果 + +```sql +SELECT * from knn_result_table where pt='20190410'; + +-- query doc distance +-- 1 3 0.17999999225139618 +-- 1 1 0.13999998569488525 +-- 2 2 0.5800000429153442 +-- 2 1 0.5600000619888306 +-- 3 3 0.5699999928474426 +-- 3 30 0.5295000076293945 +-- 10 30 0.10700000077486038 +-- 10 20 -0.0599999874830246 +-- 20 20 0.46000003814697266 +-- 20 2 0.3800000250339508 +-- 30 3 0.5370000004768372 +-- 30 30 0.4973999857902527 +``` \ No newline at end of file diff --git a/easy_rec/python/feature_column/feature_column.py b/easy_rec/python/feature_column/feature_column.py index de6724320..5a208591c 100644 --- a/easy_rec/python/feature_column/feature_column.py +++ b/easy_rec/python/feature_column/feature_column.py @@ -400,11 +400,7 @@ def parse_sequence_feature(self, config): assert config.embedding_dim > 0 - if config.HasField('sequence_combiner'): - fc.sequence_combiner = config.sequence_combiner - self._deep_columns[feature_name] = fc - else: - self._add_deep_embedding_column(fc, config) + self._add_deep_embedding_column(fc, config) def _build_partitioner(self, max_partitions): if max_partitions > 1: @@ -477,4 +473,6 @@ def _add_deep_embedding_column(self, fc, config): if config.feature_type != config.SequenceFeature: self._deep_columns[feature_name] = fc else: + if config.HasField('sequence_combiner'): + fc.sequence_combiner = config.sequence_combiner self._sequence_columns[feature_name] = fc diff --git a/easy_rec/python/inference/vector_retrieve.py b/easy_rec/python/inference/vector_retrieve.py new file mode 100644 index 000000000..917853484 --- /dev/null +++ b/easy_rec/python/inference/vector_retrieve.py @@ -0,0 +1,124 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from datetime import datetime + +import common_io +import numpy as np +import tensorflow as tf +try: + import graphlearn as gl +except: + logging.WARN( + 'GraphLearn is not installed. You can install it by "pip install http://odps-release.cn-hangzhou.oss-cdn.aliyun-inc.com/graphlearn/tunnel/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl."' # noqa: E501 + ) + + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class VectorRetrieve(object): + + def __init__(self, + query_table, + doc_table, + out_table, + ndim, + delimiter=',', + batch_size=4, + index_type='ivfflat', + nlist=10, + nprobe=2, + distance=1, + m=8): + """Retrieve top n neighbours by query vector. + + Args: + query_table: query vector table + doc_table: document vector table + out_table: output table + ndim: int, number of feature dimensions + delimiter: delimiter for feature vectors + batch_size: query batch size + index_type: search model `flat`, `ivfflat`, `ivfpq`, `gpu_ivfflat` + nlist: number of split part on each worker + nprobe: probe part on each worker + distance: type of distance, 0 is l2 distance(default), 1 is inner product. + m: number of dimensions for each node after compress + """ + self.query_table = query_table + self.doc_table = doc_table + self.out_table = out_table + self.ndim = ndim + self.delimiter = delimiter + self.batch_size = batch_size + + gl.set_inter_threadnum(8) + gl.set_knn_metric(distance) + knn_option = gl.IndexOption() + knn_option.name = 'knn' + knn_option.index_type = index_type + knn_option.nlist = nlist + knn_option.nprobe = nprobe + knn_option.m = m + self.knn_option = knn_option + + def __call__(self, top_n, task_index, task_count, *args, **kwargs): + g = gl.Graph() + g.node( + self.doc_table, + 'doc', + decoder=gl.Decoder( + attr_types=['float'] * self.ndim, attr_delimiter=self.delimiter), + option=self.knn_option) + g.init(task_index=task_index, task_count=task_count) + + query_reader = common_io.table.TableReader( + self.query_table, slice_id=task_index, slice_count=task_count) + num_records = query_reader.get_row_count() + total_batch_num = num_records // self.batch_size + 1.0 + batch_num = 0 + print('total input records: {}'.format(query_reader.get_row_count())) + print('total_batch_num: {}'.format(total_batch_num)) + print('output_table: {}'.format(self.out_table)) + + output_table_writer = common_io.table.TableWriter(self.out_table, + task_index) + count = 0 + while True: + try: + batch_query_nodes, batch_query_feats = zip( + *query_reader.read(self.batch_size, allow_smaller_final_batch=True)) + batch_num += 1.0 + print('{} process: {:.2f}'.format(datetime.now().time(), + batch_num / total_batch_num)) + feats = to_np_array(batch_query_feats, self.delimiter) + rt_ids, rt_dists = g.search('doc', feats, gl.KnnOption(k=top_n)) + + for query_node, nodes, dists in zip(batch_query_nodes, rt_ids, + rt_dists): + query = np.array([query_node] * len(nodes), dtype='int64') + output_table_writer.write( + zip(query, nodes, dists), (0, 1, 2), allow_type_cast=False) + count += 1 + if np.mod(count, 100) == 0: + print('write ', count, ' query nodes totally') + except Exception as e: + print(e) + break + + print('==finished==') + output_table_writer.close() + query_reader.close() + g.close() + + +def to_np_array(batch_query_feats, attr_delimiter): + return np.array( + [map(float, feat.split(attr_delimiter)) for feat in batch_query_feats], + dtype='float32') diff --git a/easy_rec/python/test/odps_run.py b/easy_rec/python/test/odps_run.py index fd8f2f9a7..4b6179fbc 100644 --- a/easy_rec/python/test/odps_run.py +++ b/easy_rec/python/test/odps_run.py @@ -195,6 +195,18 @@ def test_boundary_test(self): tot.start_test() tot.drop_table() + def test_vector_retrieve(self): + start_files = [ + 'vector_retrieve/create_inner_vector_table.sql' + ] + test_files = [ + 'vector_retrieve/run_vector_retrieve.sql' + ] + end_file = ['vector_retrieve/drop_table.sql'] + tot = OdpsTest(start_files, test_files, end_file, odps_oss_config) + tot.start_test() + tot.drop_table() + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/pai_jobs/easy_rec_flow/easy_rec.lua b/pai_jobs/easy_rec_flow/easy_rec.lua index 45ae08c76..87a165c85 100644 --- a/pai_jobs/easy_rec_flow/easy_rec.lua +++ b/pai_jobs/easy_rec_flow/easy_rec.lua @@ -87,7 +87,10 @@ function getHyperParams(config, cmd, checkpoint_path, model_dir, hpo_param_path, hpo_metric_save_path, saved_model_dir, all_cols, all_col_types, reserved_cols, output_cols, model_outputs, - input_table, output_table, tables, train_tables, + input_table, output_table, tables, query_table, + doc_table, knn_distance, knn_num_neighbours, + knn_feature_dims, knn_index_type, knn_feature_delimiter, + knn_nlist, knn_nprobe, knn_compress_dim, train_tables, eval_tables, boundary_table, batch_size, profiling_file, mask_feature_name, extra_params) hyperParameters = "" @@ -130,6 +133,39 @@ function getHyperParams(config, cmd, checkpoint_path, return hyperParameters, cluster, tables, output_table end + if cmd == "vector_retrieve" then + if cluster == nil or cluster == '' then + error('cluster must be set') + end + checkTable(query_table) + checkTable(doc_table) + checkTable(output_table) + hyperParameters = " --cmd=" .. cmd + hyperParameters = hyperParameters .. " --batch_size=" .. batch_size + hyperParameters = hyperParameters .. " --knn_distance=" .. knn_distance + if knn_num_neighbours ~= nil and knn_num_neighbours ~= '' then + hyperParameters = hyperParameters .. ' --knn_num_neighbours=' .. knn_num_neighbours + end + if knn_feature_dims ~= nil and knn_feature_dims ~= '' then + hyperParameters = hyperParameters .. ' --knn_feature_dims=' .. knn_feature_dims + end + hyperParameters = hyperParameters .. " --knn_index_type=" .. knn_index_type + hyperParameters = hyperParameters .. " --knn_feature_delimiter=" .. knn_feature_delimiter + if knn_nlist ~= nil and knn_nlist ~= '' then + hyperParameters = hyperParameters .. ' --knn_nlist=' .. knn_nlist + end + if knn_nprobe ~= nil and knn_nprobe ~= '' then + hyperParameters = hyperParameters .. ' --knn_nprobe=' .. knn_nprobe + end + if knn_compress_dim ~= nil and knn_compress_dim ~= '' then + hyperParameters = hyperParameters .. ' --knn_compress_dim=' .. knn_compress_dim + end + if extra_params ~= nil and extra_params ~= '' then + hyperParameters = hyperParameters .. extra_params + end + return hyperParameters, cluster, tables, output_table + end + if cmd ~= "custom" then checkConfig(config) end @@ -309,12 +345,12 @@ end function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, reservedCols, lifecycle, outputCol, tables, - trainTables, evalTables, boundaryTable) + trainTables, evalTables, boundaryTable, queryTable, docTable) -- all_cols, all_col_types, selected_cols, reserved_cols, -- create_table_sql, add_partition_sql, tables parameter to runTF if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export' - and cmd ~= 'evaluate' and cmd ~= 'custom' then - error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom') + and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then + error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve') end -- for export @@ -341,6 +377,14 @@ function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, end end + if cmd == 'vector_retrieve' then + inputTable = queryTable + all_tables[queryTable] = table_id + table_id = table_id + 1 + all_tables[docTable] = table_id + table_id = table_id + 1 + end + if cmd == 'train' then -- merge train table and eval table into all_tables if trainTables ~= '' and trainTables ~= nil then @@ -424,6 +468,29 @@ function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, tables = join(tables, ',') + if cmd == 'vector_retrieve' then + if outputTable == nil or outputTable == '' then + error("outputTable is not set") + end + + proj1, table1, partition1 = splitTableParam(outputTable) + out_table_name = proj1 .. "." .. table1 + create_sql = '' + add_partition_sql = '' + if partition1 ~= nil and string.len(partition1) ~= 0 then + local partition_names, parition_values = parseParitionSpec(partition1) + create_partition_str = genCreatePartitionStr(partition_names) + create_sql = string.format("create table if not exists %s (query BIGINT, doc BIGINT, distance DOUBLE) partitioned by %s lifecycle %s;", out_table_name, create_partition_str, lifecycle) + add_partition_sql = genAddPartitionStr(partition_names, parition_values) + add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql) + else + create_sql = string.format("create table %s (query BIGINT, doc BIGINT, distance DOUBLE) lifecycle %s;", out_table_name, lifecycle) + add_partition_sql = string.format("desc %s;", out_table_name) + end + + return "", "", "", "", create_sql, add_partition_sql, tables + end + -- analyze selected_cols excluded_cols for train, evaluate and predict proj0, table0, partition0 = splitTableParam(inputTable) input_col_types, input_cols = getInputTableColTypes(proj0 .. "." .. table0) diff --git a/pai_jobs/easy_rec_flow/easy_rec.xml b/pai_jobs/easy_rec_flow/easy_rec.xml index 9f0cf1ce9..0dc2dea31 100644 --- a/pai_jobs/easy_rec_flow/easy_rec.xml +++ b/pai_jobs/easy_rec_flow/easy_rec.xml @@ -38,6 +38,17 @@ + + + + + + + + + + + @@ -92,6 +103,8 @@ + + @@ -160,6 +173,17 @@ + + + + + + + + + + + diff --git a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua index 61cf0a48f..eea9ec32e 100644 --- a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua +++ b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.lua @@ -109,7 +109,10 @@ function getHyperParams(config, cmd, checkpoint_path, model_dir, hpo_param_path, hpo_metric_save_path, saved_model_dir, all_cols, all_col_types, reserved_cols, output_cols, model_outputs, - input_table, output_table, tables, train_tables, + input_table, output_table, tables, query_table, + doc_table, knn_distance, knn_num_neighbours, + knn_feature_dims, knn_index_type, knn_feature_delimiter, + knn_nlist, knn_nprobe, knn_compress_dim, train_tables, eval_tables, boundary_table, batch_size, profiling_file, mask_feature_name, extra_params) hyperParameters = "" @@ -152,6 +155,39 @@ function getHyperParams(config, cmd, checkpoint_path, return hyperParameters, cluster, tables, output_table end + if cmd == "vector_retrieve" then + if cluster == nil or cluster == '' then + error('cluster must be set') + end + checkTable(query_table) + checkTable(doc_table) + checkTable(output_table) + hyperParameters = " --cmd=" .. cmd + hyperParameters = hyperParameters .. " --batch_size=" .. batch_size + hyperParameters = hyperParameters .. " --knn_distance=" .. knn_distance + if knn_num_neighbours ~= nil and knn_num_neighbours ~= '' then + hyperParameters = hyperParameters .. ' --knn_num_neighbours=' .. knn_num_neighbours + end + if knn_feature_dims ~= nil and knn_feature_dims ~= '' then + hyperParameters = hyperParameters .. ' --knn_feature_dims=' .. knn_feature_dims + end + hyperParameters = hyperParameters .. " --knn_index_type=" .. knn_index_type + hyperParameters = hyperParameters .. " --knn_feature_delimiter=" .. knn_feature_delimiter + if knn_nlist ~= nil and knn_nlist ~= '' then + hyperParameters = hyperParameters .. ' --knn_nlist=' .. knn_nlist + end + if knn_nprobe ~= nil and knn_nprobe ~= '' then + hyperParameters = hyperParameters .. ' --knn_nprobe=' .. knn_nprobe + end + if knn_compress_dim ~= nil and knn_compress_dim ~= '' then + hyperParameters = hyperParameters .. ' --knn_compress_dim=' .. knn_compress_dim + end + if extra_params ~= nil and extra_params ~= '' then + hyperParameters = hyperParameters .. extra_params + end + return hyperParameters, cluster, tables, output_table + end + if cmd ~= "custom" then checkConfig(config) end @@ -331,12 +367,12 @@ end function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, reservedCols, lifecycle, outputCol, tables, - trainTables, evalTables, boundaryTable) + trainTables, evalTables, boundaryTable, queryTable, docTable) -- all_cols, all_col_types, selected_cols, reserved_cols, -- create_table_sql, add_partition_sql, tables parameter to runTF if cmd ~= 'train' and cmd ~= 'evaluate' and cmd ~= 'predict' and cmd ~= 'export' - and cmd ~= 'evaluate' and cmd ~= 'custom' then - error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom') + and cmd ~= 'evaluate' and cmd ~= 'custom' and cmd ~= 'vector_retrieve' then + error('invalid cmd: ' .. cmd .. ', should be one of train, evaluate, predict, evaluate, export, custom, vector_retrieve') end -- for export @@ -363,6 +399,14 @@ function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, end end + if cmd == 'vector_retrieve' then + inputTable = queryTable + all_tables[queryTable] = table_id + table_id = table_id + 1 + all_tables[docTable] = table_id + table_id = table_id + 1 + end + if cmd == 'train' then -- merge train table and eval table into all_tables if trainTables ~= '' and trainTables ~= nil then @@ -446,6 +490,29 @@ function parseTable(cmd, inputTable, outputTable, selectedCols, excludedCols, tables = join(tables, ',') + if cmd == 'vector_retrieve' then + if outputTable == nil or outputTable == '' then + error("outputTable is not set") + end + + proj1, table1, partition1 = splitTableParam(outputTable) + out_table_name = proj1 .. "." .. table1 + create_sql = '' + add_partition_sql = '' + if partition1 ~= nil and string.len(partition1) ~= 0 then + local partition_names, parition_values = parseParitionSpec(partition1) + create_partition_str = genCreatePartitionStr(partition_names) + create_sql = string.format("create table if not exists %s (query BIGINT, doc BIGINT, distance DOUBLE) partitioned by %s lifecycle %s;", out_table_name, create_partition_str, lifecycle) + add_partition_sql = genAddPartitionStr(partition_names, parition_values) + add_partition_sql = string.format("alter table %s add if not exists partition %s;", out_table_name, add_partition_sql) + else + create_sql = string.format("create table %s (query BIGINT, doc BIGINT, distance DOUBLE) lifecycle %s;", out_table_name, lifecycle) + add_partition_sql = string.format("desc %s;", out_table_name) + end + + return "", "", "", "", create_sql, add_partition_sql, tables + end + -- analyze selected_cols excluded_cols for train, evaluate and predict proj0, table0, partition0 = splitTableParam(inputTable) input_col_types, input_cols = getInputTableColTypes(proj0 .. "." .. table0) diff --git a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml index 07932e384..f03cb855f 100644 --- a/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml +++ b/pai_jobs/easy_rec_flow_ex/easy_rec_ext.xml @@ -40,6 +40,17 @@ + + + + + + + + + + + @@ -94,6 +105,8 @@ + + @@ -162,6 +175,17 @@ + + + + + + + + + + + diff --git a/pai_jobs/run.py b/pai_jobs/run.py index cebadb73e..163eecc72 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -11,6 +11,7 @@ import easy_rec from easy_rec.python.inference.predictor import Predictor +from easy_rec.python.inference.vector_retrieve import VectorRetrieve from easy_rec.python.utils import config_util from easy_rec.python.utils import fg_util from easy_rec.python.utils import hpo_util @@ -63,6 +64,29 @@ tf.app.flags.DEFINE_string('eval_tables', '', 'tables used for evaluation') tf.app.flags.DEFINE_string('boundary_table', '', 'tables used for boundary') tf.app.flags.DEFINE_string('sampler_table', '', 'tables used for sampler') +tf.app.flags.DEFINE_string('query_table', '', + 'table used for retrieve vector neighbours') +tf.app.flags.DEFINE_string('doc_table', '', + 'table used for be retrieved as indexed vectors') +tf.app.flags.DEFINE_enum('knn_distance', 'inner_product', + ['l2', 'inner_product'], 'type of knn distance') +tf.app.flags.DEFINE_integer('knn_num_neighbours', None, + 'top n neighbours to be retrieved') +tf.app.flags.DEFINE_integer('knn_feature_dims', None, + 'number of feature dimensions') +tf.app.flags.DEFINE_enum( + 'knn_index_type', 'ivfflat', + ['flat', 'ivfflat', 'ivfpq', 'gpu_flat', 'gpu_ivfflat', 'gpu_ivfpg'], + 'knn index type') +tf.app.flags.DEFINE_string('knn_feature_delimiter', ',', + 'delimiter for feature vectors') +tf.app.flags.DEFINE_integer('knn_nlist', 5, + 'number of split part on each worker') +tf.app.flags.DEFINE_integer('knn_nprobe', 2, + 'number of probe part on each worker') +tf.app.flags.DEFINE_integer( + 'knn_compress_dim', 8, + 'number of dimensions after compress for `ivfpq` and `gpu_ivfpq`') # flags used for evaluate & export tf.app.flags.DEFINE_string( @@ -389,8 +413,38 @@ def main(argv): batch_size=FLAGS.batch_size, slice_id=FLAGS.task_index, slice_num=worker_num) + elif FLAGS.cmd == 'vector_retrieve': + check_param('knn_distance') + assert FLAGS.knn_feature_dims is not None, '`knn_feature_dims` should not be None' + assert FLAGS.knn_num_neighbours is not None, '`knn_num_neighbours` should not be None' + + query_table, doc_table, output_table = FLAGS.query_table, FLAGS.doc_table, FLAGS.outputs + if not query_table: + tables = FLAGS.tables.split(',') + assert len( + tables + ) >= 1, 'at least 1 tables must be specified, but only[%d]: %s' % ( + len(tables), FLAGS.tables) + query_table = tables[0] + doc_table = tables[1] if len(tables) > 1 else query_table + + knn = VectorRetrieve( + query_table, + doc_table, + output_table, + ndim=FLAGS.knn_feature_dims, + distance=1 if FLAGS.knn_distance == 'inner_product' else 0, + delimiter=FLAGS.knn_feature_delimiter, + batch_size=FLAGS.batch_size, + index_type=FLAGS.knn_index_type, + nlist=FLAGS.knn_nlist, + nprobe=FLAGS.knn_nprobe, + m=FLAGS.knn_compress_dim) + worker_hosts = FLAGS.worker_hosts.split(',') + knn(FLAGS.knn_num_neighbours, FLAGS.task_index, len(worker_hosts)) else: - raise ValueError('cmd should be one of train/evaluate/export/predict') + raise ValueError( + 'cmd should be one of train/evaluate/export/predict/vector_retrieve') if __name__ == '__main__': diff --git a/samples/odps_script/vector_retrieve/create_inner_vector_table.sql b/samples/odps_script/vector_retrieve/create_inner_vector_table.sql new file mode 100644 index 000000000..e6c3f5be1 --- /dev/null +++ b/samples/odps_script/vector_retrieve/create_inner_vector_table.sql @@ -0,0 +1,35 @@ +drop TABLE IF EXISTS query_vector_{TIME_STAMP}; +create table query_vector_{TIME_STAMP}( + query_id BIGINT + ,vector string +); + +INSERT OVERWRITE TABLE query_vector_{TIME_STAMP} +VALUES + (1, '0.1,0.2,-0.4,0.5'), + (2, '-0.1,0.8,0.4,0.5'), + (3, '0.59,0.2,0.4,0.15'), + (10, '0.1,-0.2,0.4,-0.5'), + (20, '-0.1,-0.2,0.4,0.5'), + (30, '0.5,0.2,0.43,0.15') +; + +desc query_vector_{TIME_STAMP}; + +drop TABLE IF EXISTS doc_vector_{TIME_STAMP}; +create table doc_vector_{TIME_STAMP}( + doc_id BIGINT + ,vector string +); + +INSERT OVERWRITE TABLE doc_vector_{TIME_STAMP} +VALUES + (1, '0.1,0.2,0.4,0.5'), + (2, '-0.1,0.2,0.4,0.5'), + (3, '0.5,0.2,0.4,0.5'), + (10, '0.1,0.2,0.4,0.5'), + (20, '-0.1,-0.2,0.4,0.5'), + (30, '0.5,0.2,0.43,0.15') +; + +desc doc_vector_{TIME_STAMP}; diff --git a/samples/odps_script/vector_retrieve/drop_table.sql b/samples/odps_script/vector_retrieve/drop_table.sql new file mode 100644 index 000000000..3550efc6b --- /dev/null +++ b/samples/odps_script/vector_retrieve/drop_table.sql @@ -0,0 +1,3 @@ +drop TABLE IF EXISTS query_vector_{TIME_STAMP}; +drop TABLE IF EXISTS doc_vector_{TIME_STAMP}; +drop TABLE IF EXISTS result_vector_{TIME_STAMP}; \ No newline at end of file diff --git a/samples/odps_script/vector_retrieve/run_vector_retrieve.sql b/samples/odps_script/vector_retrieve/run_vector_retrieve.sql new file mode 100644 index 000000000..2314a3eea --- /dev/null +++ b/samples/odps_script/vector_retrieve/run_vector_retrieve.sql @@ -0,0 +1,16 @@ +pai -name easy_rec_ext +-Dcmd='vector_retrieve' +-DentryFile='run.py' +-Dquery_table='odps://{ODPS_PROJ_NAME}/tables/query_vector_{TIME_STAMP}' +-Ddoc_table='odps://{ODPS_PROJ_NAME}/tables/doc_vector_{TIME_STAMP}' +-Doutput_table='odps://{ODPS_PROJ_NAME}/tables/result_vector_{TIME_STAMP}' +-Dcluster='{"worker" : {"count":1, "cpu":800, "memory":10000}}' +-Darn={ROLEARN} +-Dbuckets=oss://{OSS_BUCKET_NAME}/ +-DossHost={OSS_ENDPOINT} +-Dknn_distance=inner_product +-Dknn_num_neighbours=2 +-Dknn_feature_dims=4 +-Dknn_index_type='ivfflat' +-Dknn_feature_delimiter=',' +; \ No newline at end of file