Skip to content

Commit

Permalink
Combine async methods that were split into two parts.
Browse files Browse the repository at this point in the history
These were originally separated so that a Task wasn't allocated for the case of synchronous completion. Now with ValueTask (and much better code generation of async methods by the C# compiler), this code level optimisation is no longer necessary to avoid unnecessary work and allocations.

Simplify code by removing the HandleTimeout method that was doing the same thing as SetFailed.

Note that Benchmark.NET reports (for a simplified version of ReceiveReplyAsync with a do-nothing ReadPayloadAsync method) that the previous code took 40ns and the new code takes 60ns. So there is still overhead of going through a compiler-generated async state machine instead of writing it out by hand. However, this may be insignificant compared to the time taken for async network I/O, and this change makes the code easier to maintain.
  • Loading branch information
bgrainger committed Aug 15, 2023
1 parent f09893e commit e4519e3
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 135 deletions.
73 changes: 30 additions & 43 deletions src/MySqlConnector/Core/ResultSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,61 +225,48 @@ public async Task<bool> ReadAsync(IOBehavior ioBehavior, CancellationToken cance
return row;
}

private ValueTask<Row?> ScanRowAsync(IOBehavior ioBehavior, Row? row, CancellationToken cancellationToken)
private async ValueTask<Row?> ScanRowAsync(IOBehavior ioBehavior, Row? row, CancellationToken cancellationToken)
{
// if we've already read past the end of this resultset, Read returns false
if (BufferState is ResultSetState.HasMoreData or ResultSetState.NoMoreData or ResultSetState.None)
return new ValueTask<Row?>(default(Row?));
return null;

var payloadValueTask = Session.ReceiveReplyAsync(ioBehavior, CancellationToken.None);
return payloadValueTask.IsCompletedSuccessfully
? new ValueTask<Row?>(ScanRowAsyncRemainder(this, payloadValueTask.Result, row))
: new ValueTask<Row?>(ScanRowAsyncAwaited(this, payloadValueTask.AsTask(), row, cancellationToken));
PayloadData payload;
try
{
payload = await Session.ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
}
catch (MySqlException ex)
{
BufferState = State = ResultSetState.NoMoreData;
if (ex.ErrorCode == MySqlErrorCode.QueryInterrupted && cancellationToken.IsCancellationRequested)
throw new OperationCanceledException(ex.Message, ex, cancellationToken);
if (ex.ErrorCode == MySqlErrorCode.QueryInterrupted && Command.CancellableCommand.IsTimedOut)
throw MySqlException.CreateForTimeout(ex);
throw;
}

static async Task<Row?> ScanRowAsyncAwaited(ResultSet resultSet, Task<PayloadData> payloadTask, Row? row, CancellationToken token)
if (payload.HeaderByte == EofPayload.Signature)
{
PayloadData payloadData;
try
if (Session.SupportsDeprecateEof && OkPayload.IsOk(payload.Span, Session.SupportsDeprecateEof))
{
payloadData = await payloadTask.ConfigureAwait(false);
var ok = OkPayload.Create(payload.Span, Session.SupportsDeprecateEof, Session.SupportsSessionTrack);
BufferState = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData;
return null;
}
catch (MySqlException ex)
if (!Session.SupportsDeprecateEof && EofPayload.IsEof(payload))
{
resultSet.BufferState = resultSet.State = ResultSetState.NoMoreData;
if (ex.ErrorCode == MySqlErrorCode.QueryInterrupted && token.IsCancellationRequested)
throw new OperationCanceledException(ex.Message, ex, token);
if (ex.ErrorCode == MySqlErrorCode.QueryInterrupted && resultSet.Command.CancellableCommand.IsTimedOut)
throw MySqlException.CreateForTimeout(ex);
throw;
var eof = EofPayload.Create(payload.Span);
BufferState = (eof.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData;
return null;
}
return ScanRowAsyncRemainder(resultSet, payloadData, row);
}

static Row? ScanRowAsyncRemainder(ResultSet resultSet, PayloadData payload, Row? row)
{
if (payload.HeaderByte == EofPayload.Signature)
{
var span = payload.Span;
if (resultSet.Session.SupportsDeprecateEof && OkPayload.IsOk(span, resultSet.Session.SupportsDeprecateEof))
{
var ok = OkPayload.Create(span, resultSet.Session.SupportsDeprecateEof, resultSet.Session.SupportsSessionTrack);
resultSet.BufferState = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData;
return null;
}
if (!resultSet.Session.SupportsDeprecateEof && EofPayload.IsEof(payload))
{
var eof = EofPayload.Create(span);
resultSet.BufferState = (eof.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData;
return null;
}
}

row ??= new Row(resultSet.Command.TryGetPreparedStatements() is not null, resultSet);
row.SetData(payload.Memory);
resultSet.m_hasRows = true;
resultSet.BufferState = ResultSetState.ReadingRows;
return row;
}
row ??= new Row(Command.TryGetPreparedStatements() is not null, this);
row.SetData(payload.Memory);
m_hasRows = true;
BufferState = ResultSetState.ReadingRows;
return row;
}

#pragma warning disable CA1822 // Mark members as static
Expand Down
74 changes: 16 additions & 58 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -906,81 +906,44 @@ public ValueTask<PayloadData> ReceiveAsync(IOBehavior ioBehavior, CancellationTo
}

// Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'.
public ValueTask<PayloadData> ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
public async ValueTask<PayloadData> ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
ValueTask<ArraySegment<byte>> task;
try
{
VerifyConnected();
task = m_payloadHandler!.ReadPayloadAsync(m_payloadCache, ProtocolErrorBehavior.Throw, ioBehavior);
}
catch (Exception ex)
if (CreateExceptionForInvalidState() is { } exception)
{
Log.FailedInReceiveReplyAsync(m_logger, ex, Id);
if ((ex as MySqlException)?.ErrorCode == MySqlErrorCode.CommandTimeoutExpired)
HandleTimeout();
task = ValueTaskExtensions.FromException<ArraySegment<byte>>(ex);
Log.FailedInReceiveReplyAsync(m_logger, exception, Id);
throw exception;
}

if (task.IsCompletedSuccessfully)
{
var payload = new PayloadData(task.Result);
if (payload.HeaderByte != ErrorPayload.Signature)
return new ValueTask<PayloadData>(payload);

var exception = CreateExceptionForErrorPayload(payload.Span);
return ValueTaskExtensions.FromException<PayloadData>(exception);
}

return ReceiveReplyAsyncAwaited(task);
}

private async ValueTask<PayloadData> ReceiveReplyAsyncAwaited(ValueTask<ArraySegment<byte>> task)
{
ArraySegment<byte> bytes;
try
{
bytes = await task.ConfigureAwait(false);
bytes = await m_payloadHandler!.ReadPayloadAsync(m_payloadCache, ProtocolErrorBehavior.Throw, ioBehavior).ConfigureAwait(false);
}
catch (Exception ex)
{
SetFailed(ex);
if (ex is MySqlException { ErrorCode: MySqlErrorCode.CommandTimeoutExpired })
HandleTimeout();
throw;
}

var payload = new PayloadData(bytes);
if (payload.HeaderByte == ErrorPayload.Signature)
throw CreateExceptionForErrorPayload(payload.Span);

return payload;
}

// Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'.
public ValueTask SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken)
public async ValueTask SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
ValueTask task;
try
{
VerifyConnected();
task = m_payloadHandler!.WritePayloadAsync(payload.Memory, ioBehavior);
}
catch (Exception ex)
if (CreateExceptionForInvalidState() is { } exception)
{
Log.FailedInSendReplyAsync(m_logger, ex, Id);
task = ValueTaskExtensions.FromException(ex);
Log.FailedInSendReplyAsync(m_logger, exception, Id);
throw exception;
}

if (task.IsCompletedSuccessfully)
return task;

return SendReplyAsyncAwaited(task);
}

private async ValueTask SendReplyAsyncAwaited(ValueTask task)
{
try
{
await task.ConfigureAwait(false);
await m_payloadHandler!.WritePayloadAsync(payload.Memory, ioBehavior).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand All @@ -1001,20 +964,15 @@ public static void ThrowIfStatementContainsDelimiter(MySqlException exception, I
}
}

internal void HandleTimeout()
{
if (OwningConnection is not null && OwningConnection.TryGetTarget(out var connection))
connection.SetState(ConnectionState.Closed);
}

private void VerifyConnected()
private Exception? CreateExceptionForInvalidState()
{
lock (m_lock)
{
if (m_state == State.Closed)
throw new ObjectDisposedException(nameof(ServerSession));
return new ObjectDisposedException(nameof(ServerSession));
if (m_state != State.Connected && m_state != State.Querying && m_state != State.CancelingQuery && m_state != State.ClearingPendingCancellation && m_state != State.Closing)
throw new InvalidOperationException("ServerSession is not connected.");
return new InvalidOperationException("ServerSession is not connected.");
return null;
}
}

Expand Down
11 changes: 3 additions & 8 deletions src/MySqlConnector/MySqlDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,20 @@ private void ActivateResultSet(CancellationToken cancellationToken)
m_hasWarnings = m_resultSet.WarningCount != 0;
}

private ValueTask ScanResultSetAsync(IOBehavior ioBehavior, ResultSet resultSet, CancellationToken cancellationToken)
private async ValueTask ScanResultSetAsync(IOBehavior ioBehavior, ResultSet resultSet, CancellationToken cancellationToken)
{
if (!m_hasMoreResults)
return default;
return;

if (resultSet.BufferState is ResultSetState.NoMoreData or ResultSetState.None)
{
m_hasMoreResults = false;
return default;
return;
}

if (resultSet.BufferState != ResultSetState.HasMoreData)
throw new InvalidOperationException($"Invalid state: {resultSet.BufferState}");

return new ValueTask(ScanResultSetAsyncAwaited(ioBehavior, resultSet, cancellationToken));
}

private async Task ScanResultSetAsyncAwaited(IOBehavior ioBehavior, ResultSet resultSet, CancellationToken cancellationToken)
{
using (Command!.CancellableCommand.RegisterCancel(cancellationToken))
{
try
Expand Down
41 changes: 15 additions & 26 deletions src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -473,39 +473,28 @@ public static async ValueTask<ArraySegment<byte>> ReadPayloadAsync(BufferedByteR
}
}

public static ValueTask WritePayloadAsync(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
public static async ValueTask WritePayloadAsync(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
{
return payload.Length <= MaxPacketSize ? WritePacketAsync(byteHandler, getNextSequenceNumber(), payload, ioBehavior) :
WritePayloadAsyncAwaited(byteHandler, getNextSequenceNumber, payload, ioBehavior);

static async ValueTask WritePayloadAsyncAwaited(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
var buffer = ArrayPool<byte>.Shared.Rent(Math.Min(MaxPacketSize, payload.Length) + 4);
try
{
for (var bytesSent = 0; bytesSent < payload.Length; bytesSent += MaxPacketSize)
var bytesSent = 0;
do
{
var contents = payload.Slice(bytesSent, Math.Min(MaxPacketSize, payload.Length - bytesSent));
await WritePacketAsync(byteHandler, getNextSequenceNumber(), contents, ioBehavior).ConfigureAwait(false);
}
}
}
var bufferLength = contents.Length + 4;

private static ValueTask WritePacketAsync(IByteHandler byteHandler, int sequenceNumber, ReadOnlyMemory<byte> contents, IOBehavior ioBehavior)
{
var bufferLength = contents.Length + 4;
var buffer = ArrayPool<byte>.Shared.Rent(bufferLength);
SerializationUtility.WriteUInt32((uint) contents.Length, buffer, 0, 3);
buffer[3] = (byte) sequenceNumber;
contents.CopyTo(buffer.AsMemory()[4..]);
var task = byteHandler.WriteBytesAsync(new ArraySegment<byte>(buffer, 0, bufferLength), ioBehavior);
if (task.IsCompletedSuccessfully)
{
ArrayPool<byte>.Shared.Return(buffer);
return default;
}
return WritePacketAsyncAwaited(task, buffer);
SerializationUtility.WriteUInt32((uint) contents.Length, buffer, 0, 3);
buffer[3] = (byte) getNextSequenceNumber();
contents.CopyTo(buffer.AsMemory(4));

static async ValueTask WritePacketAsyncAwaited(ValueTask task, byte[] buffer)
await byteHandler.WriteBytesAsync(new ArraySegment<byte>(buffer, 0, bufferLength), ioBehavior).ConfigureAwait(false);
bytesSent += contents.Length;
}
while (bytesSent < payload.Length);
}
finally
{
await task.ConfigureAwait(false);
ArrayPool<byte>.Shared.Return(buffer);
}
}
Expand Down

0 comments on commit e4519e3

Please sign in to comment.