Skip to content

Commit

Permalink
Solved the concurrency problem (tested on Linux)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdrakiburrahman committed Oct 7, 2024
1 parent 17d4641 commit 5907432
Showing 1 changed file with 85 additions and 47 deletions.
132 changes: 85 additions & 47 deletions tests/DeltaLake.Tests/Table/KernelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,38 @@ public class KernelTests
@"Failed to commit transaction"
};

/// <remarks>
///
/// Re: "Parallelism":
///
/// The test attempts to simulate real-life concurrency situations (parallel writers across processes),
/// such as performing writes into Delta Partitions and ensure these core scenarios do not regress.
///
/// The test DOES NOT attempt to test concurrency of a singleton <see cref="ITable"/> client etc.,
/// because both the <see cref="Bridge.Table"/> and the <see cref="Kernel.Core.Table"/> does not have
/// the locking mechanisms in place (yet) to guarantee safe concurrent access to underlying state and pointers.
///
/// Once we get there, a single <see cref="ITable"/> will be able to perform both concurrent reads and writes
/// by protecting the underling delta-rs and delta-kernel-rs state with necessary locks.
///
/// </remarks>
[Fact]
public async Task Multi_Partitioned_Table_Parallelized_Bridge_Write_Can_Be_Read_By_Kernel()
{
// Setup
//
int numRowsPerPartition = 10;
int numPartitions = 3;
int numTransactionPerStringPartition = 2;
int numTransactionPerIntegerPartition = 2;
int numRows = numRowsPerPartition * numPartitions * numTransactionPerStringPartition * numTransactionPerIntegerPartition;
int numParallelReads = 1; // TODO: Ensure this runs green with multiple threads before merging

int numWritesPerStringPartition = 3;
int numWritesPerIntegerPartition = 3;

int numReadsPerReader = 1; // TODO: Make this run with multiple passes

int numConcurrentWriters = numPartitions * numWritesPerStringPartition * numWritesPerIntegerPartition;
int numConcurrentReaders = 5;

int numRows = numRowsPerPartition * numConcurrentWriters;

var tempDir = Directory.CreateTempSubdirectory();
using IEngine engine = new DeltaEngine(EngineOptions.Default);
Expand All @@ -54,6 +75,7 @@ public async Task Multi_Partitioned_Table_Parallelized_Bridge_Write_Can_Be_Read_
Configuration = new Dictionary<string, string> { ["delta.dataSkippingNumIndexedCols"] = "32" },
PartitionBy = new[] { partitionStringColumnName, partitionIntegerColumnName },
};
var tableLoadOptions = new TableOptions() { TableLocation = tempDir.FullName };
var tableWriteOptions = new InsertOptions { SaveMode = SaveMode.Append };
var allocator = new NativeMemoryAllocator();
var randomValueGenerator = new Random();
Expand All @@ -75,25 +97,29 @@ public async Task Multi_Partitioned_Table_Parallelized_Bridge_Write_Can_Be_Read_
{
// Exercise: Parallelized writes via Bridge
//
using ITable table = await engine.CreateTableAsync(tableCreateOptions, CancellationToken.None);
// >>> sharedTable: Simulates the original author of the table, doesn't need to support concurrent creates
// >>> threadIsolatedTable: Simulates partitioned writers to the table, needs concurrent writes
//
using ITable sharedTable = await engine.CreateTableAsync(tableCreateOptions, CancellationToken.None);
var writeTasks = new List<Task>();
for (int i = 0; i < numPartitions; i++)
{
for (int j = 0; j < numTransactionPerStringPartition; j++)
for (int j = 0; j < numWritesPerStringPartition; j++)
{
for (int k = 0; k < numTransactionPerIntegerPartition; k++)
for (int k = 0; k < numWritesPerIntegerPartition; k++)
{
writeTasks.Add(Task.Run(async () =>
{
await policy.ExecuteAsync(async () =>
{
using ITable threadIsolatedTable = await engine.LoadTableAsync(tableLoadOptions, CancellationToken.None);
var partition = $"{hostNamePrefix}_{i}";
var recordBatchBuilder = new RecordBatch.Builder(allocator)
.Append(stringColumnName, false, col => col.String(arr => arr.AppendRange(Enumerable.Range(0, numRowsPerPartition).Select(_ => GenerateRandomString(randomValueGenerator)))))
.Append(partitionStringColumnName, false, col => col.String(arr => arr.AppendRange(Enumerable.Range(0, numRowsPerPartition).Select(_ => partition))))
.Append(partitionIntegerColumnName, false, col => col.Int32(arr => arr.AppendRange(Enumerable.Range(0, numRowsPerPartition).Select(_ => i * j * k))))
.Append(intColumnName, false, col => col.Int32(arr => arr.AppendRange(Enumerable.Range(0, numRowsPerPartition).Select(_ => randomValueGenerator.Next()))));
await table.InsertAsync(new[] { recordBatchBuilder.Build() }, schema, tableWriteOptions, CancellationToken.None);
await threadIsolatedTable.InsertAsync(new[] { recordBatchBuilder.Build() }, schema, tableWriteOptions, CancellationToken.None);
});
}));
}
Expand All @@ -102,58 +128,70 @@ await policy.ExecuteAsync(async () =>

await Task.WhenAll(writeTasks);

// Exercise: Parallelized writes via Kernel
//
// >>> sharedTable: Not used
// >>> threadIsolatedTable: Simulates parallel readers to the table, needs concurrent reads
//
var readTasks = new List<Task>();
for (int i = 0; i < numParallelReads; i++)
for (int i = 0; i < numConcurrentReaders; i++)
{
readTasks.Add(Task.Run(async () =>
{
// Exercise: Reads via Kernel
//
Apache.Arrow.Table arrowTable = table.ReadAsArrowTable();
DataFrame dataFrame = table.ReadAsDataFrame();
string stringResult = dataFrame.ToMarkdown();
using ITable threadIsolatedTable = await engine.LoadTableAsync(tableLoadOptions, CancellationToken.None);
// Validate: Data Integrity
// Multiple passes here ensures Kernel scan state is reset per read request
//
Assert.Equal(numRows, arrowTable.RowCount);
Assert.Equal(numRows, dataFrame.Rows.Count);
Assert.Equal(numRows, Regex.Matches(stringResult, hostNamePrefix).Count);
Assert.Equal(numColumns, arrowTable.ColumnCount);
Assert.Equal(numColumns, dataFrame.Columns.Count);
for (int j = 0; j < numReadsPerReader; j++)
{
// Exercise: Reads via Kernel
//
Apache.Arrow.Table arrowTable = threadIsolatedTable.ReadAsArrowTable();
DataFrame dataFrame = threadIsolatedTable.ReadAsDataFrame();
string stringResult = dataFrame.ToMarkdown();
var writerSchemaFieldMap = schema.FieldsList.ToDictionary(field => field.Name);
var kernelSchemaFieldMap = arrowTable.Schema.FieldsList.ToDictionary(field => field.Name);
var bridgeSchemaFieldMap = table.Schema().FieldsList.ToDictionary(field => field.Name);
// Validate: Data Integrity
//
Assert.Equal(numRows, arrowTable.RowCount);
Assert.Equal(numRows, dataFrame.Rows.Count);
Assert.Equal(numRows, Regex.Matches(stringResult, hostNamePrefix).Count);
Assert.Equal(numColumns, arrowTable.ColumnCount);
Assert.Equal(numColumns, dataFrame.Columns.Count);
// Validate: Schema Integrity
//
Assert.Equal(writerSchemaFieldMap.Count, kernelSchemaFieldMap.Count);
Assert.Equal(writerSchemaFieldMap.Count, bridgeSchemaFieldMap.Count);
Assert.Equal(writerSchemaFieldMap.Count, numColumns);
var writerSchemaFieldMap = schema.FieldsList.ToDictionary(field => field.Name);
var kernelSchemaFieldMap = arrowTable.Schema.FieldsList.ToDictionary(field => field.Name);
var bridgeSchemaFieldMap = threadIsolatedTable.Schema().FieldsList.ToDictionary(field => field.Name);
foreach (var kvp in writerSchemaFieldMap)
{
Assert.True(bridgeSchemaFieldMap.ContainsKey(kvp.Key));
Assert.Equal(kvp.Value.DataType, bridgeSchemaFieldMap[kvp.Key].DataType);
}
// Validate: Schema Integrity
//
Assert.Equal(writerSchemaFieldMap.Count, kernelSchemaFieldMap.Count);
Assert.Equal(writerSchemaFieldMap.Count, bridgeSchemaFieldMap.Count);
Assert.Equal(writerSchemaFieldMap.Count, numColumns);
foreach (var kvp in writerSchemaFieldMap)
{
Assert.True(kernelSchemaFieldMap.ContainsKey(kvp.Key));
if (kvp.Key == partitionIntegerColumnName)
foreach (var kvp in writerSchemaFieldMap)
{
// Kernel has a limitation where it can only report back String as the Partition
// values:
//
// >>> https://delta-users.slack.com/archives/C04TRPG3LHZ/p1728178727958499
//
Assert.Equal(StringType.Default, kernelSchemaFieldMap[kvp.Key].DataType);
Assert.Equal(Int32Type.Default, writerSchemaFieldMap[kvp.Key].DataType);
continue;
Assert.True(bridgeSchemaFieldMap.ContainsKey(kvp.Key));
Assert.Equal(kvp.Value.DataType, bridgeSchemaFieldMap[kvp.Key].DataType);
}
else
foreach (var kvp in writerSchemaFieldMap)
{
Assert.Equal(kvp.Value.DataType, kernelSchemaFieldMap[kvp.Key].DataType);
Assert.True(kernelSchemaFieldMap.ContainsKey(kvp.Key));
if (kvp.Key == partitionIntegerColumnName)
{
// Kernel has a limitation where it can only report back String as the Partition
// values:
//
// >>> https://delta-users.slack.com/archives/C04TRPG3LHZ/p1728178727958499
//
Assert.Equal(StringType.Default, kernelSchemaFieldMap[kvp.Key].DataType);
Assert.Equal(Int32Type.Default, writerSchemaFieldMap[kvp.Key].DataType);
continue;
}
else
{
Assert.Equal(kvp.Value.DataType, kernelSchemaFieldMap[kvp.Key].DataType);
}
}
}
}));
Expand Down

0 comments on commit 5907432

Please sign in to comment.