diff --git a/src/DeltaLake/Bridge/src/sql.rs b/src/DeltaLake/Bridge/src/sql.rs index 793e3d5..0f08c7e 100644 --- a/src/DeltaLake/Bridge/src/sql.rs +++ b/src/DeltaLake/Bridge/src/sql.rs @@ -6,7 +6,7 @@ use arrow::{ use deltalake::datafusion::{ physical_plan::SendableRecordBatchStream, sql::sqlparser::{ - ast::{Assignment, Expr, Ident, TableAlias, TableFactor, Values}, + ast::{Assignment, Expr, Ident, TableFactor, Values}, dialect::{Dialect, GenericDialect}, keywords::Keyword, parser::{Parser, ParserError}, @@ -55,6 +55,7 @@ impl<'a> DeltaLakeParser<'a> { } pub fn parse_merge(&mut self) -> Result { + let _ = self.parser.parse_keyword(Keyword::MERGE); let into = self.parser.parse_keyword(Keyword::INTO); let table = self.parser.parse_table_factor()?; @@ -259,58 +260,58 @@ pub enum MergeClause { NotMatchedBySourceDelete(Option), } -pub fn extract_table_factor_alias(table: TableFactor) -> Option { +pub fn extract_table_factor_alias(table: TableFactor) -> Option { match table { TableFactor::Table { - name: _, + name, alias, args: _, with_hints: _, version: _, partitions: _, - } => alias, + } => alias.map(|a| a.to_string()).or(Some(name.to_string())), TableFactor::Derived { lateral: _, subquery: _, alias, - } => alias, - TableFactor::TableFunction { expr: _, alias } => alias, + } => alias.map(|a| a.to_string()), + TableFactor::TableFunction { expr: _, alias } => alias.map(|a| a.to_string()), TableFactor::Function { lateral: _, name: _, args: _, alias, - } => alias, + } => alias.map(|a| a.to_string()), TableFactor::UNNEST { alias, array_exprs: _, with_offset: _, with_offset_alias: _, - } => alias, + } => alias.map(|a| a.to_string()), TableFactor::NestedJoin { table_with_joins: _, alias, - } => alias, + } => alias.map(|a| a.to_string()), TableFactor::Pivot { table: _, aggregate_function: _, value_column: _, pivot_values: _, alias, - } => alias, + } => alias.map(|a| a.to_string()), TableFactor::Unpivot { table: _, value: _, name: _, columns: _, alias, - } => alias, + } => alias.map(|a| a.to_string()), TableFactor::JsonTable { json_expr: _, json_path: _, columns: _, alias, - } => alias, + } => alias.map(|a| a.to_string()), } } diff --git a/src/DeltaLake/Bridge/src/table.rs b/src/DeltaLake/Bridge/src/table.rs index 8554778..594759d 100644 --- a/src/DeltaLake/Bridge/src/table.rs +++ b/src/DeltaLake/Bridge/src/table.rs @@ -735,8 +735,8 @@ pub extern "C" fn table_merge( None => delete, }), }; - match res { - Ok(_) => todo!(), + mb = match res { + Ok(mb) => mb, Err(error) => unsafe { callback( std::ptr::null(), diff --git a/src/DeltaLake/Table/SelectQuery.cs b/src/DeltaLake/Table/SelectQuery.cs index dbb9276..20064bf 100644 --- a/src/DeltaLake/Table/SelectQuery.cs +++ b/src/DeltaLake/Table/SelectQuery.cs @@ -8,11 +8,13 @@ public class SelectQuery /// /// Create a query with the provided query and default table alias 'deltatable' /// - /// + /// A SQL SELECT query + /// Alias for the table used in the select query - public SelectQuery(string query) + public SelectQuery(string query, string tableAlias = "deltatable") { Query = query; + TableAlias = tableAlias; } /// @@ -23,6 +25,6 @@ public SelectQuery(string query) /// /// The name for the table used in the select query /// - public string TableAlias { get; init; } = "deltatable"; + public string TableAlias { get; init; } } } \ No newline at end of file diff --git a/tests/DeltaLake.Tests/Table/DeleteTests.cs b/tests/DeltaLake.Tests/Table/DeleteTests.cs index 59cbe40..2a734af 100644 --- a/tests/DeltaLake.Tests/Table/DeleteTests.cs +++ b/tests/DeltaLake.Tests/Table/DeleteTests.cs @@ -5,25 +5,33 @@ public class DeleteTests { public static IEnumerable BaseCases() { +#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type. yield return [1, default(string?), 0]; +#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type. yield return [1, "test < CAST(0 AS INT)", 1]; yield return [1, "test >= CAST(0 AS INT)", 0]; yield return [1, "second == 'test'", 1]; yield return [1, "second == '0'", 0]; yield return [1, "third < 1", 0]; +#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type. yield return [2, default(string?), 0]; +#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type. yield return [2, "test < CAST(0 AS INT)", 2]; yield return [2, "test >= CAST(0 AS INT)", 0]; yield return [2, "second == 'test'", 2]; yield return [2, "second == '0'", 1]; yield return [2, "third < 1", 1]; +#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type. yield return [10, default(string?), 0]; +#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type. yield return [10, "test < CAST(0 AS INT)", 10]; yield return [10, "test >= CAST(0 AS INT)", 0]; yield return [10, "second == 'test'", 10]; yield return [10, "second == '0'", 9]; yield return [10, "third < 1", 9]; +#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type. yield return [100, default(string?), 0]; +#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type. yield return [100, "test < CAST(0 AS INT)", 100]; yield return [100, "test >= CAST(0 AS INT)", 0]; yield return [100, "second == 'test'", 100]; diff --git a/tests/DeltaLake.Tests/Table/MergeTests.cs b/tests/DeltaLake.Tests/Table/MergeTests.cs new file mode 100644 index 0000000..9bcfbba --- /dev/null +++ b/tests/DeltaLake.Tests/Table/MergeTests.cs @@ -0,0 +1,135 @@ +using Apache.Arrow; +using Apache.Arrow.Memory; +using DeltaLake.Table; + +namespace DeltaLake.Tests.Table; +public class MergeTests +{ + [Fact] + public async Task Merge_Memory_Full_Test() + { + var query = @"MERGE INTO mytable USING newdata + ON newdata.test = mytable.test + WHEN MATCHED THEN + UPDATE SET + second = newdata.second, + third = newdata.third + WHEN NOT MATCHED BY SOURCE THEN DELETE + WHEN NOT MATCHED BY TARGET + THEN INSERT ( + test, + second, + third + ) + VALUES ( + newdata.test, + 'inserted data', + 99 + )"; + await BaseMergeTest(query, batches => + { + var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column1.SequenceEqual([5, 6, 7, 8, 9, 10, 11, 12, 13, 14])); + var column2 = batches.SelectMany(batch => ((IReadOnlyList)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column2.SequenceEqual(["hello", "hello", "hello", "hello", "hello", "inserted data", "inserted data", "inserted data", "inserted data", "inserted data"])); + var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column3.SequenceEqual([99L, 99L, 99L, 99L, 99L, 100L, 100L, 100L, 100L, 100L])); + }); + } + + [Fact] + public async Task Merge_Memory_No_Delete_Test() + { + var query = @"MERGE INTO mytable USING newdata + ON newdata.test = mytable.test + WHEN MATCHED THEN + UPDATE SET + second = newdata.second, + third = newdata.third + WHEN NOT MATCHED BY TARGET + THEN INSERT ( + test, + second, + third + ) + VALUES ( + newdata.test, + 'inserted data', + 99 + )"; + await BaseMergeTest(query, batches => + { + var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column1.SequenceEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])); + var column2 = batches.SelectMany(batch => ((IReadOnlyList)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column2.SequenceEqual(["0", "1", "2", "3", "4", "hello", "hello", "hello", "hello", "hello", "inserted data", "inserted data", "inserted data", "inserted data", "inserted data"])); + var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column3.SequenceEqual([0L, 1L, 2L, 3L, 4L, 99L, 99L, 99L, 99L, 99L, 100L, 100L, 100L, 100L, 100L])); + }); + } + + [Fact] + public async Task Merge_Memory_No_Insert_Test() + { + var query = @"MERGE INTO mytable USING newdata + ON newdata.test = mytable.test + WHEN MATCHED THEN + UPDATE SET + second = newdata.second, + third = newdata.third"; + await BaseMergeTest(query, batches => + { + var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column1.SequenceEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])); + var column2 = batches.SelectMany(batch => ((IReadOnlyList)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column2.SequenceEqual(["0", "1", "2", "3", "4", "hello", "hello", "hello", "hello", "hello"])); + var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column3.SequenceEqual([0L, 1L, 2L, 3L, 4L, 100L, 100L, 100L, 100L, 100L])); + }); + } + + [Fact] + public async Task Merge_Memory_No_Update_Test() + { + var query = @"MERGE INTO mytable USING newdata + ON newdata.test = mytable.test + WHEN NOT MATCHED BY TARGET + THEN INSERT ( + test, + second, + third + ) + VALUES ( + newdata.test, + 'inserted data', + 99 + )"; + await BaseMergeTest(query, batches => + { + var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column1.SequenceEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])); + var column2 = batches.SelectMany(batch => ((IReadOnlyList)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column2.SequenceEqual(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "inserted data", "inserted data", "inserted data", "inserted data", "inserted data"])); + var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray(); + Assert.True(column3.SequenceEqual([0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 99L, 99L, 99L, 99L, 99L])); + }); + } + + private async Task BaseMergeTest(string query, Action> assertions) + { + var pair = await TableHelpers.SetupTable($"memory://{Guid.NewGuid():N}", 10); + using var runtime = pair.runtime; + using var table = pair.table; + var allocator = new NativeMemoryAllocator(); + var enumerable = Enumerable.Range(5, 10); + var recordBatchBuilder = new RecordBatch.Builder(allocator) + .Append("test", false, col => col.Int32(arr => arr.AppendRange(enumerable))) + .Append("second", false, col => col.String(arr => arr.AppendRange(enumerable.Select(_ => "hello")))) + .Append("third", false, col => col.Int64(arr => arr.AppendRange(enumerable.Select(_ => 100L)))); + using var rb = recordBatchBuilder.Build(); + await table.MergeAsync(query, [rb], rb.Schema, CancellationToken.None); + Console.WriteLine("merged!"); + var batches = table.QueryAsync(new SelectQuery("select * from deltatable"), CancellationToken.None).ToBlockingEnumerable().ToList(); + assertions(batches); + } +} \ No newline at end of file diff --git a/tests/DeltaLake.Tests/Table/UpdateTests.cs b/tests/DeltaLake.Tests/Table/UpdateTests.cs new file mode 100644 index 0000000..bc5284a --- /dev/null +++ b/tests/DeltaLake.Tests/Table/UpdateTests.cs @@ -0,0 +1,71 @@ +using Apache.Arrow; +using DeltaLake.Table; + +namespace DeltaLake.Tests.Table; +public class UpdateTests +{ + private const string UpdateGreaterThanOne = "UPDATE test SET test = test + CAST(1 AS INT) WHERE test > CAST(1 AS INT)"; + private const string UpdateNoPredicate = "UPDATE test SET test = test + CAST(1 AS INT)"; + + private const string UpdatePredicateMiss = "UPDATE test SET test = test + CAST(1 AS INT) WHERE test < CAST(0 as INT)"; + + [Theory] + [InlineData(1, UpdateGreaterThanOne)] + [InlineData(2, UpdateGreaterThanOne)] + [InlineData(10, UpdateGreaterThanOne)] + [InlineData(100, UpdateGreaterThanOne)] + public async Task Memory_Update_Variable_Record_Test( + int length, + string predicate) + { + var totalValue = length < 2 ? length - 1 : (length - 2) / 2 * (length + 3) + 1; + await BaseUpdateTest($"memory://{Guid.NewGuid():N}", length, predicate, totalValue); + } + + [Theory] + [InlineData(1, UpdateNoPredicate)] + [InlineData(2, UpdateNoPredicate)] + [InlineData(10, UpdateNoPredicate)] + [InlineData(100, UpdateNoPredicate)] + public async Task Memory_Update_No_Predicate_Variable_Record_Test( + int length, + string predicate) + { + var totalValue = length == 1 ? 1 : length / 2 * (length + 1); + await BaseUpdateTest($"memory://{Guid.NewGuid():N}", length, predicate, totalValue); + } + + [Theory] + [InlineData(1, UpdatePredicateMiss)] + [InlineData(2, UpdatePredicateMiss)] + [InlineData(10, UpdatePredicateMiss)] + [InlineData(100, UpdatePredicateMiss)] + public async Task Memory_Update_Predicate_Miss_Variable_Record_Test( + int length, + string predicate) + { + var totalValue = length / 2 * (length - 1); + await BaseUpdateTest($"memory://{Guid.NewGuid():N}", length, predicate, totalValue); + } + + private async static Task BaseUpdateTest( + string path, + int length, + string query, + long expectedTotal) + { + var data = await TableHelpers.SetupTable(path, length); + using var runtime = data.runtime; + using var table = data.table; + await table.UpdateAsync(query, CancellationToken.None); + var queryResult = table.QueryAsync(new SelectQuery("select SUM(test) from test") + { + TableAlias = "test", + }, + CancellationToken.None).ToBlockingEnumerable().ToList(); + var totalValue = queryResult.Select(rb => ((Int64Array)rb.Column(0)).Sum(i => i!.Value)).Sum(); + var totalRecords = queryResult.Select(s => s.Length).Sum(); + Assert.Equal(1, totalRecords); + Assert.Equal(expectedTotal, totalValue); + } +} \ No newline at end of file