From 9f6a19c56c38aa351e55eda5877ee2448b7b3d97 Mon Sep 17 00:00:00 2001 From: zhagnlu <1542303831@qq.com> Date: Tue, 16 Jan 2024 19:48:38 +0800 Subject: [PATCH] fix: increase expr recursion depth to avoid parse failed (#29860) (#30021) pr: #29860 Signed-off-by: luzhang Co-authored-by: luzhang --- internal/core/src/query/Plan.cpp | 10 +++++++++- tests/python_client/testcases/test_query.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 979b9e80bddff..e74e3871112fd 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -77,7 +77,15 @@ CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size) { proto::plan::PlanNode plan_node; - plan_node.ParseFromArray(serialized_expr_plan, size); + google::protobuf::io::ArrayInputStream array_stream(serialized_expr_plan, + size); + google::protobuf::io::CodedInputStream input_stream(&array_stream); + input_stream.SetRecursionLimit(std::numeric_limits::max()); + + auto res = plan_node.ParsePartialFromCodedStream(&input_stream); + if (!res) { + throw SegcoreError(UnexpectedError, "parse plan node proto failed"); + } return ProtoParser(schema).CreateRetrievePlan(plan_node); } diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index ac4a3d658bbcf..17f09ba0a25d7 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -2352,6 +2352,24 @@ def test_query_using_all_types_of_default_value(self): assert res[ct.default_bool_field_name] is False assert res[ct.default_string_field_name] == "abc" + @pytest.mark.tags(CaseLabel.L0) + def test_query_multi_logical_exprs(self): + """ + target: test the scenario which query with many logical expressions + method: 1. create collection + 3. query the expr that like: int64 == 0 || int64 == 1 ........ + expected: run successfully + """ + c_name = cf.gen_unique_str(prefix) + collection_w = self.init_collection_wrap(name=c_name) + df = cf.gen_default_dataframe_data() + collection_w.insert(df) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.load() + multi_exprs = " || ".join(f'{default_int_field_name} == {i}' for i in range(60)) + _, check_res = collection_w.query(multi_exprs, output_fields=[f'{default_int_field_name}']) + assert(check_res == True) + class TestQueryString(TestcaseBase): """