Skip to content

Commit

Permalink
Write to the writing Pipeline directly
Browse files Browse the repository at this point in the history
See:
https://blog.marcgravell.com/2018/07/pipe-dreams-part-3.html#writes-and-wrongs

* Fixup API validation
* Add assertion
* Fix a missing Advance/1 call
* Add FUTURE
* Replace span slices with range operator.
* Fix name violations
  • Loading branch information
lukebakken committed Jul 23, 2024
1 parent fc82cbc commit 2ef446a
Show file tree
Hide file tree
Showing 53 changed files with 366 additions and 240 deletions.
2 changes: 1 addition & 1 deletion RabbitMQ.Stream.Client/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public OutgoingMsg(byte publisherId, ulong publishingId, Message data)
public Message Data => data;
public int SizeNeeded => 0;

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
throw new NotImplementedException();
}
Expand Down
14 changes: 8 additions & 6 deletions RabbitMQ.Stream.Client/CloseRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// 2.0, and the Mozilla Public License, version 2.0.
// Copyright (c) 2017-2023 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.

using System;
using System.Buffers;

namespace RabbitMQ.Stream.Client
{
Expand All @@ -20,13 +20,15 @@ public CloseRequest(uint correlationId, string reason)

public int SizeNeeded => 10 + WireFormatting.StringSize(reason);

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
var span = writer.GetSpan(SizeNeeded);
var offset = WireFormatting.WriteUInt16(span, Key);
offset += WireFormatting.WriteUInt16(span.Slice(offset), ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span.Slice(offset), correlationId);
offset += WireFormatting.WriteUInt16(span.Slice(offset), 1); //ok code
offset += WireFormatting.WriteString(span.Slice(offset), reason);
offset += WireFormatting.WriteUInt16(span[offset..], ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span[offset..], correlationId);
offset += WireFormatting.WriteUInt16(span[offset..], 1); //ok code
offset += WireFormatting.WriteString(span[offset..], reason);
writer.Advance(offset);
return offset;
}
}
Expand Down
3 changes: 2 additions & 1 deletion RabbitMQ.Stream.Client/CloseResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ public CloseResponse(uint correlationId, ResponseCode responseCode)

public ResponseCode ResponseCode => responseCode;

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
throw new NotImplementedException();
}

internal static int Read(ReadOnlySequence<byte> frame, out CloseResponse command)
{
var offset = WireFormatting.ReadUInt16(frame, out _);
Expand Down
5 changes: 3 additions & 2 deletions RabbitMQ.Stream.Client/CommandVersionsRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// 2.0, and the Mozilla Public License, version 2.0.
// Copyright (c) 2017-2023 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.

using System;
using System.Buffers;

namespace RabbitMQ.Stream.Client;

Expand All @@ -29,8 +29,9 @@ public int SizeNeeded
}
}

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
var span = writer.GetSpan(SizeNeeded);
var offset = WireFormatting.WriteUInt16(span, Key);
offset += WireFormatting.WriteUInt16(span[offset..], ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span[offset..], _correlationId);
Expand Down
6 changes: 5 additions & 1 deletion RabbitMQ.Stream.Client/CommandVersionsResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ private CommandVersionsResponse(uint correlationId, ResponseCode responseCode, L
}

public int SizeNeeded { get => throw new NotImplementedException(); }
public int Write(Span<byte> span) => throw new NotImplementedException();

public uint CorrelationId { get; }
public ResponseCode ResponseCode { get; }
Expand All @@ -47,4 +46,9 @@ internal static int Read(ReadOnlySequence<byte> frame, out CommandVersionsRespon
command = new CommandVersionsResponse(correlation, (ResponseCode)responseCode, commands);
return offset;
}

public int Write(IBufferWriter<byte> writer)
{
throw new NotImplementedException();
}
}
8 changes: 4 additions & 4 deletions RabbitMQ.Stream.Client/Compression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ public int Write(Span<byte> span)
var offset = 0;
foreach (var msg in messages)
{
offset += WireFormatting.WriteUInt32(span.Slice(offset), (uint)msg.Size);
offset += msg.Write(span.Slice(offset));
offset += WireFormatting.WriteUInt32(span[offset..], (uint)msg.Size);
offset += msg.Write(span[offset..]);
}

return offset;
Expand Down Expand Up @@ -89,8 +89,8 @@ public void Compress(List<Message> messages)
var offset = 0;
foreach (var msg in messages)
{
offset += WireFormatting.WriteUInt32(span.Slice(offset), (uint)msg.Size);
offset += msg.Write(span.Slice(offset));
offset += WireFormatting.WriteUInt32(span[offset..], (uint)msg.Size);
offset += msg.Write(span[offset..]);
}

using var compressedMemory = new MemoryStream();
Expand Down
129 changes: 94 additions & 35 deletions RabbitMQ.Stream.Client/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ internal static class ConnectionClosedReason

public class Connection : IDisposable
{
private readonly Socket socket;
private readonly PipeWriter writer;
private readonly PipeReader reader;
private readonly Socket _socket;
private readonly PipeWriter _writer;
private readonly PipeReader _reader;
private readonly Task _incomingFramesTask;
private readonly Func<Memory<byte>, Task> commandCallback;
private readonly Func<string, Task> closedCallback;
private readonly Func<Memory<byte>, Task> _commandCallback;
private readonly Func<string, Task> _closedCallback;
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1, 1);
private int numFrames;
private bool isClosed = false;
Expand Down Expand Up @@ -56,13 +56,13 @@ private Connection(Socket socket, Func<Memory<byte>, Task> callback,
Func<string, Task> closedCallBack, SslOption sslOption, ILogger logger)
{
_logger = logger;
this.socket = socket;
commandCallback = callback;
closedCallback = closedCallBack;
_socket = socket;
_commandCallback = callback;
_closedCallback = closedCallBack;
var networkStream = new NetworkStream(socket);
var stream = MaybeTcpUpgrade(networkStream, sslOption);
writer = PipeWriter.Create(stream);
reader = PipeReader.Create(stream);
_writer = PipeWriter.Create(stream);
_reader = PipeReader.Create(stream);
// ProcessIncomingFrames is dropped as soon as the connection is closed
// no need to stop it manually when the connection is closed
_incomingFramesTask = Task.Run(ProcessIncomingFrames);
Expand Down Expand Up @@ -101,17 +101,50 @@ public static async Task<Connection> Create(EndPoint endpoint, Func<Memory<byte>
return new Connection(socket, commandCallback, closedCallBack, sslOption, logger);
}

public async ValueTask<bool> Write<T>(T command) where T : struct, ICommand
public ValueTask<bool> Write<T>(T command) where T : struct, ICommand
{
await WriteCommand(command).ConfigureAwait(false);
// we return true to indicate that the command was written
// In this PR https://github.com/rabbitmq/rabbitmq-stream-dotnet-client/pull/220
// we made all WriteCommand async so await is enough to indicate that the command was written
// We decided to keep the return value to avoid a breaking change
return true;
if (!_writeLock.Wait(0))
{
// https://blog.marcgravell.com/2018/07/pipe-dreams-part-3.html
var writeSlowPath = WriteCommandAsyncSlowPath(command);
writeSlowPath.ConfigureAwait(false);
return writeSlowPath;
}
else
{
var release = true;
try
{
var payloadSize = WriteCommandPayloadSize(command);
var written = command.Write(_writer);
Debug.Assert(payloadSize == written);
var flush = _writer.FlushAsync();
flush.ConfigureAwait(false);
if (flush.IsCompletedSuccessfully)
{
// we return true to indicate that the command was written
// In this PR https://github.com/rabbitmq/rabbitmq-stream-dotnet-client/pull/220
// we made all WriteCommand async so await is enough to indicate that the command was written
// We decided to keep the return value to avoid a breaking change
return ValueTask.FromResult(true);
}
else
{
release = false;
return AwaitFlushThenRelease(flush);
}
}
finally
{
if (release)
{
_writeLock.Release();
}
}
}
}

private async Task WriteCommand<T>(T command) where T : struct, ICommand
private async ValueTask<bool> WriteCommandAsyncSlowPath<T>(T command) where T : struct, ICommand
{
if (Token.IsCancellationRequested)
{
Expand All @@ -127,18 +160,45 @@ private async Task WriteCommand<T>(T command) where T : struct, ICommand
await _writeLock.WaitAsync(Token).ConfigureAwait(false);
try
{
var size = command.SizeNeeded;
var mem = new byte[4 + size]; // + 4 to write the size
WireFormatting.WriteUInt32(mem, (uint)size);
var written = command.Write(mem.AsSpan()[4..]);
await writer.WriteAsync(new ReadOnlyMemory<byte>(mem), Token).ConfigureAwait(false);
Debug.Assert(size == written);
await writer.FlushAsync(Token).ConfigureAwait(false);
var payloadSize = WriteCommandPayloadSize(command);
var written = command.Write(_writer);
Debug.Assert(payloadSize == written);
await _writer.FlushAsync().ConfigureAwait(false);
}
finally
{
_writeLock.Release();
}

return true;
}

private async ValueTask<bool> AwaitFlushThenRelease(ValueTask<FlushResult> task)
{
try
{
await task.ConfigureAwait(false);
}
finally
{
_writeLock.Release();
}

return true;
}

private int WriteCommandPayloadSize<T>(T command) where T : struct, ICommand
{
/*
* TODO FUTURE
* This code could be moved into a common base class for all outgoing
* commands
*/
var payloadSize = command.SizeNeeded;
var mem = new byte[4 + payloadSize]; // + 4 to write the size
var written = WireFormatting.WriteUInt32(mem, (uint)payloadSize);
_writer.Advance(written);
return payloadSize;
}

private async Task ProcessIncomingFrames()
Expand All @@ -148,9 +208,9 @@ private async Task ProcessIncomingFrames()
{
while (!isClosed)
{
if (!reader.TryRead(out var result))
if (!_reader.TryRead(out var result))
{
result = await reader.ReadAsync(Token).ConfigureAwait(false);
result = await _reader.ReadAsync(Token).ConfigureAwait(false);
}

var buffer = result.Buffer;
Expand All @@ -166,16 +226,15 @@ private async Task ProcessIncomingFrames()
while (TryReadFrame(ref buffer, out var frame) && !isClosed)
{
// Let's rent some memory to copy the frame from the network stream. This memory will be reclaimed once the frame has been handled.

var memory =
ArrayPool<byte>.Shared.Rent((int)frame.Length).AsMemory(0, (int)frame.Length);
frame.CopyTo(memory.Span);

await commandCallback(memory).ConfigureAwait(false);
await _commandCallback(memory).ConfigureAwait(false);
numFrames += 1;
}

reader.AdvanceTo(buffer.Start, buffer.End);
_reader.AdvanceTo(buffer.Start, buffer.End);
}
}
catch (OperationCanceledException e)
Expand Down Expand Up @@ -206,8 +265,8 @@ private async Task ProcessIncomingFrames()
"TCP Connection Closed ClientId: {ClientId}, Reason {Reason}. IsCancellationRequested {Token} ",
ClientId, _closedReason, Token.IsCancellationRequested);
// Mark the PipeReader as complete
await reader.CompleteAsync(caught).ConfigureAwait(false);
closedCallback?.Invoke(_closedReason)!.ConfigureAwait(false);
await _reader.CompleteAsync(caught).ConfigureAwait(false);
_closedCallback?.Invoke(_closedReason)!.ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -247,9 +306,9 @@ public void Dispose()
}

isClosed = true;
writer.Complete();
reader.Complete();
socket.Close();
_writer.Complete();
_reader.Complete();
_socket.Close();
if (!_incomingFramesTask.Wait(Consts.MidWait))
{
_logger?.LogWarning("ProcessIncomingFrames reader task did not exit in {MidWait}",
Expand Down
2 changes: 1 addition & 1 deletion RabbitMQ.Stream.Client/ConsumerUpdateQueryResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private ConsumerUpdateQueryResponse(uint correlationId, byte subscriptionId, byt
public bool IsActive => active == 1;
public int SizeNeeded => throw new NotImplementedException();

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
throw new NotImplementedException();
}
Expand Down
14 changes: 8 additions & 6 deletions RabbitMQ.Stream.Client/ConsumerUpdateRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// 2.0, and the Mozilla Public License, version 2.0.
// Copyright (c) 2017-2023 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.

using System;
using System.Buffers;

namespace RabbitMQ.Stream.Client;

Expand All @@ -27,13 +27,15 @@ public int SizeNeeded
}
}

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
var span = writer.GetSpan(SizeNeeded);
var offset = WireFormatting.WriteUInt16(span, Key);
offset += WireFormatting.WriteUInt16(span.Slice(offset), ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span.Slice(offset), _correlationId);
offset += WireFormatting.WriteUInt16(span.Slice(offset), (ushort)ResponseCode.Ok);
offset += OffsetSpecification.Write(span.Slice(offset));
offset += WireFormatting.WriteUInt16(span[offset..], ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span[offset..], _correlationId);
offset += WireFormatting.WriteUInt16(span[offset..], (ushort)ResponseCode.Ok);
offset += OffsetSpecification.Write(span[offset..]);
writer.Advance(offset);
return offset;
}

Expand Down
20 changes: 12 additions & 8 deletions RabbitMQ.Stream.Client/Create.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,23 @@ public int SizeNeeded
}
}

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
var span = writer.GetSpan(SizeNeeded);

var offset = WireFormatting.WriteUInt16(span, Key);
offset += WireFormatting.WriteUInt16(span.Slice(offset), ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span.Slice(offset), correlationId);
offset += WireFormatting.WriteString(span.Slice(offset), stream);
offset += WireFormatting.WriteInt32(span.Slice(offset), arguments.Count);
offset += WireFormatting.WriteUInt16(span[offset..], ((ICommand)this).Version);
offset += WireFormatting.WriteUInt32(span[offset..], correlationId);
offset += WireFormatting.WriteString(span[offset..], stream);
offset += WireFormatting.WriteInt32(span[offset..], arguments.Count);

foreach (var (key, value) in arguments)
{
offset += WireFormatting.WriteString(span.Slice(offset), key);
offset += WireFormatting.WriteString(span.Slice(offset), value);
offset += WireFormatting.WriteString(span[offset..], key);
offset += WireFormatting.WriteString(span[offset..], value);
}

writer.Advance(offset);
return offset;
}
}
Expand All @@ -71,10 +74,11 @@ public CreateResponse(uint correlationId, ushort responseCode)

public ResponseCode ResponseCode => (ResponseCode)responseCode;

public int Write(Span<byte> span)
public int Write(IBufferWriter<byte> writer)
{
throw new NotImplementedException();
}

internal static int Read(ReadOnlySequence<byte> frame, out CreateResponse command)
{
var offset = WireFormatting.ReadUInt16(frame, out _);
Expand Down
Loading

0 comments on commit 2ef446a

Please sign in to comment.