From ce3a2432dc06f379c406f0d6b533c5a38c2ed758 Mon Sep 17 00:00:00 2001 From: Bradley Grainger Date: Tue, 15 Aug 2023 10:22:16 -0700 Subject: [PATCH] Combine async methods that were split into two parts. These were originally separated so that a Task wasn't allocated for the case of synchronous completion. Now with ValueTask (and much better code generation of async methods by the C# compiler), this code level optimisation is no longer necessary to avoid unnecessary work and allocations. Simplify code by removing the HandleTimeout method that was doing the same thing as SetFailed. Note that Benchmark.NET reports (for a simplified version of the code with a do-nothing ReadPayloadAsync method) that the previous code took 40ns and the new code takes 60ns. So there is still overhead of going through a compiler-generated async state machine instead of writing it out by hand. However, this may be insignificant compared to the time taken for async network I/O, and this change makes the code easier to maintain. --- src/MySqlConnector/Core/ServerSession.cs | 74 ++++--------------- .../Protocol/Serialization/ProtocolUtility.cs | 41 ++++------ 2 files changed, 31 insertions(+), 84 deletions(-) diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index f20992f27..10a92607c 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -906,81 +906,44 @@ public ValueTask ReceiveAsync(IOBehavior ioBehavior, CancellationTo } // Continues a conversation with the server by receiving a response to a packet sent with 'Send' or 'SendReply'. - public ValueTask ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + public async ValueTask ReceiveReplyAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) { - ValueTask> task; - try - { - VerifyConnected(); - task = m_payloadHandler!.ReadPayloadAsync(m_payloadCache, ProtocolErrorBehavior.Throw, ioBehavior); - } - catch (Exception ex) + if (CreateExceptionForInvalidState() is { } exception) { - Log.FailedInReceiveReplyAsync(m_logger, ex, Id); - if ((ex as MySqlException)?.ErrorCode == MySqlErrorCode.CommandTimeoutExpired) - HandleTimeout(); - task = ValueTaskExtensions.FromException>(ex); + Log.FailedInReceiveReplyAsync(m_logger, exception, Id); + throw exception; } - if (task.IsCompletedSuccessfully) - { - var payload = new PayloadData(task.Result); - if (payload.HeaderByte != ErrorPayload.Signature) - return new ValueTask(payload); - - var exception = CreateExceptionForErrorPayload(payload.Span); - return ValueTaskExtensions.FromException(exception); - } - - return ReceiveReplyAsyncAwaited(task); - } - - private async ValueTask ReceiveReplyAsyncAwaited(ValueTask> task) - { ArraySegment bytes; try { - bytes = await task.ConfigureAwait(false); + bytes = await m_payloadHandler!.ReadPayloadAsync(m_payloadCache, ProtocolErrorBehavior.Throw, ioBehavior).ConfigureAwait(false); } catch (Exception ex) { SetFailed(ex); - if (ex is MySqlException { ErrorCode: MySqlErrorCode.CommandTimeoutExpired }) - HandleTimeout(); throw; } + var payload = new PayloadData(bytes); if (payload.HeaderByte == ErrorPayload.Signature) throw CreateExceptionForErrorPayload(payload.Span); + return payload; } // Continues a conversation with the server by sending a reply to a packet received with 'Receive' or 'ReceiveReply'. - public ValueTask SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) + public async ValueTask SendReplyAsync(PayloadData payload, IOBehavior ioBehavior, CancellationToken cancellationToken) { - ValueTask task; - try - { - VerifyConnected(); - task = m_payloadHandler!.WritePayloadAsync(payload.Memory, ioBehavior); - } - catch (Exception ex) + if (CreateExceptionForInvalidState() is { } exception) { - Log.FailedInSendReplyAsync(m_logger, ex, Id); - task = ValueTaskExtensions.FromException(ex); + Log.FailedInSendReplyAsync(m_logger, exception, Id); + throw exception; } - if (task.IsCompletedSuccessfully) - return task; - - return SendReplyAsyncAwaited(task); - } - - private async ValueTask SendReplyAsyncAwaited(ValueTask task) - { try { - await task.ConfigureAwait(false); + await m_payloadHandler!.WritePayloadAsync(payload.Memory, ioBehavior).ConfigureAwait(false); } catch (Exception ex) { @@ -1001,20 +964,15 @@ public static void ThrowIfStatementContainsDelimiter(MySqlException exception, I } } - internal void HandleTimeout() - { - if (OwningConnection is not null && OwningConnection.TryGetTarget(out var connection)) - connection.SetState(ConnectionState.Closed); - } - - private void VerifyConnected() + private Exception? CreateExceptionForInvalidState() { lock (m_lock) { if (m_state == State.Closed) - throw new ObjectDisposedException(nameof(ServerSession)); + return new ObjectDisposedException(nameof(ServerSession)); if (m_state != State.Connected && m_state != State.Querying && m_state != State.CancelingQuery && m_state != State.ClearingPendingCancellation && m_state != State.Closing) - throw new InvalidOperationException("ServerSession is not connected."); + return new InvalidOperationException("ServerSession is not connected."); + return null; } } diff --git a/src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs b/src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs index 9736004dc..4c3096001 100644 --- a/src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs +++ b/src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs @@ -473,39 +473,28 @@ public static async ValueTask> ReadPayloadAsync(BufferedByteR } } - public static ValueTask WritePayloadAsync(IByteHandler byteHandler, Func getNextSequenceNumber, ReadOnlyMemory payload, IOBehavior ioBehavior) + public static async ValueTask WritePayloadAsync(IByteHandler byteHandler, Func getNextSequenceNumber, ReadOnlyMemory payload, IOBehavior ioBehavior) { - return payload.Length <= MaxPacketSize ? WritePacketAsync(byteHandler, getNextSequenceNumber(), payload, ioBehavior) : - WritePayloadAsyncAwaited(byteHandler, getNextSequenceNumber, payload, ioBehavior); - - static async ValueTask WritePayloadAsyncAwaited(IByteHandler byteHandler, Func getNextSequenceNumber, ReadOnlyMemory payload, IOBehavior ioBehavior) + var buffer = ArrayPool.Shared.Rent(Math.Min(MaxPacketSize, payload.Length) + 4); + try { - for (var bytesSent = 0; bytesSent < payload.Length; bytesSent += MaxPacketSize) + var bytesSent = 0; + do { var contents = payload.Slice(bytesSent, Math.Min(MaxPacketSize, payload.Length - bytesSent)); - await WritePacketAsync(byteHandler, getNextSequenceNumber(), contents, ioBehavior).ConfigureAwait(false); - } - } - } + var bufferLength = contents.Length + 4; - private static ValueTask WritePacketAsync(IByteHandler byteHandler, int sequenceNumber, ReadOnlyMemory contents, IOBehavior ioBehavior) - { - var bufferLength = contents.Length + 4; - var buffer = ArrayPool.Shared.Rent(bufferLength); - SerializationUtility.WriteUInt32((uint) contents.Length, buffer, 0, 3); - buffer[3] = (byte) sequenceNumber; - contents.CopyTo(buffer.AsMemory()[4..]); - var task = byteHandler.WriteBytesAsync(new ArraySegment(buffer, 0, bufferLength), ioBehavior); - if (task.IsCompletedSuccessfully) - { - ArrayPool.Shared.Return(buffer); - return default; - } - return WritePacketAsyncAwaited(task, buffer); + SerializationUtility.WriteUInt32((uint) contents.Length, buffer, 0, 3); + buffer[3] = (byte) getNextSequenceNumber(); + contents.CopyTo(buffer.AsMemory(4)); - static async ValueTask WritePacketAsyncAwaited(ValueTask task, byte[] buffer) + await byteHandler.WriteBytesAsync(new ArraySegment(buffer, 0, bufferLength), ioBehavior).ConfigureAwait(false); + bytesSent += contents.Length; + } + while (bytesSent < payload.Length); + } + finally { - await task.ConfigureAwait(false); ArrayPool.Shared.Return(buffer); } }