Skip to content

Commit

Permalink
fix: missed error catching when binding condition expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
uniqueFranky committed Dec 26, 2024
1 parent 61ff0a5 commit 0cbe44d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
25 changes: 24 additions & 1 deletion src/observer/sql/stmt/delete_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ See the Mulan PSL v2 for more details. */
#include "sql/stmt/filter_stmt.h"
#include "storage/db/db.h"
#include "storage/table/table.h"
#include "sql/parser/expression_binder.h"

DeleteStmt::DeleteStmt(Table *table, SimpleFilterStmt *filter_stmt) : table_(table), filter_stmt_(filter_stmt) {}

Expand Down Expand Up @@ -46,8 +47,30 @@ RC DeleteStmt::create(Db *db, DeleteSqlNode &delete_sql, Stmt *&stmt)
std::unordered_map<std::string, Table *> table_map;
table_map.insert(std::pair<std::string, Table *>(std::string(table_name), table));

BinderContext binder_context;
binder_context.add_table(table);
ExpressionBinder expression_binder(binder_context);
RC rc = RC::SUCCESS;
for(auto &condition: delete_sql.conditions) {
if (condition.left_type == ConditionSqlNode::SideType::Expr) {
vector<unique_ptr<Expression>> bound;
rc = expression_binder.bind_expression(condition.left_expression, bound);
if(OB_FAIL(rc)) {
return rc;
}
condition.left_expression = std::move(bound.front());
}
if (condition.right_type == ConditionSqlNode::SideType::Expr) {
vector<unique_ptr<Expression>> bound;
rc = expression_binder.bind_expression(condition.right_expression, bound);
if(OB_FAIL(rc)) {
return rc;
}
condition.right_expression = std::move(bound.front());
}
}
SimpleFilterStmt *filter_stmt = nullptr;
RC rc = SimpleFilterStmt::create(
rc = SimpleFilterStmt::create(
db, table, &table_map, delete_sql.conditions.data(), static_cast<int>(delete_sql.conditions.size()), filter_stmt);
if (rc != RC::SUCCESS) {
LOG_WARN("failed to create filter statement. rc=%d:%s", rc, strrc(rc));
Expand Down
14 changes: 10 additions & 4 deletions src/observer/sql/stmt/select_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,24 @@ RC SelectStmt::create(Db *db, SelectSqlNode &select_sql, Stmt *&stmt)
default_table = tables[0];
}


RC rc = RC::SUCCESS;
std::vector<ConditionSqlNode> simple_conditions;
std::vector<ConditionSqlNode> sub_query_conditions;
for(auto &condition: select_sql.conditions) {
if(condition.left_type == ConditionSqlNode::SideType::Expr) {
vector<unique_ptr<Expression>> bound;
expression_binder.bind_expression(condition.left_expression, bound);
rc = expression_binder.bind_expression(condition.left_expression, bound);
if(OB_FAIL(rc)) {
return rc;
}
condition.left_expression = std::move(bound.front());
}
if(condition.right_type == ConditionSqlNode::SideType::Expr) {
vector<unique_ptr<Expression>> bound;
expression_binder.bind_expression(condition.right_expression, bound);
rc = expression_binder.bind_expression(condition.right_expression, bound);
if(OB_FAIL(rc)) {
return rc;
}
condition.right_expression = std::move(bound.front());
}
if(condition.left_type == ConditionSqlNode::SideType::SubQuery || condition.right_type == ConditionSqlNode::SideType::SubQuery) {
Expand All @@ -122,7 +128,7 @@ RC SelectStmt::create(Db *db, SelectSqlNode &select_sql, Stmt *&stmt)
}
// create simple filter statement in `where` statement
SimpleFilterStmt *simple_filter_stmt = nullptr;
RC rc = SimpleFilterStmt::create(db,
rc = SimpleFilterStmt::create(db,
default_table,
&table_map,
simple_conditions.data(),
Expand Down
23 changes: 23 additions & 0 deletions src/observer/sql/stmt/update_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the Mulan PSL v2 for more details. */

#include "sql/stmt/update_stmt.h"
#include "storage/db/db.h"
#include "sql/parser/expression_binder.h"

UpdateStmt::UpdateStmt(Table *table, const std::string &attribute_name, Value *values, int value_amount, SimpleFilterStmt *filter_stmt)
: table_(table), values_(values), value_amount_(value_amount), filter_stmt_(filter_stmt), attribute_name_(attribute_name)
Expand All @@ -29,6 +30,28 @@ RC UpdateStmt::create(Db *db, UpdateSqlNode &update, Stmt *&stmt)
RC rc = RC::SUCCESS;
std::unordered_map<std::string, Table *> table_map;
table_map.insert_or_assign(update.relation_name, table);

BinderContext binder_context;
binder_context.add_table(table);
ExpressionBinder expression_binder(binder_context);
for(auto &condition: update.conditions) {
if (condition.left_type == ConditionSqlNode::SideType::Expr) {
vector<unique_ptr<Expression>> bound;
rc = expression_binder.bind_expression(condition.left_expression, bound);
if(OB_FAIL(rc)) {
return rc;
}
condition.left_expression = std::move(bound.front());
}
if (condition.right_type == ConditionSqlNode::SideType::Expr) {
vector<unique_ptr<Expression>> bound;
rc = expression_binder.bind_expression(condition.right_expression, bound);
if(OB_FAIL(rc)) {
return rc;
}
condition.right_expression = std::move(bound.front());
}
}
SimpleFilterStmt *filter_stmt = nullptr;
rc = SimpleFilterStmt::create(db, table, &table_map, update.conditions.data(),
static_cast<int>(update.conditions.size()), filter_stmt);
Expand Down

0 comments on commit 0cbe44d

Please sign in to comment.