diff --git a/docs/features/requestaggregation.rst b/docs/features/requestaggregation.rst index f1123cf19..c9a642388 100644 --- a/docs/features/requestaggregation.rst +++ b/docs/features/requestaggregation.rst @@ -217,9 +217,12 @@ Below is an example of an aggregator that you could implement for your solution: Gotchas ------- -You cannot use Routes with specific **RequestIdKeys** as this would be crazy complicated to track. +* You cannot use Routes with specific **RequestIdKeys** as this would be crazy complicated to track. +* Aggregation only supports the ``GET`` HTTP verb. +* Aggregation allows for the forwarding of ``HttpRequest.Body`` to downstream services by duplicating the body data. + Form data and attached files should also be forwarded. + It is essential to always specify the ``Content-Length`` header in requests to upstream; otherwise, Ocelot will log warnings like *"Aggregation does not support body copy without Content-Length header!"*. -Aggregation only supports the ``GET`` HTTP verb. """" diff --git a/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs b/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs index 17609dc18..43a98fcd3 100644 --- a/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs +++ b/src/Ocelot/Multiplexer/MultiplexingMiddleware.cs @@ -19,8 +19,7 @@ public class MultiplexingMiddleware : OcelotMiddleware public MultiplexingMiddleware(RequestDelegate next, IOcelotLoggerFactory loggerFactory, - IResponseAggregatorFactory factory - ) + IResponseAggregatorFactory factory) : base(loggerFactory.CreateLogger()) { _factory = factory; @@ -184,7 +183,7 @@ private IEnumerable> ProcessRouteWithComplexAggregation(Aggreg /// The cloned Http context. private async Task ProcessRouteAsync(HttpContext sourceContext, DownstreamRoute route, List placeholders = null) { - var newHttpContext = CreateThreadContext(sourceContext); + var newHttpContext = await CreateThreadContextAsync(sourceContext); CopyItemsToNewContext(newHttpContext, sourceContext, placeholders); newHttpContext.Items.UpsertDownstreamRoute(route); @@ -208,14 +207,15 @@ private static void CopyItemsToNewContext(HttpContext target, HttpContext source /// /// The base http context. /// The cloned context. - private static HttpContext CreateThreadContext(HttpContext source) + protected virtual async Task CreateThreadContextAsync(HttpContext source) { - var from = source.Request; + var from = source.Request; + var bodyStream = await CloneRequestBodyAsync(from, source.RequestAborted); var target = new DefaultHttpContext { Request = { - Body = from.Body, // TODO Consider stream cloning for multiple reads + Body = bodyStream, ContentLength = from.ContentLength, ContentType = from.ContentType, Host = from.Host, @@ -237,12 +237,13 @@ private static HttpContext CreateThreadContext(HttpContext source) RequestAborted = source.RequestAborted, User = source.User, }; - foreach (var header in from.Headers) { target.Request.Headers[header.Key] = header.Value.ToArray(); - } - + } + + // Once the downstream request is completed and the downstream response has been read, the downstream response object can dispose of the body's Stream object + target.Response.RegisterForDisposeAsync(bodyStream); // manage Stream lifetime by HttpResponse object return target; } @@ -255,5 +256,29 @@ protected virtual Task MapAsync(HttpContext httpContext, Route route, List CloneRequestBodyAsync(HttpRequest request, CancellationToken aborted) + { + request.EnableBuffering(); + if (request.Body.Position != 0) + { + Logger.LogWarning("Ocelot does not support body copy without stream in initial position 0"); + return request.Body; + } + + var targetBuffer = new MemoryStream(); + if (request.ContentLength is not null) + { + await request.Body.CopyToAsync(targetBuffer, (int)request.ContentLength, aborted); + targetBuffer.Position = 0; + request.Body.Position = 0; + } + else + { + Logger.LogWarning("Aggregation does not support body copy without Content-Length header!"); + } + + return targetBuffer; + } } diff --git a/test/Ocelot.AcceptanceTests/AggregateTests.cs b/test/Ocelot.AcceptanceTests/AggregateTests.cs index daf3bdbfa..6257e27e6 100644 --- a/test/Ocelot.AcceptanceTests/AggregateTests.cs +++ b/test/Ocelot.AcceptanceTests/AggregateTests.cs @@ -15,6 +15,7 @@ using Ocelot.DependencyInjection; using Ocelot.Middleware; using Ocelot.Multiplexer; +using System.Text; namespace Ocelot.AcceptanceTests { @@ -598,114 +599,209 @@ public void Should_return_response_200_with_user_forwarding() } } - private void GivenServiceIsRunning(string baseUrl, int statusCode, string responseBody) + [Fact] + [Trait("Bug", "2039")] + public void Should_return_response_200_with_copied_body_sent_on_multiple_services() { - _serviceHandler.GivenThereIsAServiceRunningOn(baseUrl, async context => - { - context.Response.StatusCode = statusCode; - await context.Response.WriteAsync(responseBody); - }); + var port1 = PortFinder.GetRandomPort(); + var port2 = PortFinder.GetRandomPort(); + var route1 = GivenRoute(port1, "/Service1", "Service1", "/Sub1"); + var route2 = GivenRoute(port2, "/Service2", "Service2", "/Sub2"); + var configuration = GivenConfiguration(route1, route2); + var requestBody = @"{""id"":1,""response"":""fromBody-#REPLACESTRING#""}"; + var sub1ResponseContent = @"{""id"":1,""response"":""fromBody-s1""}"; + var sub2ResponseContent = @"{""id"":1,""response"":""fromBody-s2""}"; + var expected = $"{{\"Service1\":{sub1ResponseContent},\"Service2\":{sub2ResponseContent}}}"; + + this.Given(x => x.GivenServiceIsRunning(0, port1, "/Sub1", 200, reqBody => reqBody.Replace("#REPLACESTRING#", "s1"))) + .Given(x => x.GivenServiceIsRunning(1, port2, "/Sub2", 200, reqBody => reqBody.Replace("#REPLACESTRING#", "s2"))) + .And(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning()) + .When(x => WhenIGetUrlWithBodyOnTheApiGateway("/", requestBody)) + .Then(x => ThenTheStatusCodeShouldBe(HttpStatusCode.OK)) + .And(x => ThenTheResponseBodyShouldBe(expected)) + .BDDfy(); + } + + [Fact] + [Trait("Bug", "2039")] + public void Should_return_response_200_with_copied_form_sent_on_multiple_services() + { + var port1 = PortFinder.GetRandomPort(); + var port2 = PortFinder.GetRandomPort(); + var route1 = GivenRoute(port1, "/Service1", "Service1", "/Sub1"); + var route2 = GivenRoute(port2, "/Service2", "Service2", "/Sub2"); + var configuration = GivenConfiguration(route1, route2); + + var formValues = new[] + { + new KeyValuePair("param1", "value1"), + new KeyValuePair("param2", "from-form-REPLACESTRING"), + }; + + var sub1ResponseContent = "\"[key:param1=value1¶m2=from-form-s1]\""; + var sub2ResponseContent = "\"[key:param1=value1¶m2=from-form-s2]\""; + var expected = $"{{\"Service1\":{sub1ResponseContent},\"Service2\":{sub2ResponseContent}}}"; + + this.Given(x => x.GivenServiceIsRunning(0, port1, "/Sub1", 200, (IFormCollection reqForm) => FormatFormCollection(reqForm).Replace("REPLACESTRING", "s1"))) + .Given(x => x.GivenServiceIsRunning(1, port2, "/Sub2", 200, (IFormCollection reqForm) => FormatFormCollection(reqForm).Replace("REPLACESTRING", "s2"))) + .And(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning()) + .When(x => WhenIGetUrlWithFormOnTheApiGateway("/", "key", formValues)) + .Then(x => ThenTheStatusCodeShouldBe(HttpStatusCode.OK)) + .And(x => ThenTheResponseBodyShouldBe(expected)) + .BDDfy(); + } + + private static string FormatFormCollection(IFormCollection reqForm) + { + var sb = new StringBuilder() + .Append('"'); + + foreach (var kvp in reqForm) + { + sb.Append($"[{kvp.Key}:{kvp.Value}]"); + } + + return sb + .Append('"') + .ToString(); + } + + private void GivenServiceIsRunning(string baseUrl, int statusCode, string responseBody) + { + _serviceHandler.GivenThereIsAServiceRunningOn(baseUrl, async context => + { + context.Response.StatusCode = statusCode; + await context.Response.WriteAsync(responseBody); + }); } private void GivenServiceIsRunning(int index, int port, string basePath, int statusCode, string responseBody) - { - var baseUrl = $"{Uri.UriSchemeHttp}://localhost:{port}"; - _serviceHandler.GivenThereIsAServiceRunningOn(baseUrl, basePath, async context => - { - _downstreamPaths[index] = !string.IsNullOrEmpty(context.Request.PathBase.Value) ? context.Request.PathBase.Value : context.Request.Path.Value; - - if (_downstreamPaths[index] != basePath) + => GivenServiceIsRunning(index, port, basePath, statusCode, + async context => { - context.Response.StatusCode = statusCode; - await context.Response.WriteAsync("downstream path didn't match base path"); - } - else + await context.Response.WriteAsync(responseBody); + }); + + private void GivenServiceIsRunning(int index, int port, string basePath, int statusCode, Func responseFromBody) + => GivenServiceIsRunning(index, port, basePath, statusCode, + async context => + { + var requestBody = await new StreamReader(context.Request.Body).ReadToEndAsync(); + var responseBody = responseFromBody(requestBody); + await context.Response.WriteAsync(responseBody); + }); + + private void GivenServiceIsRunning(int index, int port, string basePath, int statusCode, Func responseFromForm) + => GivenServiceIsRunning(index, port, basePath, statusCode, + async context => { - context.Response.StatusCode = statusCode; + var responseBody = responseFromForm(context.Request.Form); await context.Response.WriteAsync(responseBody); - } - }); + }); + + private void GivenServiceIsRunning(int index, int port, string basePath, int statusCode, Action processContext) + { + var baseUrl = DownstreamUrl(port); + _serviceHandler.GivenThereIsAServiceRunningOn(baseUrl, basePath, async context => + { + _downstreamPaths[index] = !string.IsNullOrEmpty(context.Request.PathBase.Value) + ? context.Request.PathBase.Value + : context.Request.Path.Value; + + if (_downstreamPaths[index] != basePath) + { + context.Response.StatusCode = (int)HttpStatusCode.NotFound; + await context.Response.WriteAsync("downstream path doesn't match base path"); + } + else + { + context.Response.StatusCode = statusCode; + processContext?.Invoke(context); + } + }); } private void GivenOcelotIsRunningWithSpecificAggregatorsRegisteredInDi() where TAggregator : class, IDefinedAggregator - where TDependency : class - { - _webHostBuilder = new WebHostBuilder(); - - _webHostBuilder - .ConfigureAppConfiguration((hostingContext, config) => - { - config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); - var env = hostingContext.HostingEnvironment; - config.AddJsonFile("appsettings.json", true, false) - .AddJsonFile($"appsettings.{env.EnvironmentName}.json", true, false); - config.AddJsonFile(_ocelotConfigFileName, true, false); - config.AddEnvironmentVariables(); - }) - .ConfigureServices(s => - { - s.AddSingleton(_webHostBuilder); - s.AddSingleton(); - s.AddOcelot() - .AddSingletonDefinedAggregator(); - }) - .Configure(a => { a.UseOcelot().Wait(); }); - - _ocelotServer = new TestServer(_webHostBuilder); - _ocelotClient = _ocelotServer.CreateClient(); + where TDependency : class + { + _webHostBuilder = new WebHostBuilder(); + + _webHostBuilder + .ConfigureAppConfiguration((hostingContext, config) => + { + config.SetBasePath(hostingContext.HostingEnvironment.ContentRootPath); + var env = hostingContext.HostingEnvironment; + config.AddJsonFile("appsettings.json", true, false) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", true, false); + config.AddJsonFile(_ocelotConfigFileName, true, false); + config.AddEnvironmentVariables(); + }) + .ConfigureServices(s => + { + s.AddSingleton(_webHostBuilder); + s.AddSingleton(); + s.AddOcelot() + .AddSingletonDefinedAggregator(); + }) + .Configure(a => { a.UseOcelot().Wait(); }); + + _ocelotServer = new TestServer(_webHostBuilder); + _ocelotClient = _ocelotServer.CreateClient(); } - private void ThenTheDownstreamUrlPathShouldBe(string expectedDownstreamPathOne, string expectedDownstreamPath) - { - _downstreamPaths[0].ShouldBe(expectedDownstreamPathOne); - _downstreamPaths[1].ShouldBe(expectedDownstreamPath); + private void ThenTheDownstreamUrlPathShouldBe(string expectedDownstreamPathOne, string expectedDownstreamPath) + { + _downstreamPaths[0].ShouldBe(expectedDownstreamPathOne); + _downstreamPaths[1].ShouldBe(expectedDownstreamPath); } - private static FileRoute GivenRoute(int port, string upstream, string key) => new() - { - DownstreamPathTemplate = "/", - DownstreamScheme = Uri.UriSchemeHttp, - DownstreamHostAndPorts = new() { new FileHostAndPort("localhost", port) }, - UpstreamPathTemplate = upstream, - UpstreamHttpMethod = new() { HttpMethods.Get }, - Key = key, + private static FileRoute GivenRoute(int port, string upstream, string key, string downstream = null) => new() + { + DownstreamPathTemplate = downstream ?? "/", + DownstreamScheme = Uri.UriSchemeHttp, + DownstreamHostAndPorts = new() { new("localhost", port) }, + UpstreamPathTemplate = upstream, + UpstreamHttpMethod = new() { HttpMethods.Get }, + Key = key, }; - private static new FileConfiguration GivenConfiguration(params FileRoute[] routes) - { - var obj = Steps.GivenConfiguration(routes); - obj.Aggregates.Add( - new() - { - UpstreamPathTemplate = "/", - UpstreamHost = "localhost", - RouteKeys = routes.Select(r => r.Key).ToList(), // [ "Laura", "Tom" ], - } - ); - return obj; - } + private static new FileConfiguration GivenConfiguration(params FileRoute[] routes) + { + var obj = Steps.GivenConfiguration(routes); + obj.Aggregates.Add( + new() + { + UpstreamPathTemplate = "/", + UpstreamHost = "localhost", + RouteKeys = routes.Select(r => r.Key).ToList(), // [ "Laura", "Tom" ], + } + ); + return obj; + } } - public class FakeDep - { + public class FakeDep + { } - public class FakeDefinedAggregator : IDefinedAggregator - { - public FakeDefinedAggregator(FakeDep dep) - { - } - - public async Task Aggregate(List responses) - { - var one = await responses[0].Items.DownstreamResponse().Content.ReadAsStringAsync(); - var two = await responses[1].Items.DownstreamResponse().Content.ReadAsStringAsync(); - - var merge = $"{one}, {two}"; - merge = merge.Replace("Hello", "Bye").Replace("{", "").Replace("}", ""); - var headers = responses.SelectMany(x => x.Items.DownstreamResponse().Headers).ToList(); - return new DownstreamResponse(new StringContent(merge), HttpStatusCode.OK, headers, "some reason"); - } + public class FakeDefinedAggregator : IDefinedAggregator + { + public FakeDefinedAggregator(FakeDep dep) + { + } + + public async Task Aggregate(List responses) + { + var one = await responses[0].Items.DownstreamResponse().Content.ReadAsStringAsync(); + var two = await responses[1].Items.DownstreamResponse().Content.ReadAsStringAsync(); + + var merge = $"{one}, {two}"; + merge = merge.Replace("Hello", "Bye").Replace("{", "").Replace("}", ""); + var headers = responses.SelectMany(x => x.Items.DownstreamResponse().Headers).ToList(); + return new DownstreamResponse(new StringContent(merge), HttpStatusCode.OK, headers, "some reason"); + } } -} +} diff --git a/test/Ocelot.AcceptanceTests/Steps.cs b/test/Ocelot.AcceptanceTests/Steps.cs index 7f95a8a9d..5acb57cbc 100644 --- a/test/Ocelot.AcceptanceTests/Steps.cs +++ b/test/Ocelot.AcceptanceTests/Steps.cs @@ -190,8 +190,8 @@ protected virtual void DeleteOcelotConfig(params string[] files) { Console.WriteLine(e); } - } - } + } + } public void ThenTheResponseBodyHeaderIs(string key, string value) { @@ -817,6 +817,29 @@ public void WhenIGetUrlOnTheApiGatewayAndDontWait(string url) _ocelotClient.GetAsync(url); } + public void WhenIGetUrlWithBodyOnTheApiGateway(string url, string body) + { + var request = new HttpRequestMessage(HttpMethod.Get, url) + { + Content = new StringContent(body), + }; + _response = _ocelotClient.SendAsync(request).Result; + } + + public void WhenIGetUrlWithFormOnTheApiGateway(string url, string name, IEnumerable> values) + { + var content = new MultipartFormDataContent(); + var dataContent = new FormUrlEncodedContent(values); + content.Add(dataContent, name); + content.Headers.ContentDisposition = new ContentDispositionHeaderValue("form-data"); + + var request = new HttpRequestMessage(HttpMethod.Get, url) + { + Content = content, + }; + _response = _ocelotClient.SendAsync(request).Result; + } + public void WhenICancelTheRequest() { _ocelotClient.CancelPendingRequests(); diff --git a/test/Ocelot.UnitTests/Multiplexing/MultiplexingMiddlewareTests.cs b/test/Ocelot.UnitTests/Multiplexing/MultiplexingMiddlewareTests.cs index 95f41e863..b94c247ff 100644 --- a/test/Ocelot.UnitTests/Multiplexing/MultiplexingMiddlewareTests.cs +++ b/test/Ocelot.UnitTests/Multiplexing/MultiplexingMiddlewareTests.cs @@ -61,14 +61,14 @@ public void should_not_multiplex() [Fact] [Trait("Bug", "1396")] - public void CreateThreadContext_CopyUser_ToTarget() + public async Task CreateThreadContextAsync_CopyUser_ToTarget() { // Arrange - GivenUser("test", "Copy", nameof(CreateThreadContext_CopyUser_ToTarget)); + GivenUser("test", "Copy", nameof(CreateThreadContextAsync_CopyUser_ToTarget)); // Act - var method = _middleware.GetType().GetMethod("CreateThreadContext", BindingFlags.NonPublic | BindingFlags.Static); - var actual = (HttpContext)method.Invoke(_middleware, new object[] { _httpContext }); + var method = _middleware.GetType().GetMethod("CreateThreadContextAsync", BindingFlags.NonPublic | BindingFlags.Instance); + var actual = await (Task)method.Invoke(_middleware, new object[] { _httpContext }); // Assert AssertUsers(actual); @@ -188,8 +188,8 @@ public async Task Should_Create_As_Many_Contexts_As_Routes_And_Map_Is_Called_Onc ItExpr.IsAny(), ItExpr.Is>(list => list.Count == routesCount) ); - } - + } + [Fact] [Trait("PR", "1826")] public async Task Should_Not_Call_ProcessSingleRoute_Or_Map_If_No_Route() @@ -212,7 +212,30 @@ public async Task Should_Not_Call_ProcessSingleRoute_Or_Map_If_No_Route() ItExpr.IsAny(), ItExpr.IsAny(), ItExpr.IsAny>()); - } + } + + [Theory] + [Trait("Bug", "2039")] + [InlineData(1)] // Times.Never() + [InlineData(2)] // Times.Exactly(2) + [InlineData(3)] // Times.Exactly(3) + [InlineData(4)] // Times.Exactly(4) + public async Task Should_Call_CloneRequestBodyAsync_Each_Time_Per_Requests(int numberOfRoutes) + { + // Arrange + var mock = MockMiddlewareFactory(null, null); + GivenUser("test", "Invoke", nameof(Should_Call_CloneRequestBodyAsync_Each_Time_Per_Requests)); + GivenTheFollowing(GivenDefaultRoute(numberOfRoutes)); + + // Act + await WhenIMultiplex(); + + // Assert + mock.Protected().Verify>("CloneRequestBodyAsync", + numberOfRoutes > 1 ? Times.Exactly(numberOfRoutes) : Times.Never(), + ItExpr.IsAny(), + ItExpr.IsAny()); + } [Fact] [Trait("PR", "1826")]