Skip to content

Commit

Permalink
added merge tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mightyshazam committed Feb 9, 2024
1 parent 4972cb8 commit baec815
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 17 deletions.
25 changes: 13 additions & 12 deletions src/DeltaLake/Bridge/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use arrow::{
use deltalake::datafusion::{
physical_plan::SendableRecordBatchStream,
sql::sqlparser::{
ast::{Assignment, Expr, Ident, TableAlias, TableFactor, Values},
ast::{Assignment, Expr, Ident, TableFactor, Values},
dialect::{Dialect, GenericDialect},
keywords::Keyword,
parser::{Parser, ParserError},
Expand Down Expand Up @@ -55,6 +55,7 @@ impl<'a> DeltaLakeParser<'a> {
}

pub fn parse_merge(&mut self) -> Result<Statement, ParserError> {
let _ = self.parser.parse_keyword(Keyword::MERGE);
let into = self.parser.parse_keyword(Keyword::INTO);

let table = self.parser.parse_table_factor()?;
Expand Down Expand Up @@ -259,58 +260,58 @@ pub enum MergeClause {
NotMatchedBySourceDelete(Option<Expr>),
}

pub fn extract_table_factor_alias(table: TableFactor) -> Option<TableAlias> {
pub fn extract_table_factor_alias(table: TableFactor) -> Option<String> {
match table {
TableFactor::Table {
name: _,
name,
alias,
args: _,
with_hints: _,
version: _,
partitions: _,
} => alias,
} => alias.map(|a| a.to_string()).or(Some(name.to_string())),
TableFactor::Derived {
lateral: _,
subquery: _,
alias,
} => alias,
TableFactor::TableFunction { expr: _, alias } => alias,
} => alias.map(|a| a.to_string()),
TableFactor::TableFunction { expr: _, alias } => alias.map(|a| a.to_string()),
TableFactor::Function {
lateral: _,
name: _,
args: _,
alias,
} => alias,
} => alias.map(|a| a.to_string()),
TableFactor::UNNEST {
alias,
array_exprs: _,
with_offset: _,
with_offset_alias: _,
} => alias,
} => alias.map(|a| a.to_string()),
TableFactor::NestedJoin {
table_with_joins: _,
alias,
} => alias,
} => alias.map(|a| a.to_string()),
TableFactor::Pivot {
table: _,
aggregate_function: _,
value_column: _,
pivot_values: _,
alias,
} => alias,
} => alias.map(|a| a.to_string()),
TableFactor::Unpivot {
table: _,
value: _,
name: _,
columns: _,
alias,
} => alias,
} => alias.map(|a| a.to_string()),
TableFactor::JsonTable {
json_expr: _,
json_path: _,
columns: _,
alias,
} => alias,
} => alias.map(|a| a.to_string()),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/DeltaLake/Bridge/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,8 @@ pub extern "C" fn table_merge(
None => delete,
}),
};
match res {
Ok(_) => todo!(),
mb = match res {
Ok(mb) => mb,
Err(error) => unsafe {
callback(
std::ptr::null(),
Expand Down
8 changes: 5 additions & 3 deletions src/DeltaLake/Table/SelectQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ public class SelectQuery
/// <summary>
/// Create a query with the provided query and default table alias 'deltatable'
/// </summary>
/// <param name="query"></param>
/// <param name="query">A SQL SELECT query</param>
/// <param name="tableAlias">Alias for the table used in the select query</param>

public SelectQuery(string query)
public SelectQuery(string query, string tableAlias = "deltatable")
{
Query = query;
TableAlias = tableAlias;
}

/// <summary>
Expand All @@ -23,6 +25,6 @@ public SelectQuery(string query)
/// <summary>
/// The name for the table used in the select query
/// </summary>
public string TableAlias { get; init; } = "deltatable";
public string TableAlias { get; init; }
}
}
8 changes: 8 additions & 0 deletions tests/DeltaLake.Tests/Table/DeleteTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,33 @@ public class DeleteTests
{
public static IEnumerable<object[]> BaseCases()
{
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [1, default(string?), 0];
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [1, "test < CAST(0 AS INT)", 1];
yield return [1, "test >= CAST(0 AS INT)", 0];
yield return [1, "second == 'test'", 1];
yield return [1, "second == '0'", 0];
yield return [1, "third < 1", 0];
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [2, default(string?), 0];
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [2, "test < CAST(0 AS INT)", 2];
yield return [2, "test >= CAST(0 AS INT)", 0];
yield return [2, "second == 'test'", 2];
yield return [2, "second == '0'", 1];
yield return [2, "third < 1", 1];
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [10, default(string?), 0];
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [10, "test < CAST(0 AS INT)", 10];
yield return [10, "test >= CAST(0 AS INT)", 0];
yield return [10, "second == 'test'", 10];
yield return [10, "second == '0'", 9];
yield return [10, "third < 1", 9];
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [100, default(string?), 0];
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
yield return [100, "test < CAST(0 AS INT)", 100];
yield return [100, "test >= CAST(0 AS INT)", 0];
yield return [100, "second == 'test'", 100];
Expand Down
135 changes: 135 additions & 0 deletions tests/DeltaLake.Tests/Table/MergeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using Apache.Arrow;
using Apache.Arrow.Memory;
using DeltaLake.Table;

namespace DeltaLake.Tests.Table;
public class MergeTests
{
[Fact]
public async Task Merge_Memory_Full_Test()
{
var query = @"MERGE INTO mytable USING newdata
ON newdata.test = mytable.test
WHEN MATCHED THEN
UPDATE SET
second = newdata.second,
third = newdata.third
WHEN NOT MATCHED BY SOURCE THEN DELETE
WHEN NOT MATCHED BY TARGET
THEN INSERT (
test,
second,
third
)
VALUES (
newdata.test,
'inserted data',
99
)";
await BaseMergeTest(query, batches =>
{
var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column1.SequenceEqual([5, 6, 7, 8, 9, 10, 11, 12, 13, 14]));
var column2 = batches.SelectMany(batch => ((IReadOnlyList<string>)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray();
Assert.True(column2.SequenceEqual(["hello", "hello", "hello", "hello", "hello", "inserted data", "inserted data", "inserted data", "inserted data", "inserted data"]));
var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column3.SequenceEqual([99L, 99L, 99L, 99L, 99L, 100L, 100L, 100L, 100L, 100L]));
});
}

[Fact]
public async Task Merge_Memory_No_Delete_Test()
{
var query = @"MERGE INTO mytable USING newdata
ON newdata.test = mytable.test
WHEN MATCHED THEN
UPDATE SET
second = newdata.second,
third = newdata.third
WHEN NOT MATCHED BY TARGET
THEN INSERT (
test,
second,
third
)
VALUES (
newdata.test,
'inserted data',
99
)";
await BaseMergeTest(query, batches =>
{
var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column1.SequenceEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]));
var column2 = batches.SelectMany(batch => ((IReadOnlyList<string>)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray();
Assert.True(column2.SequenceEqual(["0", "1", "2", "3", "4", "hello", "hello", "hello", "hello", "hello", "inserted data", "inserted data", "inserted data", "inserted data", "inserted data"]));
var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column3.SequenceEqual([0L, 1L, 2L, 3L, 4L, 99L, 99L, 99L, 99L, 99L, 100L, 100L, 100L, 100L, 100L]));
});
}

[Fact]
public async Task Merge_Memory_No_Insert_Test()
{
var query = @"MERGE INTO mytable USING newdata
ON newdata.test = mytable.test
WHEN MATCHED THEN
UPDATE SET
second = newdata.second,
third = newdata.third";
await BaseMergeTest(query, batches =>
{
var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column1.SequenceEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]));
var column2 = batches.SelectMany(batch => ((IReadOnlyList<string>)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray();
Assert.True(column2.SequenceEqual(["0", "1", "2", "3", "4", "hello", "hello", "hello", "hello", "hello"]));
var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column3.SequenceEqual([0L, 1L, 2L, 3L, 4L, 100L, 100L, 100L, 100L, 100L]));
});
}

[Fact]
public async Task Merge_Memory_No_Update_Test()
{
var query = @"MERGE INTO mytable USING newdata
ON newdata.test = mytable.test
WHEN NOT MATCHED BY TARGET
THEN INSERT (
test,
second,
third
)
VALUES (
newdata.test,
'inserted data',
99
)";
await BaseMergeTest(query, batches =>
{
var column1 = batches.SelectMany(batch => ((Int32Array)batch.Column(0)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column1.SequenceEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]));
var column2 = batches.SelectMany(batch => ((IReadOnlyList<string>)(StringArray)batch.Column(1)).ToArray()).OrderBy(i => i).ToArray();
Assert.True(column2.SequenceEqual(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "inserted data", "inserted data", "inserted data", "inserted data", "inserted data"]));
var column3 = batches.SelectMany(batch => ((Int64Array)batch.Column(2)).Values.ToArray()).OrderBy(i => i).ToArray();
Assert.True(column3.SequenceEqual([0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 99L, 99L, 99L, 99L, 99L]));
});
}

private async Task BaseMergeTest(string query, Action<IReadOnlyList<RecordBatch>> assertions)
{
var pair = await TableHelpers.SetupTable($"memory://{Guid.NewGuid():N}", 10);
using var runtime = pair.runtime;
using var table = pair.table;
var allocator = new NativeMemoryAllocator();
var enumerable = Enumerable.Range(5, 10);
var recordBatchBuilder = new RecordBatch.Builder(allocator)
.Append("test", false, col => col.Int32(arr => arr.AppendRange(enumerable)))
.Append("second", false, col => col.String(arr => arr.AppendRange(enumerable.Select(_ => "hello"))))
.Append("third", false, col => col.Int64(arr => arr.AppendRange(enumerable.Select(_ => 100L))));
using var rb = recordBatchBuilder.Build();
await table.MergeAsync(query, [rb], rb.Schema, CancellationToken.None);
Console.WriteLine("merged!");
var batches = table.QueryAsync(new SelectQuery("select * from deltatable"), CancellationToken.None).ToBlockingEnumerable().ToList();
assertions(batches);
}
}
71 changes: 71 additions & 0 deletions tests/DeltaLake.Tests/Table/UpdateTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using Apache.Arrow;
using DeltaLake.Table;

namespace DeltaLake.Tests.Table;
public class UpdateTests
{
private const string UpdateGreaterThanOne = "UPDATE test SET test = test + CAST(1 AS INT) WHERE test > CAST(1 AS INT)";
private const string UpdateNoPredicate = "UPDATE test SET test = test + CAST(1 AS INT)";

private const string UpdatePredicateMiss = "UPDATE test SET test = test + CAST(1 AS INT) WHERE test < CAST(0 as INT)";

[Theory]
[InlineData(1, UpdateGreaterThanOne)]
[InlineData(2, UpdateGreaterThanOne)]
[InlineData(10, UpdateGreaterThanOne)]
[InlineData(100, UpdateGreaterThanOne)]
public async Task Memory_Update_Variable_Record_Test(
int length,
string predicate)
{
var totalValue = length < 2 ? length - 1 : (length - 2) / 2 * (length + 3) + 1;
await BaseUpdateTest($"memory://{Guid.NewGuid():N}", length, predicate, totalValue);
}

[Theory]
[InlineData(1, UpdateNoPredicate)]
[InlineData(2, UpdateNoPredicate)]
[InlineData(10, UpdateNoPredicate)]
[InlineData(100, UpdateNoPredicate)]
public async Task Memory_Update_No_Predicate_Variable_Record_Test(
int length,
string predicate)
{
var totalValue = length == 1 ? 1 : length / 2 * (length + 1);
await BaseUpdateTest($"memory://{Guid.NewGuid():N}", length, predicate, totalValue);
}

[Theory]
[InlineData(1, UpdatePredicateMiss)]
[InlineData(2, UpdatePredicateMiss)]
[InlineData(10, UpdatePredicateMiss)]
[InlineData(100, UpdatePredicateMiss)]
public async Task Memory_Update_Predicate_Miss_Variable_Record_Test(
int length,
string predicate)
{
var totalValue = length / 2 * (length - 1);
await BaseUpdateTest($"memory://{Guid.NewGuid():N}", length, predicate, totalValue);
}

private async static Task BaseUpdateTest(
string path,
int length,
string query,
long expectedTotal)
{
var data = await TableHelpers.SetupTable(path, length);
using var runtime = data.runtime;
using var table = data.table;
await table.UpdateAsync(query, CancellationToken.None);
var queryResult = table.QueryAsync(new SelectQuery("select SUM(test) from test")
{
TableAlias = "test",
},
CancellationToken.None).ToBlockingEnumerable().ToList();
var totalValue = queryResult.Select(rb => ((Int64Array)rb.Column(0)).Sum(i => i!.Value)).Sum();
var totalRecords = queryResult.Select(s => s.Length).Sum();
Assert.Equal(1, totalRecords);
Assert.Equal(expectedTotal, totalValue);
}
}

0 comments on commit baec815

Please sign in to comment.