From 0cbe44db055714b0afa051a3a7053a6fc9712a6b Mon Sep 17 00:00:00 2001 From: uniqueFranky Date: Thu, 26 Dec 2024 21:03:08 +0800 Subject: [PATCH] fix: missed error catching when binding condition expressions --- src/observer/sql/stmt/delete_stmt.cpp | 25 ++++++++++++++++++++++++- src/observer/sql/stmt/select_stmt.cpp | 14 ++++++++++---- src/observer/sql/stmt/update_stmt.cpp | 23 +++++++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/observer/sql/stmt/delete_stmt.cpp b/src/observer/sql/stmt/delete_stmt.cpp index cba2e72b4..de74142a3 100644 --- a/src/observer/sql/stmt/delete_stmt.cpp +++ b/src/observer/sql/stmt/delete_stmt.cpp @@ -17,6 +17,7 @@ See the Mulan PSL v2 for more details. */ #include "sql/stmt/filter_stmt.h" #include "storage/db/db.h" #include "storage/table/table.h" +#include "sql/parser/expression_binder.h" DeleteStmt::DeleteStmt(Table *table, SimpleFilterStmt *filter_stmt) : table_(table), filter_stmt_(filter_stmt) {} @@ -46,8 +47,30 @@ RC DeleteStmt::create(Db *db, DeleteSqlNode &delete_sql, Stmt *&stmt) std::unordered_map table_map; table_map.insert(std::pair(std::string(table_name), table)); + BinderContext binder_context; + binder_context.add_table(table); + ExpressionBinder expression_binder(binder_context); + RC rc = RC::SUCCESS; + for(auto &condition: delete_sql.conditions) { + if (condition.left_type == ConditionSqlNode::SideType::Expr) { + vector> bound; + rc = expression_binder.bind_expression(condition.left_expression, bound); + if(OB_FAIL(rc)) { + return rc; + } + condition.left_expression = std::move(bound.front()); + } + if (condition.right_type == ConditionSqlNode::SideType::Expr) { + vector> bound; + rc = expression_binder.bind_expression(condition.right_expression, bound); + if(OB_FAIL(rc)) { + return rc; + } + condition.right_expression = std::move(bound.front()); + } + } SimpleFilterStmt *filter_stmt = nullptr; - RC rc = SimpleFilterStmt::create( + rc = SimpleFilterStmt::create( db, table, &table_map, delete_sql.conditions.data(), static_cast(delete_sql.conditions.size()), filter_stmt); if (rc != RC::SUCCESS) { LOG_WARN("failed to create filter statement. rc=%d:%s", rc, strrc(rc)); diff --git a/src/observer/sql/stmt/select_stmt.cpp b/src/observer/sql/stmt/select_stmt.cpp index 63bad08fb..77016fd7a 100644 --- a/src/observer/sql/stmt/select_stmt.cpp +++ b/src/observer/sql/stmt/select_stmt.cpp @@ -98,18 +98,24 @@ RC SelectStmt::create(Db *db, SelectSqlNode &select_sql, Stmt *&stmt) default_table = tables[0]; } - + RC rc = RC::SUCCESS; std::vector simple_conditions; std::vector sub_query_conditions; for(auto &condition: select_sql.conditions) { if(condition.left_type == ConditionSqlNode::SideType::Expr) { vector> bound; - expression_binder.bind_expression(condition.left_expression, bound); + rc = expression_binder.bind_expression(condition.left_expression, bound); + if(OB_FAIL(rc)) { + return rc; + } condition.left_expression = std::move(bound.front()); } if(condition.right_type == ConditionSqlNode::SideType::Expr) { vector> bound; - expression_binder.bind_expression(condition.right_expression, bound); + rc = expression_binder.bind_expression(condition.right_expression, bound); + if(OB_FAIL(rc)) { + return rc; + } condition.right_expression = std::move(bound.front()); } if(condition.left_type == ConditionSqlNode::SideType::SubQuery || condition.right_type == ConditionSqlNode::SideType::SubQuery) { @@ -122,7 +128,7 @@ RC SelectStmt::create(Db *db, SelectSqlNode &select_sql, Stmt *&stmt) } // create simple filter statement in `where` statement SimpleFilterStmt *simple_filter_stmt = nullptr; - RC rc = SimpleFilterStmt::create(db, + rc = SimpleFilterStmt::create(db, default_table, &table_map, simple_conditions.data(), diff --git a/src/observer/sql/stmt/update_stmt.cpp b/src/observer/sql/stmt/update_stmt.cpp index d676fb4b6..9215b0458 100644 --- a/src/observer/sql/stmt/update_stmt.cpp +++ b/src/observer/sql/stmt/update_stmt.cpp @@ -14,6 +14,7 @@ See the Mulan PSL v2 for more details. */ #include "sql/stmt/update_stmt.h" #include "storage/db/db.h" +#include "sql/parser/expression_binder.h" UpdateStmt::UpdateStmt(Table *table, const std::string &attribute_name, Value *values, int value_amount, SimpleFilterStmt *filter_stmt) : table_(table), values_(values), value_amount_(value_amount), filter_stmt_(filter_stmt), attribute_name_(attribute_name) @@ -29,6 +30,28 @@ RC UpdateStmt::create(Db *db, UpdateSqlNode &update, Stmt *&stmt) RC rc = RC::SUCCESS; std::unordered_map table_map; table_map.insert_or_assign(update.relation_name, table); + + BinderContext binder_context; + binder_context.add_table(table); + ExpressionBinder expression_binder(binder_context); + for(auto &condition: update.conditions) { + if (condition.left_type == ConditionSqlNode::SideType::Expr) { + vector> bound; + rc = expression_binder.bind_expression(condition.left_expression, bound); + if(OB_FAIL(rc)) { + return rc; + } + condition.left_expression = std::move(bound.front()); + } + if (condition.right_type == ConditionSqlNode::SideType::Expr) { + vector> bound; + rc = expression_binder.bind_expression(condition.right_expression, bound); + if(OB_FAIL(rc)) { + return rc; + } + condition.right_expression = std::move(bound.front()); + } + } SimpleFilterStmt *filter_stmt = nullptr; rc = SimpleFilterStmt::create(db, table, &table_map, update.conditions.data(), static_cast(update.conditions.size()), filter_stmt);