Skip to content

Commit

Permalink
feat: Support Update rows in a table
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran committed Nov 27, 2024
1 parent abc4b70 commit b38db71
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 12 deletions.
65 changes: 58 additions & 7 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@
#include "duckdb/common/helper.hpp"

#include "duckdb/main/relation.hpp"
#include "duckdb/main/relation/create_table_relation.hpp"
#include <duckdb/main/relation/delete_relation.hpp>
#include "duckdb/main/relation/table_relation.hpp"
#include "duckdb/main/relation/table_function_relation.hpp"
#include "duckdb/main/relation/value_relation.hpp"
#include "duckdb/main/relation/view_relation.hpp"
#include "duckdb/main/relation/aggregate_relation.hpp"
#include "duckdb/main/relation/create_table_relation.hpp"
#include "duckdb/main/relation/cross_product_relation.hpp"
#include "duckdb/main/relation/filter_relation.hpp"
#include "duckdb/main/relation/join_relation.hpp"
#include "duckdb/main/relation/limit_relation.hpp"
#include "duckdb/main/relation/order_relation.hpp"
#include "duckdb/main/relation/projection_relation.hpp"
#include "duckdb/main/relation/setop_relation.hpp"
#include "duckdb/main/relation/table_function_relation.hpp"
#include "duckdb/main/relation/table_relation.hpp"
#include "duckdb/main/relation/value_relation.hpp"
#include "duckdb/main/relation/view_relation.hpp"
#include <duckdb/main/relation/delete_relation.hpp>

#include <duckdb/main/relation/update_relation.hpp>

namespace duckdb {
const std::unordered_map<std::string, std::string> SubstraitToDuckDB::function_names_remap = {
Expand Down Expand Up @@ -753,6 +755,48 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
}
}


shared_ptr<Relation> SubstraitToDuckDB::TransformUpdateOp(const substrait::Rel &sop) {
auto &supdate = sop.update();
auto &nobj = supdate.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto schema = supdate.table_schema();
unique_ptr<ParsedExpression> condition = nullptr;
if (supdate.has_condition()) {
condition = std::move(TransformExpr(supdate.condition()));
}
vector<unique_ptr<ParsedExpression>> transformations;
vector<string> columns;
for (int i = 0; i < supdate.transformations_size(); i++) {
auto &transformation = supdate.transformations(i);
auto column_target = transformation.column_target();
columns.push_back(schema.names(column_target));
transformations.emplace_back(TransformExpr(transformation.transformation()));
}

shared_ptr<TableRelation> table;
auto context_wrapper = make_shared_ptr<RelationContextWrapper>(context);
auto table_info = TableInfo(*context, DEFAULT_SCHEMA, table_name);
if (!table_info) {
throw CatalogException("Table '%s' does not exist!", table_name);
}
if (acquire_lock) {
table = make_shared_ptr<TableRelation>(context, std::move(table_info));
} else {
table = make_shared_ptr<TableRelation>(context_wrapper, std::move(table_info));
}
return make_shared_ptr<UpdateRelation>(table->context, std::move(condition), schema_name, table_name, columns, std::move(transformations));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
switch (sop.rel_type_case()) {
Expand All @@ -776,6 +820,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
return TransformSetOp(sop, names);
case substrait::Rel::RelTypeCase::kWrite:
return TransformWriteOp(sop);
case substrait::Rel::RelTypeCase::kUpdate:
return TransformUpdateOp(sop);
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
Expand Down Expand Up @@ -835,7 +881,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
}
}

if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) {
switch (sop.input().rel_type_case()) {

Check warning on line 884 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

21 enumeration values not handled in switch: 'REL_TYPE_NOT_SET', 'kRead', 'kFilter'... [-Wswitch]

Check warning on line 884 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

21 enumeration values not handled in switch: 'REL_TYPE_NOT_SET', 'kRead', 'kFilter'... [-Wswitch]
case substrait::Rel::RelTypeCase::kWrite: {
auto write = sop.input().write();
switch (write.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
Expand All @@ -847,6 +894,10 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
return child;
}
}
case substrait::Rel::RelTypeCase::kUpdate:
return child;
}


return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
}
Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SubstraitToDuckDB {
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformWriteOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformUpdateOp(const substrait::Rel &sop);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
Expand Down
3 changes: 3 additions & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class DuckDBToSubstrait {
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
substrait::Rel *TransformUpdateTable(LogicalOperator &dop);
void UpdateColumnReferences(substrait::Expression * expression, const vector<ColumnIndex> vector);

static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
176 changes: 171 additions & 5 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,6 @@ substrait::Rel *DuckDBToSubstrait::TransformInsertTable(LogicalOperator &dop) {
substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &logical_delete = dop.Cast<LogicalDelete>();
auto &table = logical_delete.table;
if (logical_delete.children.size() != 1) {
throw InternalException("Delete table expected one child, found " + to_string(logical_delete.children.size()));
}
Expand All @@ -1524,10 +1523,6 @@ substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE);
writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT);

auto named_table = writeRel->mutable_named_table();
named_table->add_names(table.schema.name);
named_table->add_names(table.name);

SetNamedTable(logical_delete.table, writeRel);
auto schema = new substrait::NamedStruct();
SetTableSchema(logical_delete.table, schema);
Expand All @@ -1538,6 +1533,175 @@ substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
return rel;
}

const vector<ColumnIndex> & GetColumnIds(const LogicalProjection & dproj) {
if (dproj.children.size() != 1) {
throw InternalException("Update table projection expected 1 child, found " + to_string(dproj.children.size()));
}

if (dproj.children[0]->type == LogicalOperatorType::LOGICAL_GET) {
return dproj.children[0]->Cast<LogicalGet>().GetColumnIds();
}
// if (dproj.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) {
// int columnCount = dproj.expressions.size();
// vector<ColumnIndex> columnIds;
// auto &inner_proj = dproj.children[0]->Cast<LogicalProjection>();
//
// }
static const vector<ColumnIndex> empty;
return empty;
}

substrait::Rel *DuckDBToSubstrait::TransformUpdateTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &logical_update = dop.Cast<LogicalUpdate>();
auto &table = logical_update.table;
if (logical_update.children.size() != 1) {
throw InternalException("Delete table expected one child, found " + to_string(logical_update.children.size()));
}

auto update_rel = rel->mutable_update();

auto named_table = update_rel->mutable_named_table();
named_table->add_names(table.schema.name);
named_table->add_names(table.name);

auto schema = new substrait::NamedStruct();
SetTableSchema(logical_update.table, schema);
update_rel->set_allocated_table_schema(schema);

if (logical_update.children.size() != 1) {
throw InternalException("Update table expected one child, found " + to_string(logical_update.children.size()));
}
if (logical_update.children[0]->type != LogicalOperatorType::LOGICAL_PROJECTION) {
throw InternalException("Update table expected projection child, found " +
LogicalOperatorToString(logical_update.children[0]->type));
}

auto &dproj = dop.children[0]->Cast<LogicalProjection>();
if (dproj.expressions.size() < logical_update.columns.size()) {
throw InternalException("Update table expected %d expressions, found %d", logical_update.expressions.size(),
dproj.expressions.size());
}
if (dproj.children.size() != 1) {
throw InternalException("Update table projection expected 1 child, found " + to_string(dproj.children.size()));
}
if (dproj.children[0]->type != LogicalOperatorType::LOGICAL_GET &&
dproj.children[0]->type != LogicalOperatorType::LOGICAL_FILTER) {
throw InternalException("Update table projection expected get as child, found " +
LogicalOperatorToString(dproj.children[0]->type));
}

// fix column references in the transformations using the column ids
auto &columnIds = GetColumnIds(dproj);
// auto &columnIds = dproj.children[0]->Cast<LogicalGet>().GetColumnIds();

substrait::Rel *input = TransformOp(*logical_update.children[0]);
auto &project_rel = input->project();
for (int i = 0; i < logical_update.columns.size(); i++) {
auto transformation = update_rel->add_transformations();
auto mutable_expression = transformation->mutable_transformation();
mutable_expression->CopyFrom(project_rel.expressions(i));
UpdateColumnReferences(mutable_expression, columnIds);
transformation->set_column_target(logical_update.columns[i].index);
}

switch (project_rel.input().rel_type_case()) {
case substrait::Rel::RelTypeCase::kRead: {
auto &read_rel = project_rel.input().read();
if (read_rel.has_filter()) {
auto condition = new substrait::Expression(read_rel.filter());
update_rel->set_allocated_condition(condition);
}
break;
}
case substrait::Rel::RelTypeCase::kProject: {
auto &inner_project_rel = project_rel.input().project();
if (inner_project_rel.input().rel_type_case() != substrait::Rel::RelTypeCase::kFilter) {
throw InternalException("Unsupported input type " + to_string(inner_project_rel.input().rel_type_case()));
}
if (inner_project_rel.input().has_filter() && inner_project_rel.input().filter().has_condition()) {
auto condition = new substrait::Expression(inner_project_rel.input().filter().condition());

update_rel->set_allocated_condition(condition);
}
break;
}
default:
throw InternalException("Unsupported input type " + to_string(project_rel.input().rel_type_case()));
}
return rel;
}

void DuckDBToSubstrait::UpdateColumnReferences(substrait::Expression * expr, const vector<ColumnIndex> columnIds) {
if (columnIds.empty()) {
return;
}
switch (expr->rex_type_case()) {
case substrait::Expression::RexTypeCase::kLiteral:
return;
case substrait::Expression::RexTypeCase::kSelection: {
auto &fieldRef = expr->selection();
if (fieldRef.has_direct_reference() && fieldRef.direct_reference().has_struct_field()) {
auto inputColumnIdx = fieldRef.direct_reference().struct_field().field();
auto newColumnId = columnIds[inputColumnIdx].GetPrimaryIndex();
expr->mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(newColumnId);
}
return;
}
case substrait::Expression::RexTypeCase::kScalarFunction: {
auto mutable_function = expr->mutable_scalar_function();
for (int i = 0; i < mutable_function->arguments_size(); i++) {
if (mutable_function->arguments(i).has_value()) {
auto mutable_arg = mutable_function->mutable_arguments(i)->mutable_value();
UpdateColumnReferences(mutable_arg, columnIds);
}
}
return;
}
case substrait::Expression::RexTypeCase::kIfThen: {
if (expr->if_then().has_else_()) {
UpdateColumnReferences(expr->mutable_if_then()->mutable_else_(), columnIds);
}
for (int i = 0; i < expr->if_then().ifs_size(); i++) {
UpdateColumnReferences(expr->mutable_if_then()->mutable_ifs(i)->mutable_if_(), columnIds);
UpdateColumnReferences(expr->mutable_if_then()->mutable_ifs(i)->mutable_then(), columnIds);
}
return;
}
case substrait::Expression::RexTypeCase::kCast: {
auto mutable_cast = expr->mutable_cast();
if (mutable_cast->has_input()) {
UpdateColumnReferences(mutable_cast->mutable_input(), columnIds);
}
return;
}
case substrait::Expression::RexTypeCase::kSingularOrList:
return;
case substrait::Expression::RexTypeCase::kNested: {
auto nested = expr->mutable_nested();
if (nested->has_struct_()) {
for (int i = 0; i < nested->struct_().fields_size(); i++) {
UpdateColumnReferences(nested->mutable_struct_()->mutable_fields(i), columnIds);
}
} else if (nested->has_list()) {
for (int i = 0; i < nested->list().values_size(); i++) {
UpdateColumnReferences(nested->mutable_list()->mutable_values(i), columnIds);
}
} else if (nested->has_map()) {
for (int i = 0; i < nested->map().key_values_size(); i++) {
UpdateColumnReferences(nested->mutable_map()->mutable_key_values(i)->mutable_key(), columnIds);
UpdateColumnReferences(nested->mutable_map()->mutable_key_values(i)->mutable_value(), columnIds);
}
}
return;
}
case substrait::Expression::RexTypeCase::kSubquery:
return;
default:
throw InternalException("Unsupported expression type " + to_string(expr->rex_type_case()));
}
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down Expand Up @@ -1576,6 +1740,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformInsertTable(dop);
case LogicalOperatorType::LOGICAL_DELETE:
return TransformDeleteTable(dop);
case LogicalOperatorType::LOGICAL_UPDATE:
return TransformUpdateTable(dop);
default:
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
Expand Down
43 changes: 43 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,46 @@ TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 0, {2, 6}));
REQUIRE(CHECK_COLUMN(result, 1, {4, 8}));
}


TEST_CASE("Test C UpdateRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "UPDATE employees SET salary = salary * 1.20");
auto result = ExecuteViaSubstrait(con, "SELECT * from employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2}));
REQUIRE(CHECK_COLUMN(result, 3, {144000, 96000, 60000, 114000, 72000}));
}

TEST_CASE("Test C UpdateRows with simple condition using Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "UPDATE employees SET salary = salary * 1.20 where employee_id = 1");
auto result = ExecuteViaSubstrait(con, "SELECT * from employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2}));
REQUIRE(CHECK_COLUMN(result, 3, {144000, 80000, 50000, 95000, 60000}));
}

TEST_CASE("Test C UpdateRows with condition on a column using Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "UPDATE employees SET salary = salary * 1.20 where salary < 100000");
auto result = ExecuteViaSubstrait(con, "SELECT * from employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 96000, 60000, 114000, 72000}));
}

0 comments on commit b38db71

Please sign in to comment.