diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index f995233..40b4b33 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -22,13 +22,8 @@ #include "duckdb/common/helper.hpp" #include "duckdb/main/relation.hpp" -#include "duckdb/main/relation/create_table_relation.hpp" -#include -#include "duckdb/main/relation/table_relation.hpp" -#include "duckdb/main/relation/table_function_relation.hpp" -#include "duckdb/main/relation/value_relation.hpp" -#include "duckdb/main/relation/view_relation.hpp" #include "duckdb/main/relation/aggregate_relation.hpp" +#include "duckdb/main/relation/create_table_relation.hpp" #include "duckdb/main/relation/cross_product_relation.hpp" #include "duckdb/main/relation/filter_relation.hpp" #include "duckdb/main/relation/join_relation.hpp" @@ -36,6 +31,13 @@ #include "duckdb/main/relation/order_relation.hpp" #include "duckdb/main/relation/projection_relation.hpp" #include "duckdb/main/relation/setop_relation.hpp" +#include "duckdb/main/relation/table_function_relation.hpp" +#include "duckdb/main/relation/table_relation.hpp" +#include "duckdb/main/relation/value_relation.hpp" +#include "duckdb/main/relation/view_relation.hpp" +#include + +#include namespace duckdb { const std::unordered_map SubstraitToDuckDB::function_names_remap = { @@ -753,6 +755,48 @@ shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s } } + +shared_ptr SubstraitToDuckDB::TransformUpdateOp(const substrait::Rel &sop) { + auto &supdate = sop.update(); + auto &nobj = supdate.named_table(); + if (nobj.names_size() == 0) { + throw InvalidInputException("Named object must have at least one name"); + } + auto table_idx = nobj.names_size() - 1; + auto table_name = nobj.names(table_idx); + string schema_name; + if (table_idx > 0) { + schema_name = nobj.names(0); + } + + auto schema = supdate.table_schema(); + unique_ptr condition = nullptr; + if (supdate.has_condition()) { + condition = std::move(TransformExpr(supdate.condition())); + } + vector> transformations; + vector columns; + for (int i = 0; i < supdate.transformations_size(); i++) { + auto &transformation = supdate.transformations(i); + auto column_target = transformation.column_target(); + columns.push_back(schema.names(column_target)); + transformations.emplace_back(TransformExpr(transformation.transformation())); + } + + shared_ptr table; + auto context_wrapper = make_shared_ptr(context); + auto table_info = TableInfo(*context, DEFAULT_SCHEMA, table_name); + if (!table_info) { + throw CatalogException("Table '%s' does not exist!", table_name); + } + if (acquire_lock) { + table = make_shared_ptr(context, std::move(table_info)); + } else { + table = make_shared_ptr(context_wrapper, std::move(table_info)); + } + return make_shared_ptr(table->context, std::move(condition), schema_name, table_name, columns, std::move(transformations)); +} + shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop, const google::protobuf::RepeatedPtrField *names) { switch (sop.rel_type_case()) { @@ -776,6 +820,8 @@ shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop, return TransformSetOp(sop, names); case substrait::Rel::RelTypeCase::kWrite: return TransformWriteOp(sop); + case substrait::Rel::RelTypeCase::kUpdate: + return TransformUpdateOp(sop); default: throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case())); } @@ -835,7 +881,8 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot } } - if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) { + switch (sop.input().rel_type_case()) { + case substrait::Rel::RelTypeCase::kWrite: { auto write = sop.input().write(); switch (write.op()) { case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: { @@ -847,6 +894,10 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot return child; } } + case substrait::Rel::RelTypeCase::kUpdate: + return child; + } + return make_shared_ptr(child, std::move(expressions), aliases); } diff --git a/src/include/from_substrait.hpp b/src/include/from_substrait.hpp index 7fbcfa4..d941005 100644 --- a/src/include/from_substrait.hpp +++ b/src/include/from_substrait.hpp @@ -80,6 +80,7 @@ class SubstraitToDuckDB { shared_ptr TransformSetOp(const substrait::Rel &sop, const google::protobuf::RepeatedPtrField *names = nullptr); shared_ptr TransformWriteOp(const substrait::Rel &sop); + shared_ptr TransformUpdateOp(const substrait::Rel &sop); //! Transform Substrait Expressions to DuckDB Expressions unique_ptr TransformExpr(const substrait::Expression &sexpr, diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index 7466395..2ca13c6 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -72,6 +72,9 @@ class DuckDBToSubstrait { substrait::Rel *TransformCreateTable(LogicalOperator &dop); substrait::Rel *TransformInsertTable(LogicalOperator &dop); substrait::Rel *TransformDeleteTable(LogicalOperator &dop); + substrait::Rel *TransformUpdateTable(LogicalOperator &dop); + void UpdateColumnReferences(substrait::Expression * expression, const vector vector); + static substrait::Rel *TransformDummyScan(); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) //! To Substrait; diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 6ff0b0a..18aa7f4 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -1515,7 +1515,6 @@ substrait::Rel *DuckDBToSubstrait::TransformInsertTable(LogicalOperator &dop) { substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) { auto rel = new substrait::Rel(); auto &logical_delete = dop.Cast(); - auto &table = logical_delete.table; if (logical_delete.children.size() != 1) { throw InternalException("Delete table expected one child, found " + to_string(logical_delete.children.size())); } @@ -1524,10 +1523,6 @@ substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) { writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE); writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT); - auto named_table = writeRel->mutable_named_table(); - named_table->add_names(table.schema.name); - named_table->add_names(table.name); - SetNamedTable(logical_delete.table, writeRel); auto schema = new substrait::NamedStruct(); SetTableSchema(logical_delete.table, schema); @@ -1538,6 +1533,175 @@ substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) { return rel; } +const vector & GetColumnIds(const LogicalProjection & dproj) { + if (dproj.children.size() != 1) { + throw InternalException("Update table projection expected 1 child, found " + to_string(dproj.children.size())); + } + + if (dproj.children[0]->type == LogicalOperatorType::LOGICAL_GET) { + return dproj.children[0]->Cast().GetColumnIds(); + } + // if (dproj.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { + // int columnCount = dproj.expressions.size(); + // vector columnIds; + // auto &inner_proj = dproj.children[0]->Cast(); + // + // } + static const vector empty; + return empty; +} + +substrait::Rel *DuckDBToSubstrait::TransformUpdateTable(LogicalOperator &dop) { + auto rel = new substrait::Rel(); + auto &logical_update = dop.Cast(); + auto &table = logical_update.table; + if (logical_update.children.size() != 1) { + throw InternalException("Delete table expected one child, found " + to_string(logical_update.children.size())); + } + + auto update_rel = rel->mutable_update(); + + auto named_table = update_rel->mutable_named_table(); + named_table->add_names(table.schema.name); + named_table->add_names(table.name); + + auto schema = new substrait::NamedStruct(); + SetTableSchema(logical_update.table, schema); + update_rel->set_allocated_table_schema(schema); + + if (logical_update.children.size() != 1) { + throw InternalException("Update table expected one child, found " + to_string(logical_update.children.size())); + } + if (logical_update.children[0]->type != LogicalOperatorType::LOGICAL_PROJECTION) { + throw InternalException("Update table expected projection child, found " + + LogicalOperatorToString(logical_update.children[0]->type)); + } + + auto &dproj = dop.children[0]->Cast(); + if (dproj.expressions.size() < logical_update.columns.size()) { + throw InternalException("Update table expected %d expressions, found %d", logical_update.expressions.size(), + dproj.expressions.size()); + } + if (dproj.children.size() != 1) { + throw InternalException("Update table projection expected 1 child, found " + to_string(dproj.children.size())); + } + if (dproj.children[0]->type != LogicalOperatorType::LOGICAL_GET && + dproj.children[0]->type != LogicalOperatorType::LOGICAL_FILTER) { + throw InternalException("Update table projection expected get as child, found " + + LogicalOperatorToString(dproj.children[0]->type)); + } + + // fix column references in the transformations using the column ids + auto &columnIds = GetColumnIds(dproj); + // auto &columnIds = dproj.children[0]->Cast().GetColumnIds(); + + substrait::Rel *input = TransformOp(*logical_update.children[0]); + auto &project_rel = input->project(); + for (int i = 0; i < logical_update.columns.size(); i++) { + auto transformation = update_rel->add_transformations(); + auto mutable_expression = transformation->mutable_transformation(); + mutable_expression->CopyFrom(project_rel.expressions(i)); + UpdateColumnReferences(mutable_expression, columnIds); + transformation->set_column_target(logical_update.columns[i].index); + } + + switch (project_rel.input().rel_type_case()) { + case substrait::Rel::RelTypeCase::kRead: { + auto &read_rel = project_rel.input().read(); + if (read_rel.has_filter()) { + auto condition = new substrait::Expression(read_rel.filter()); + update_rel->set_allocated_condition(condition); + } + break; + } + case substrait::Rel::RelTypeCase::kProject: { + auto &inner_project_rel = project_rel.input().project(); + if (inner_project_rel.input().rel_type_case() != substrait::Rel::RelTypeCase::kFilter) { + throw InternalException("Unsupported input type " + to_string(inner_project_rel.input().rel_type_case())); + } + if (inner_project_rel.input().has_filter() && inner_project_rel.input().filter().has_condition()) { + auto condition = new substrait::Expression(inner_project_rel.input().filter().condition()); + + update_rel->set_allocated_condition(condition); + } + break; + } + default: + throw InternalException("Unsupported input type " + to_string(project_rel.input().rel_type_case())); + } + return rel; +} + +void DuckDBToSubstrait::UpdateColumnReferences(substrait::Expression * expr, const vector columnIds) { + if (columnIds.empty()) { + return; + } + switch (expr->rex_type_case()) { + case substrait::Expression::RexTypeCase::kLiteral: + return; + case substrait::Expression::RexTypeCase::kSelection: { + auto &fieldRef = expr->selection(); + if (fieldRef.has_direct_reference() && fieldRef.direct_reference().has_struct_field()) { + auto inputColumnIdx = fieldRef.direct_reference().struct_field().field(); + auto newColumnId = columnIds[inputColumnIdx].GetPrimaryIndex(); + expr->mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(newColumnId); + } + return; + } + case substrait::Expression::RexTypeCase::kScalarFunction: { + auto mutable_function = expr->mutable_scalar_function(); + for (int i = 0; i < mutable_function->arguments_size(); i++) { + if (mutable_function->arguments(i).has_value()) { + auto mutable_arg = mutable_function->mutable_arguments(i)->mutable_value(); + UpdateColumnReferences(mutable_arg, columnIds); + } + } + return; + } + case substrait::Expression::RexTypeCase::kIfThen: { + if (expr->if_then().has_else_()) { + UpdateColumnReferences(expr->mutable_if_then()->mutable_else_(), columnIds); + } + for (int i = 0; i < expr->if_then().ifs_size(); i++) { + UpdateColumnReferences(expr->mutable_if_then()->mutable_ifs(i)->mutable_if_(), columnIds); + UpdateColumnReferences(expr->mutable_if_then()->mutable_ifs(i)->mutable_then(), columnIds); + } + return; + } + case substrait::Expression::RexTypeCase::kCast: { + auto mutable_cast = expr->mutable_cast(); + if (mutable_cast->has_input()) { + UpdateColumnReferences(mutable_cast->mutable_input(), columnIds); + } + return; + } + case substrait::Expression::RexTypeCase::kSingularOrList: + return; + case substrait::Expression::RexTypeCase::kNested: { + auto nested = expr->mutable_nested(); + if (nested->has_struct_()) { + for (int i = 0; i < nested->struct_().fields_size(); i++) { + UpdateColumnReferences(nested->mutable_struct_()->mutable_fields(i), columnIds); + } + } else if (nested->has_list()) { + for (int i = 0; i < nested->list().values_size(); i++) { + UpdateColumnReferences(nested->mutable_list()->mutable_values(i), columnIds); + } + } else if (nested->has_map()) { + for (int i = 0; i < nested->map().key_values_size(); i++) { + UpdateColumnReferences(nested->mutable_map()->mutable_key_values(i)->mutable_key(), columnIds); + UpdateColumnReferences(nested->mutable_map()->mutable_key_values(i)->mutable_value(), columnIds); + } + } + return; + } + case substrait::Expression::RexTypeCase::kSubquery: + return; + default: + throw InternalException("Unsupported expression type " + to_string(expr->rex_type_case())); + } +} + substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { switch (dop.type) { case LogicalOperatorType::LOGICAL_FILTER: @@ -1576,6 +1740,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformInsertTable(dop); case LogicalOperatorType::LOGICAL_DELETE: return TransformDeleteTable(dop); + case LogicalOperatorType::LOGICAL_UPDATE: + return TransformUpdateTable(dop); default: throw NotImplementedException(LogicalOperatorToString(dop.type)); } diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index 27ab432..90fc484 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -320,3 +320,46 @@ TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") { REQUIRE(CHECK_COLUMN(result, 0, {2, 6})); REQUIRE(CHECK_COLUMN(result, 1, {4, 8})); } + + +TEST_CASE("Test C UpdateRows with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "UPDATE employees SET salary = salary * 1.20"); + auto result = ExecuteViaSubstrait(con, "SELECT * from employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2})); + REQUIRE(CHECK_COLUMN(result, 3, {144000, 96000, 60000, 114000, 72000})); +} + +TEST_CASE("Test C UpdateRows with simple condition using Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "UPDATE employees SET salary = salary * 1.20 where employee_id = 1"); + auto result = ExecuteViaSubstrait(con, "SELECT * from employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2})); + REQUIRE(CHECK_COLUMN(result, 3, {144000, 80000, 50000, 95000, 60000})); +} + +TEST_CASE("Test C UpdateRows with condition on a column using Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "UPDATE employees SET salary = salary * 1.20 where salary < 100000"); + auto result = ExecuteViaSubstrait(con, "SELECT * from employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2})); + REQUIRE(CHECK_COLUMN(result, 3, {120000, 96000, 60000, 114000, 72000})); +} \ No newline at end of file