From 77d2f2e518bf7ca9a5c30bed9199ec0ace400c79 Mon Sep 17 00:00:00 2001
From: Kamil Sobol <61715331+kasobol-msft@users.noreply.github.com>
Date: Fri, 1 May 2020 13:35:36 -0700
Subject: [PATCH] fix ingress/egress reporting on outgoing data.
---
.../Core/ByteCountingStream.cs | 44 +++++++++++++----
Lib/Common/Core/Util/StreamExtensions.cs | 5 +-
.../Shared/Protocol/HttpContentFactory.cs | 6 ++-
.../Blob/BlobReadStreamTest.cs | 48 +++++++++++++++++++
4 files changed, 91 insertions(+), 12 deletions(-)
diff --git a/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs b/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs
index 5f1878297..11d2cf2b8 100644
--- a/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs
+++ b/Lib/ClassLibraryCommon/Core/ByteCountingStream.cs
@@ -31,17 +31,19 @@ internal class ByteCountingStream : Stream
{
private readonly Stream wrappedStream;
private readonly RequestResult requestObject;
+ private readonly bool reverseCapture;
///
/// Initializes a new instance of the ByteCountingStream class with an expandable capacity initialized to zero.
///
- public ByteCountingStream(Stream wrappedStream, RequestResult requestObject)
+ public ByteCountingStream(Stream wrappedStream, RequestResult requestObject, bool reverseCapture = false)
: base()
{
CommonUtility.AssertNotNull("WrappedStream", wrappedStream);
CommonUtility.AssertNotNull("RequestObject", requestObject);
this.wrappedStream = wrappedStream;
this.requestObject = requestObject;
+ this.reverseCapture = reverseCapture;
}
public override bool CanRead
@@ -105,14 +107,14 @@ public override long Seek(long offset, SeekOrigin origin)
public override int Read(byte[] buffer, int offset, int count)
{
int read = this.wrappedStream.Read(buffer, offset, count);
- this.requestObject.IngressBytes += read;
+ this.CaptureRead(read);
return read;
}
public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
int read = await this.wrappedStream.ReadAsync(buffer, offset, count, cancellationToken);
- this.requestObject.IngressBytes += read;
+ this.CaptureRead(read);
return read;
}
@@ -122,7 +124,7 @@ public override int ReadByte()
if (val != -1)
{
- ++this.requestObject.IngressBytes;
+ this.CaptureRead(1);
}
return val;
@@ -155,7 +157,7 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, Asy
public override int EndRead(IAsyncResult asyncResult)
{
int read = this.wrappedStream.EndRead(asyncResult);
- this.requestObject.IngressBytes += read;
+ this.CaptureRead(read);
return read;
}
@@ -171,7 +173,7 @@ public override int EndRead(IAsyncResult asyncResult)
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
IAsyncResult res = this.wrappedStream.BeginWrite(buffer, offset, count, callback, state);
- this.requestObject.EgressBytes += count;
+ this.CaptureWrite(count);
return res;
}
@@ -187,19 +189,43 @@ public override void EndWrite(IAsyncResult asyncResult)
public override void Write(byte[] buffer, int offset, int count)
{
this.wrappedStream.Write(buffer, offset, count);
- this.requestObject.EgressBytes += count;
+ this.CaptureWrite(count);
}
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await this.wrappedStream.WriteAsync(buffer, offset, count, cancellationToken);
- this.requestObject.EgressBytes += count;
+ this.CaptureWrite(count);
}
public override void WriteByte(byte value)
{
this.wrappedStream.WriteByte(value);
- ++this.requestObject.EgressBytes;
+ this.CaptureWrite(1);
+ }
+
+ private void CaptureWrite(int count)
+ {
+ if (reverseCapture)
+ {
+ this.requestObject.IngressBytes += count;
+ }
+ else
+ {
+ this.requestObject.EgressBytes += count;
+ }
+ }
+
+ private void CaptureRead(int count)
+ {
+ if (reverseCapture)
+ {
+ this.requestObject.EgressBytes += count;
+ }
+ else
+ {
+ this.requestObject.IngressBytes += count;
+ }
}
protected override void Dispose(bool disposing)
diff --git a/Lib/Common/Core/Util/StreamExtensions.cs b/Lib/Common/Core/Util/StreamExtensions.cs
index 387f08986..a09e3048d 100644
--- a/Lib/Common/Core/Util/StreamExtensions.cs
+++ b/Lib/Common/Core/Util/StreamExtensions.cs
@@ -251,11 +251,12 @@ private static int MinBytesToRead(long? val1, int val2)
///
/// A reference to the original stream
/// An object that represents the result of a physical request.
+ /// A flag indicating that ingress/egress bytes should be capture in reverse.
///
[DebuggerNonUserCode]
- internal static Stream WrapWithByteCountingStream(this Stream stream, RequestResult result)
+ internal static Stream WrapWithByteCountingStream(this Stream stream, RequestResult result, bool reverseCapture=false)
{
- return new ByteCountingStream(stream, result);
+ return new ByteCountingStream(stream, result, reverseCapture);
}
#endif
diff --git a/Lib/Common/Shared/Protocol/HttpContentFactory.cs b/Lib/Common/Shared/Protocol/HttpContentFactory.cs
index 680c06e2a..109796ed4 100644
--- a/Lib/Common/Shared/Protocol/HttpContentFactory.cs
+++ b/Lib/Common/Shared/Protocol/HttpContentFactory.cs
@@ -19,6 +19,7 @@ namespace Microsoft.Azure.Storage.Shared.Protocol
{
using Microsoft.Azure.Storage.Core;
using Microsoft.Azure.Storage.Core.Executor;
+ using Microsoft.Azure.Storage.Core.Util;
using System;
using System.IO;
using System.Net.Http;
@@ -28,7 +29,10 @@ internal static class HttpContentFactory
public static HttpContent BuildContentFromStream(Stream stream, long offset, long? length, Checksum checksum, RESTCommand cmd, OperationContext operationContext)
{
stream.Seek(offset, SeekOrigin.Begin);
-
+
+#if !(WINDOWS_RT || NETCORE)
+ stream = stream.WrapWithByteCountingStream(cmd.CurrentResult, true);
+#endif
HttpContent retContent = new StreamContent(new NonCloseableStream(stream));
retContent.Headers.ContentLength = length;
if (checksum?.MD5 != null)
diff --git a/Test/ClassLibraryCommon/Blob/BlobReadStreamTest.cs b/Test/ClassLibraryCommon/Blob/BlobReadStreamTest.cs
index 1c73e032e..b5427bae0 100644
--- a/Test/ClassLibraryCommon/Blob/BlobReadStreamTest.cs
+++ b/Test/ClassLibraryCommon/Blob/BlobReadStreamTest.cs
@@ -342,6 +342,54 @@ public void BlockBlobReadStreamBasicTest()
}
}
+ [TestMethod]
+ [Description("Download a blob using CloudBlobStream With Ingress/Egress bytes tracking")]
+ [TestCategory(ComponentCategory.Blob)]
+ [TestCategory(TestTypeCategory.UnitTest)]
+ [TestCategory(SmokeTestCategory.NonSmoke)]
+ [TestCategory(TenantTypeCategory.DevStore), TestCategory(TenantTypeCategory.DevFabric), TestCategory(TenantTypeCategory.Cloud)]
+ public void BlockBlobReadStreamBasicWithIngressEgressBytesTest()
+ {
+ int bufferSize = 5 * 1024 * 1024;
+ byte[] buffer = GetRandomBuffer(bufferSize);
+ CloudBlobContainer container = GetRandomContainerReference();
+ try
+ {
+ container.Create();
+
+ CloudBlockBlob blob = container.GetBlockBlobReference("blob1");
+ using (MemoryStream wholeBlob = new MemoryStream(buffer))
+ {
+ OperationContext operationContext = new OperationContext();
+ blob.UploadFromStream(wholeBlob, null, null, operationContext);
+ Assert.AreEqual(bufferSize, operationContext.LastResult.EgressBytes);
+ Assert.AreEqual(0, operationContext.LastResult.IngressBytes);
+ }
+
+ using (MemoryStream wholeBlob = new MemoryStream(buffer))
+ {
+ OperationContext operationContext = new OperationContext();
+ using (Stream blobStream = blob.OpenRead(operationContext: operationContext))
+ {
+ TestHelper.AssertStreamsAreEqual(wholeBlob, blobStream);
+ }
+ long totalIngress = 0;
+ long totalEggress = 0;
+ foreach (var result in operationContext.RequestResults)
+ {
+ totalIngress += result.IngressBytes;
+ totalEggress += result.EgressBytes;
+ }
+ Assert.AreEqual(bufferSize, totalIngress);
+ Assert.AreEqual(0, totalEggress);
+ }
+ }
+ finally
+ {
+ container.DeleteIfExists();
+ }
+ }
+
[TestMethod]
[Description("Download a blob using CloudBlobStream")]
[TestCategory(ComponentCategory.Blob)]