Skip to content

Commit

Permalink
Combine async methods that were split into two parts.
Browse files Browse the repository at this point in the history
These were originally separated so that a Task wasn't allocated for the case of synchronous completion. Now with ValueTask (and much better code generation of async methods by the C# compiler), this code level optimisation is no longer necessary to avoid unnecessary work and allocations.

Simplify code by removing the HandleTimeout method that was doing the same thing as SetFailed.

Note that Benchmark.NET reports (for a simplified version of the code with a do-nothing ReadPayloadAsync method) that the previous code took 40ns and the new code takes 60ns. So there is still overhead of going through a compiler-generated async state machine instead of writing it out by hand. However, this may be insignificant compared to the time taken for async network I/O, and this change makes the code easier to maintain.
  • Loading branch information
bgrainger committed Aug 15, 2023
1 parent f09893e commit ce3a243
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 84 deletions.
74 changes: 16 additions & 58 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -906,81 +906,44 @@ public ValueTask<PayloadData> ReceiveAsync(IOBehavior ioBehavior, CancellationTo
}

// Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'.
public ValueTask<PayloadData> ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
public async ValueTask<PayloadData> ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
ValueTask<ArraySegment<byte>> task;
try
{
VerifyConnected();
task = m_payloadHandler!.ReadPayloadAsync(m_payloadCache, ProtocolErrorBehavior.Throw, ioBehavior);
}
catch (Exception ex)
if (CreateExceptionForInvalidState() is { } exception)
{
Log.FailedInReceiveReplyAsync(m_logger, ex, Id);
if ((ex as MySqlException)?.ErrorCode == MySqlErrorCode.CommandTimeoutExpired)
HandleTimeout();
task = ValueTaskExtensions.FromException<ArraySegment<byte>>(ex);
Log.FailedInReceiveReplyAsync(m_logger, exception, Id);
throw exception;
}

if (task.IsCompletedSuccessfully)
{
var payload = new PayloadData(task.Result);
if (payload.HeaderByte != ErrorPayload.Signature)
return new ValueTask<PayloadData>(payload);

var exception = CreateExceptionForErrorPayload(payload.Span);
return ValueTaskExtensions.FromException<PayloadData>(exception);
}

return ReceiveReplyAsyncAwaited(task);
}

private async ValueTask<PayloadData> ReceiveReplyAsyncAwaited(ValueTask<ArraySegment<byte>> task)
{
ArraySegment<byte> bytes;
try
{
bytes = await task.ConfigureAwait(false);
bytes = await m_payloadHandler!.ReadPayloadAsync(m_payloadCache, ProtocolErrorBehavior.Throw, ioBehavior).ConfigureAwait(false);
}
catch (Exception ex)
{
SetFailed(ex);
if (ex is MySqlException { ErrorCode: MySqlErrorCode.CommandTimeoutExpired })
HandleTimeout();
throw;
}

var payload = new PayloadData(bytes);
if (payload.HeaderByte == ErrorPayload.Signature)
throw CreateExceptionForErrorPayload(payload.Span);

return payload;
}

// Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'.
public ValueTask SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken)
public async ValueTask SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
ValueTask task;
try
{
VerifyConnected();
task = m_payloadHandler!.WritePayloadAsync(payload.Memory, ioBehavior);
}
catch (Exception ex)
if (CreateExceptionForInvalidState() is { } exception)
{
Log.FailedInSendReplyAsync(m_logger, ex, Id);
task = ValueTaskExtensions.FromException(ex);
Log.FailedInSendReplyAsync(m_logger, exception, Id);
throw exception;
}

if (task.IsCompletedSuccessfully)
return task;

return SendReplyAsyncAwaited(task);
}

private async ValueTask SendReplyAsyncAwaited(ValueTask task)
{
try
{
await task.ConfigureAwait(false);
await m_payloadHandler!.WritePayloadAsync(payload.Memory, ioBehavior).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand All @@ -1001,20 +964,15 @@ public static void ThrowIfStatementContainsDelimiter(MySqlException exception, I
}
}

internal void HandleTimeout()
{
if (OwningConnection is not null && OwningConnection.TryGetTarget(out var connection))
connection.SetState(ConnectionState.Closed);
}

private void VerifyConnected()
private Exception? CreateExceptionForInvalidState()
{
lock (m_lock)
{
if (m_state == State.Closed)
throw new ObjectDisposedException(nameof(ServerSession));
return new ObjectDisposedException(nameof(ServerSession));
if (m_state != State.Connected && m_state != State.Querying && m_state != State.CancelingQuery && m_state != State.ClearingPendingCancellation && m_state != State.Closing)
throw new InvalidOperationException("ServerSession is not connected.");
return new InvalidOperationException("ServerSession is not connected.");
return null;
}
}

Expand Down
41 changes: 15 additions & 26 deletions src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -473,39 +473,28 @@ public static async ValueTask<ArraySegment<byte>> ReadPayloadAsync(BufferedByteR
}
}

public static ValueTask WritePayloadAsync(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
public static async ValueTask WritePayloadAsync(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
{
return payload.Length <= MaxPacketSize ? WritePacketAsync(byteHandler, getNextSequenceNumber(), payload, ioBehavior) :
WritePayloadAsyncAwaited(byteHandler, getNextSequenceNumber, payload, ioBehavior);

static async ValueTask WritePayloadAsyncAwaited(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
var buffer = ArrayPool<byte>.Shared.Rent(Math.Min(MaxPacketSize, payload.Length) + 4);
try
{
for (var bytesSent = 0; bytesSent < payload.Length; bytesSent += MaxPacketSize)
var bytesSent = 0;
do
{
var contents = payload.Slice(bytesSent, Math.Min(MaxPacketSize, payload.Length - bytesSent));
await WritePacketAsync(byteHandler, getNextSequenceNumber(), contents, ioBehavior).ConfigureAwait(false);
}
}
}
var bufferLength = contents.Length + 4;

private static ValueTask WritePacketAsync(IByteHandler byteHandler, int sequenceNumber, ReadOnlyMemory<byte> contents, IOBehavior ioBehavior)
{
var bufferLength = contents.Length + 4;
var buffer = ArrayPool<byte>.Shared.Rent(bufferLength);
SerializationUtility.WriteUInt32((uint) contents.Length, buffer, 0, 3);
buffer[3] = (byte) sequenceNumber;
contents.CopyTo(buffer.AsMemory()[4..]);
var task = byteHandler.WriteBytesAsync(new ArraySegment<byte>(buffer, 0, bufferLength), ioBehavior);
if (task.IsCompletedSuccessfully)
{
ArrayPool<byte>.Shared.Return(buffer);
return default;
}
return WritePacketAsyncAwaited(task, buffer);
SerializationUtility.WriteUInt32((uint) contents.Length, buffer, 0, 3);
buffer[3] = (byte) getNextSequenceNumber();
contents.CopyTo(buffer.AsMemory(4));

static async ValueTask WritePacketAsyncAwaited(ValueTask task, byte[] buffer)
await byteHandler.WriteBytesAsync(new ArraySegment<byte>(buffer, 0, bufferLength), ioBehavior).ConfigureAwait(false);
bytesSent += contents.Length;
}
while (bytesSent < payload.Length);
}
finally
{
await task.ConfigureAwait(false);
ArrayPool<byte>.Shared.Return(buffer);
}
}
Expand Down

0 comments on commit ce3a243

Please sign in to comment.