From 21c8e83863769fb064a3cc746532fa961e5d1548 Mon Sep 17 00:00:00 2001 From: Richard Schneider Date: Sun, 1 Sep 2019 13:48:12 +1200 Subject: [PATCH] fix: reading bytes from a stream #51 --- src/Multiplex/Muxer.cs | 6 +- src/ProtoBufHelper.cs | 4 +- src/Protocols/Message.cs | 4 +- src/Protocols/Ping1.cs | 10 +-- src/SecureCommunication/Secio1.cs | 2 +- src/SecureCommunication/Secio1Stream.cs | 7 +- src/StreamExtensions.cs | 95 +++++++++++++++++++++++++ test/StreamExtensionsTest.cs | 73 +++++++++++++++++++ 8 files changed, 176 insertions(+), 25 deletions(-) create mode 100644 src/StreamExtensions.cs create mode 100644 test/StreamExtensionsTest.cs diff --git a/src/Multiplex/Muxer.cs b/src/Multiplex/Muxer.cs index 7a147ae..3c83df5 100644 --- a/src/Multiplex/Muxer.cs +++ b/src/Multiplex/Muxer.cs @@ -188,11 +188,7 @@ public bool Receiver // Read the payload. var payload = new byte[length]; - int offset = 0; - while (offset < length) - { - offset += await Channel.ReadAsync(payload, offset, length - offset, cancel).ConfigureAwait(false); - } + await Channel.ReadExactAsync(payload, 0, length, cancel).ConfigureAwait(false); // Process the packet Substreams.TryGetValue(header.StreamId, out Substream substream); diff --git a/src/ProtoBufHelper.cs b/src/ProtoBufHelper.cs index 973f680..1d16fe7 100644 --- a/src/ProtoBufHelper.cs +++ b/src/ProtoBufHelper.cs @@ -34,9 +34,7 @@ public static class ProtoBufHelper { var length = await stream.ReadVarint32Async(cancel).ConfigureAwait(false); var bytes = new byte[length]; - for (int offset = 0; offset < length;) { - offset += await stream.ReadAsync(bytes, offset, length - offset, cancel).ConfigureAwait(false); - } + await stream.ReadExactAsync(bytes, 0, length, cancel).ConfigureAwait(false); using (var ms = new MemoryStream(bytes, false)) { diff --git a/src/Protocols/Message.cs b/src/Protocols/Message.cs index fb36c22..fc2f18e 100644 --- a/src/Protocols/Message.cs +++ b/src/Protocols/Message.cs @@ -47,8 +47,8 @@ public static class Message var eol = new byte[1]; var length = await stream.ReadVarint32Async(cancel).ConfigureAwait(false); var buffer = new byte[length - 1]; - await stream.ReadAsync(buffer, 0, length - 1, cancel).ConfigureAwait(false); - await stream.ReadAsync(eol, 0, 1, cancel).ConfigureAwait(false); + await stream.ReadExactAsync(buffer, 0, length - 1, cancel).ConfigureAwait(false); + await stream.ReadExactAsync(eol, 0, 1, cancel).ConfigureAwait(false); if (eol[0] != newline[0]) { log.Error($"length: {length}, bytes: {buffer.ToHexString()}"); diff --git a/src/Protocols/Ping1.cs b/src/Protocols/Ping1.cs index 01a6876..a2f9049 100644 --- a/src/Protocols/Ping1.cs +++ b/src/Protocols/Ping1.cs @@ -48,10 +48,7 @@ public override string ToString() { // Read the message. var request = new byte[PingSize]; - for (int offset = 0; offset < PingSize;) - { - offset += await stream.ReadAsync(request, offset, PingSize - offset, cancel).ConfigureAwait(false); - } + await stream.ReadExactAsync(request, 0, PingSize, cancel).ConfigureAwait(false); log.Debug($"got ping from {connection.RemotePeer}"); // Echo the message @@ -147,10 +144,7 @@ async Task> PingAsync(Peer peer, int count, Cancellation await stream.FlushAsync(cancel).ConfigureAwait(false); var response = new byte[PingSize]; - for (int offset = 0; offset < PingSize;) - { - offset += await stream.ReadAsync(response, offset, PingSize - offset, cancel).ConfigureAwait(false); - } + await stream.ReadExactAsync(response, 0, PingSize, cancel).ConfigureAwait(false); var result = new PingResult { diff --git a/src/SecureCommunication/Secio1.cs b/src/SecureCommunication/Secio1.cs index 232314c..8bc6fe8 100644 --- a/src/SecureCommunication/Secio1.cs +++ b/src/SecureCommunication/Secio1.cs @@ -188,7 +188,7 @@ public override string ToString() // Receive our nonce. var verification = new byte[localNonce.Length]; - await secureStream.ReadAsync(verification, 0, verification.Length, cancel); + await secureStream.ReadExactAsync(verification, 0, verification.Length, cancel); if (!localNonce.SequenceEqual(verification)) { throw new Exception($"SECIO verification message failure."); diff --git a/src/SecureCommunication/Secio1Stream.cs b/src/SecureCommunication/Secio1Stream.cs index f814c9f..68d56be 100644 --- a/src/SecureCommunication/Secio1Stream.cs +++ b/src/SecureCommunication/Secio1Stream.cs @@ -192,12 +192,7 @@ async Task ReadPacketAsync(CancellationToken cancel) async Task ReadPacketBytesAsync(int count, CancellationToken cancel) { byte[] buffer = new byte[count]; - for (int i = 0, n; i < count; i += n) - { - n = await stream.ReadAsync(buffer, i, count - i, cancel).ConfigureAwait(false); - if (n < 1) - throw new EndOfStreamException(); - } + await stream.ReadExactAsync(buffer, 0, count, cancel).ConfigureAwait(false); return buffer; } diff --git a/src/StreamExtensions.cs b/src/StreamExtensions.cs new file mode 100644 index 0000000..41429ec --- /dev/null +++ b/src/StreamExtensions.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace PeerTalk +{ + /// + /// + /// + public static class StreamExtensions + { + /// + /// Asynchronously reads a sequence of bytes from the stream and advances + /// the position within the stream by the number of bytes read. + /// + /// + /// The stream to read from. + /// + /// + /// The buffer to write the data into. + /// + /// + /// The byte offset in at which to begin + /// writing data from the . + /// + /// + /// The number of bytes to read. + /// + /// + /// A task that represents the asynchronous operation. + /// + /// + /// When the does not have + /// bytes. + /// + public static async Task ReadExactAsync(this Stream stream, byte[] buffer, int offset, int length) + { + while (0 < length) + { + var n = await stream.ReadAsync(buffer, offset, length); + if (n == 0) + { + throw new EndOfStreamException(); + } + offset += n; + length -= n; + } + } + + /// + /// Asynchronously reads a sequence of bytes from the stream and advances + /// the position within the stream by the number of bytes read. + /// + /// + /// The stream to read from. + /// + /// + /// The buffer to write the data into. + /// + /// + /// The byte offset in at which to begin + /// writing data from the . + /// + /// + /// The number of bytes to read. + /// + /// + /// Is used to stop the task. + /// + /// + /// A task that represents the asynchronous operation. + /// + /// + /// When the does not have + /// bytes. + /// + public static async Task ReadExactAsync(this Stream stream, byte[] buffer, int offset, int length, CancellationToken cancel) + { + while (0 < length) + { + var n = await stream.ReadAsync(buffer, offset, length, cancel); + if (n == 0) + { + throw new EndOfStreamException(); + } + offset += n; + length -= n; + } + } + } +} diff --git a/test/StreamExtensionsTest.cs b/test/StreamExtensionsTest.cs new file mode 100644 index 0000000..f26b37a --- /dev/null +++ b/test/StreamExtensionsTest.cs @@ -0,0 +1,73 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Ipfs; +using System; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Net.Sockets; + +namespace PeerTalk +{ + [TestClass] + public class StreamExtensionsTest + { + [TestMethod] + public async Task ReadAsync() + { + var expected = new byte[] { 1, 2, 3, 4 }; + using (var ms = new MemoryStream(expected)) + { + var actual = new byte[expected.Length]; + await ms.ReadExactAsync(actual, 0, actual.Length); + CollectionAssert.AreEqual(expected, actual); + } + } + + [TestMethod] + public void ReadAsync_EOS() + { + var expected = new byte[] { 1, 2, 3, 4 }; + var actual = new byte[expected.Length + 1]; + + using (var ms = new MemoryStream(expected)) + { + ExceptionAssert.Throws(() => + { + ms.ReadExactAsync(actual, 0, actual.Length).Wait(); + }); + } + + var cancel = new CancellationTokenSource(); + using (var ms = new MemoryStream(expected)) + { + ExceptionAssert.Throws(() => + { + ms.ReadExactAsync(actual, 0, actual.Length, cancel.Token).Wait(); + }); + } + } + + [TestMethod] + public async Task ReadAsync_Cancel() + { + var expected = new byte[] { 1, 2, 3, 4 }; + var actual = new byte[expected.Length]; + var cancel = new CancellationTokenSource(); + using (var ms = new MemoryStream(expected)) + { + await ms.ReadExactAsync(actual, 0, actual.Length, cancel.Token); + CollectionAssert.AreEqual(expected, actual); + } + + cancel.Cancel(); + using (var ms = new MemoryStream(expected)) + { + ExceptionAssert.Throws(() => + { + ms.ReadExactAsync(actual, 0, actual.Length, cancel.Token).Wait(); + }); + } + } + } +}