Skip to content

Commit

Permalink
refactor: abstracting batching into its own class
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Frowen committed Apr 25, 2024
1 parent 996b52a commit 4b77799
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 88 deletions.
62 changes: 12 additions & 50 deletions Assets/Mirage/Runtime/SocketLayer/Connection/AckSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ public class AckSystem : IDisposable
/// <summary>PacketType, ack sequence, mask</summary>
public const int ACK_HEADER_SIZE = sizeof(byte) + sizeof(ushort) + sizeof(ulong);

public const int RELIABLE_MESSAGE_LENGTH_SIZE = sizeof(ushort);
public const int FRAGMENT_INDEX_SIZE = sizeof(byte);

/// <summary>Smallest size a header for reliable packet, <see cref="RELIABLE_HEADER_SIZE"/> + 2 bytes per message</summary>
public const int MIN_RELIABLE_HEADER_SIZE = RELIABLE_HEADER_SIZE + RELIABLE_MESSAGE_LENGTH_SIZE;
public const int MIN_RELIABLE_HEADER_SIZE = RELIABLE_HEADER_SIZE + Batch.MESSAGE_LENGTH_SIZE;

/// <summary>Smallest size a header for reliable packet, <see cref="RELIABLE_HEADER_SIZE"/> + 1 byte for fragment index</summary>
public const int MIN_RELIABLE_FRAGMENT_HEADER_SIZE = RELIABLE_HEADER_SIZE + FRAGMENT_INDEX_SIZE;
Expand Down Expand Up @@ -66,7 +65,7 @@ public class AckSystem : IDisposable
private float _lastSentTime;
private ushort _lastSentAck;
private int _emptyAckCount = 0;
private ReliablePacket _nextBatch;
private readonly Batch _batch;

/// <summary>
///
Expand All @@ -83,6 +82,7 @@ public AckSystem(IRawConnection connection, Config config, int maxPacketSize, IT
_bufferPool = bufferPool;
_reliablePool = new Pool<ReliablePacket>(ReliablePacket.CreateNew, 0, config.MaxReliablePacketsInSendBufferPerConnection);
_metrics = metrics;
_batch = new ReliableBatch(maxPacketSize, CreateReliableBuffer, SendReliablePacket);

_ackTimeout = config.TimeBeforeEmptyAck;
_emptyAckLimit = config.EmptyAckLimit;
Expand Down Expand Up @@ -112,6 +112,9 @@ public void Dispose()
{
var removeSafety = new HashSet<ByteBuffer>();

if (_batch is IDisposable disposable)
disposable.Dispose();

_sentAckablePackets.ClearAndRelease((packet) =>
{
Debug.Assert(packet.IsValid());
Expand Down Expand Up @@ -142,7 +145,7 @@ public void Dispose()


/// <summary>
/// Gets next Reliable packet in order, packet consists for multiple messsages
/// Gets next Reliable packet in order, packet consists for multiple messages
/// <para>[length, message, length, message, ...]</para>
/// </summary>
/// <param name="packet"></param>
Expand Down Expand Up @@ -191,12 +194,7 @@ public ReliableReceived GetNextFragment()

public void Update()
{
if (_nextBatch != null)
{
SendReliablePacket(_nextBatch);
_nextBatch = null;
}

_batch.Flush();

// todo send ack if not recently been sent
// ack only packet sent if no other sent within last frame
Expand Down Expand Up @@ -308,8 +306,6 @@ public void SendNotify(byte[] inPacket, int inOffset, int inLength, INotifyCallB
}
}



public void SendReliable(byte[] message, int offset, int length)
{
if (_sentAckablePackets.IsFull)
Expand All @@ -325,37 +321,16 @@ public void SendReliable(byte[] message, int offset, int length)
// if there is existing batch, send it first
// we need to do this so that fragmented message arrive in order
// if we dont, a message sent after maybe be added to batch and then have earlier order than fragmented message
if (_nextBatch != null)
{
SendReliablePacket(_nextBatch);
_nextBatch = null;
}

_batch.Flush();
SendFragmented(message, offset, length);
return;
}


if (_nextBatch == null)
{
_nextBatch = CreateReliableBuffer(PacketType.Reliable);
}

var msgLength = length + RELIABLE_MESSAGE_LENGTH_SIZE;
var batchLength = _nextBatch.Length;
if (batchLength + msgLength > _maxPacketSize)
{
// if full, send and create new
SendReliablePacket(_nextBatch);

_nextBatch = CreateReliableBuffer(PacketType.Reliable);
}

AddToBatch(_nextBatch, message, offset, length);
_batch.AddMessage(message, offset, length);
}

/// <summary>
/// Splits messsage into multiple packets
/// Splits message into multiple packets
/// <para>Note: this might just send 1 packet if length is equal to size.
/// This might happen because fragmented header is 1 less that batched header</para>
/// </summary>
Expand Down Expand Up @@ -408,18 +383,6 @@ private ReliablePacket CreateReliableBuffer(PacketType packetType)
return packet;
}

private static void AddToBatch(ReliablePacket packet, byte[] message, int offset, int length)
{
var array = packet.Buffer.array;
var packetOffset = packet.Length;

ByteUtils.WriteUShort(array, ref packetOffset, (ushort)length);
Buffer.BlockCopy(message, offset, array, packetOffset, length);
packetOffset += length;

packet.Length = packetOffset;
}

private void SendReliablePacket(ReliablePacket reliable)
{
ThrowIfBufferLimitReached();
Expand Down Expand Up @@ -447,7 +410,6 @@ private void ThrowIfBufferLimitReached()
}
}


/// <summary>
/// Receives incoming Notify packet
/// <para>Ignores duplicate or late packets</para>
Expand Down Expand Up @@ -743,7 +705,7 @@ public bool IsNotValid()
}
}

private class ReliablePacket
public class ReliablePacket
{
public ushort LastSequence;
public int Length;
Expand Down
127 changes: 127 additions & 0 deletions Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using System;

namespace Mirage.SocketLayer
{
public abstract class Batch
{
public const int MESSAGE_LENGTH_SIZE = 2;

private readonly int _maxPacketSize;

public Batch(int maxPacketSize)
{
_maxPacketSize = maxPacketSize;
}

protected abstract bool Created { get; }
protected abstract byte[] GetBatch();
protected abstract ref int GetBatchLength();

protected abstract void CreateNewBatch();
protected abstract void SendAndReset();

public void AddMessage(byte[] message, int offset, int length)
{
if (Created)
{
var msgLength = length + MESSAGE_LENGTH_SIZE;
var batchLength = GetBatchLength();
if (batchLength + msgLength > _maxPacketSize)
{
// if full, send and create new
SendAndReset();
}
}

if (!Created)
CreateNewBatch();

AddToBatch(message, offset, length);
}

private void AddToBatch(byte[] message, int offset, int length)
{
var batch = GetBatch();
ref var batchLength = ref GetBatchLength();
ByteUtils.WriteUShort(batch, ref batchLength, checked((ushort)length));
Buffer.BlockCopy(message, offset, batch, batchLength, length);
batchLength += length;
}

public void Flush()
{
if (Created)
SendAndReset();
}
}

public class ArrayBatch : Batch
{
private readonly Action<byte[], int> _send;
private readonly PacketType _packetType;

private readonly byte[] _batch;
private int _batchLength;

public ArrayBatch(int maxPacketSize, Action<byte[], int> send, PacketType reliable)
: base(maxPacketSize)
{
_batch = new byte[maxPacketSize];
_send = send;
_packetType = reliable;
}

protected override bool Created => _batchLength > 0;

protected override byte[] GetBatch() => _batch;
protected override ref int GetBatchLength() => ref _batchLength;

protected override void CreateNewBatch()
{
_batch[0] = (byte)_packetType;
_batchLength = 1;
}

protected override void SendAndReset()
{
_send.Invoke(_batch, _batchLength);
_batchLength = 0;
}
}

public class ReliableBatch : Batch, IDisposable
{
private AckSystem.ReliablePacket _nextBatch;
private readonly Func<PacketType, AckSystem.ReliablePacket> _createReliableBuffer;
private readonly Action<AckSystem.ReliablePacket> _sendReliablePacket;

public ReliableBatch(int maxPacketSize, Func<PacketType, AckSystem.ReliablePacket> createReliableBuffer, Action<AckSystem.ReliablePacket> sendReliablePacket)
: base(maxPacketSize)
{
_createReliableBuffer = createReliableBuffer;
_sendReliablePacket = sendReliablePacket;
}

protected override bool Created => _nextBatch != null;

protected override byte[] GetBatch() => _nextBatch.Buffer.array;
protected override ref int GetBatchLength() => ref _nextBatch.Length;

protected override void CreateNewBatch()
{
_nextBatch = _createReliableBuffer.Invoke(PacketType.Reliable);
}

protected override void SendAndReset()
{
_sendReliablePacket.Invoke(_nextBatch);
_nextBatch = null;
}

void IDisposable.Dispose()
{
_nextBatch.Buffer.Release();
_nextBatch = null;
}
}
}
11 changes: 11 additions & 0 deletions Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,28 @@ namespace Mirage.SocketLayer
/// </summary>
internal sealed class NoReliableConnection : Connection
{
private const int HEADER_SIZE = 1 + MESSAGE_LENGTH_SIZE;
private const int MESSAGE_LENGTH_SIZE = 2;
private const int HEADER_SIZE = 1 + Batch.MESSAGE_LENGTH_SIZE;

private byte[] _nextBatch;
private int _batchLength;
private readonly Batch _nextBatchReliable;

internal NoReliableConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, int maxPacketSize, Time time, ILogger logger, Metrics metrics)
: base(peer, endPoint, dataHandler, config, maxPacketSize, time, logger, metrics)
{
_nextBatch = new byte[maxPacketSize];
CreateNewBatch();
_nextBatchReliable = new ArrayBatch(maxPacketSize, SendBatchInternal, PacketType.Reliable);

if (maxPacketSize > ushort.MaxValue)
{
throw new ArgumentException($"Max package size can not bigger than {ushort.MaxValue}. NoReliableConnection uses 2 bytes for message length, maxPacketSize over that value will mean that message will be incorrectly batched.");
}
}

private void SendBatchInternal(byte[] batch, int length)
{
_peer.Send(this, batch, length);
}

// just sue SendReliable for unreliable/notify
// note: we dont need to pass in that it is reliable, receiving doesn't really care what channel it is
public override void SendUnreliable(byte[] packet, int offset, int length) => SendReliable(packet, offset, length);
public override void SendNotify(byte[] packet, int offset, int length, INotifyCallBack callBacks)
{
Expand All @@ -53,37 +56,10 @@ public override void SendReliable(byte[] message, int offset, int length)
throw new ArgumentException($"Message is bigger than MTU, size:{length} but max message size is {_maxPacketSize - HEADER_SIZE}");
}


var msgLength = length + MESSAGE_LENGTH_SIZE;
if (_batchLength + msgLength > _maxPacketSize)
{
// if full, send and create new
SendBatch();
}

AddToBatch(message, offset, length);
_nextBatchReliable.AddMessage(message, offset, length);
_metrics?.OnSendMessageReliable(length);
}

private void SendBatch()
{
_peer.Send(this, _nextBatch, _batchLength);
CreateNewBatch();
}

private void CreateNewBatch()
{
_nextBatch[0] = (byte)PacketType.Reliable;
_batchLength = 1;
}

private void AddToBatch(byte[] message, int offset, int length)
{
ByteUtils.WriteUShort(_nextBatch, ref _batchLength, checked((ushort)length));
Buffer.BlockCopy(message, offset, _nextBatch, _batchLength, length);
_batchLength += length;
}

internal override void ReceiveReliablePacket(Packet packet)
{
HandleReliableBatched(packet.Buffer.array, 1, packet.Length);
Expand All @@ -96,10 +72,7 @@ internal override void ReceiveReliablePacket(Packet packet)

public override void FlushBatch()
{
if (_batchLength > 1)
{
SendBatch();
}
_nextBatchReliable.Flush();
}

internal override bool IsValidSize(Packet packet)
Expand Down

0 comments on commit 4b77799

Please sign in to comment.