From aef3e6b9f53ad8c8aa8c9ade92a572f60d0672a1 Mon Sep 17 00:00:00 2001 From: Sergii <109989060+sergio-str@users.noreply.github.com> Date: Tue, 7 May 2024 14:09:19 +0300 Subject: [PATCH] #1590 Use correct interval for request counting (#1592) * Use correct interval for request counting * Minor fixes, return correct counter value when ban period elapsed * Revert "Use correct interval for request counting" This reverts commit 7d232c7042f789d5ea32834f8c33d1adfb92ec24. * Revert "Artificial commit, initiate CI" This reverts commit e723dfac839f78961eb4ce0068c36eb0ab52e30c. * CA1822 Member 'XYZ' does not access instance data and can be marked as static * Quick code review by @raman-m * Rate Limiting feature name should match folder name * namespace `Ocelot.RateLimiting` * Extract `IRateLimitCore` interface * Remove useless `ClientRateLimitProcessor` class * Rename to `IRateLimitStorage` and dev docs * Wrap services as a feature * Review `IRateLimitCore` interface and dev docs * The middleware class prefix should match the feature name * Add some basic `RateLimitCoreTests` * Rename to `IRateLimiting` * Refactor rate limiting core * Remove redundant `SaveCounter` from the interface * Thread safe storage operations * Coalesce in return statement * Convert to file-scoped namespace * Use expression body * Unit tests for #1590 user scenario * Move test class to separate feature folder * Inherit from `Steps` * Refactoring: Follow the DRY principle * Acceptance test for #1590 user scenario * Update feature docs --------- Co-authored-by: raman-m --- docs/features/ratelimiting.rst | 92 ++++-- src/Ocelot/DependencyInjection/Features.cs | 16 ++ .../DependencyInjection/OcelotBuilder.cs | 4 +- .../Middleware/OcelotPipelineExtensions.cs | 2 +- .../RateLimit/ClientRateLimitProcessor.cs | 35 --- src/Ocelot/RateLimit/ClientRequestIdentity.cs | 18 -- ...DistributedCacheRateLimitCounterHandler.cs | 42 --- .../RateLimit/IRateLimitCounterHandler.cs | 13 - .../MemoryCacheRateLimitCounterHandler.cs | 28 -- .../RateLimitMiddlewareExtensions.cs | 12 - src/Ocelot/RateLimit/RateLimitCore.cs | 147 ---------- src/Ocelot/RateLimit/RateLimitCounter.cs | 21 -- src/Ocelot/RateLimit/RateLimitHeaders.cs | 23 -- .../RateLimiting/ClientRequestIdentity.cs | 15 + .../DistributedCacheRateLimitStorage.cs | 32 +++ src/Ocelot/RateLimiting/IRateLimitStorage.cs | 16 ++ src/Ocelot/RateLimiting/IRateLimiting.cs | 59 ++++ .../MemoryCacheRateLimitStorage.cs | 25 ++ .../Middleware/RateLimitingMiddleware.cs} | 36 ++- .../RateLimitingMiddlewareExtensions.cs | 11 + .../QuotaExceededError.cs | 2 +- src/Ocelot/RateLimiting/RateLimitCounter.cs | 29 ++ src/Ocelot/RateLimiting/RateLimitHeaders.cs | 19 ++ src/Ocelot/RateLimiting/RateLimiting.cs | 194 +++++++++++++ .../ClientRateLimitTests.cs | 217 -------------- .../RateLimiting/ClientRateLimitingTests.cs | 182 ++++++++++++ .../ClientRateLimitMiddlewareTests.cs | 183 ------------ .../RateLimitingMiddlewareTests.cs | 218 ++++++++++++++ .../RateLimiting/RateLimitingTests.cs | 268 ++++++++++++++++++ 29 files changed, 1165 insertions(+), 794 deletions(-) create mode 100644 src/Ocelot/DependencyInjection/Features.cs delete mode 100644 src/Ocelot/RateLimit/ClientRateLimitProcessor.cs delete mode 100644 src/Ocelot/RateLimit/ClientRequestIdentity.cs delete mode 100644 src/Ocelot/RateLimit/DistributedCacheRateLimitCounterHandler.cs delete mode 100644 src/Ocelot/RateLimit/IRateLimitCounterHandler.cs delete mode 100644 src/Ocelot/RateLimit/MemoryCacheRateLimitCounterHandler.cs delete mode 100644 src/Ocelot/RateLimit/Middleware/RateLimitMiddlewareExtensions.cs delete mode 100644 src/Ocelot/RateLimit/RateLimitCore.cs delete mode 100644 src/Ocelot/RateLimit/RateLimitCounter.cs delete mode 100644 src/Ocelot/RateLimit/RateLimitHeaders.cs create mode 100644 src/Ocelot/RateLimiting/ClientRequestIdentity.cs create mode 100644 src/Ocelot/RateLimiting/DistributedCacheRateLimitStorage.cs create mode 100644 src/Ocelot/RateLimiting/IRateLimitStorage.cs create mode 100644 src/Ocelot/RateLimiting/IRateLimiting.cs create mode 100644 src/Ocelot/RateLimiting/MemoryCacheRateLimitStorage.cs rename src/Ocelot/{RateLimit/Middleware/ClientRateLimitMiddleware.cs => RateLimiting/Middleware/RateLimitingMiddleware.cs} (83%) create mode 100644 src/Ocelot/RateLimiting/Middleware/RateLimitingMiddlewareExtensions.cs rename src/Ocelot/{RateLimit => RateLimiting}/QuotaExceededError.cs (89%) create mode 100644 src/Ocelot/RateLimiting/RateLimitCounter.cs create mode 100644 src/Ocelot/RateLimiting/RateLimitHeaders.cs create mode 100644 src/Ocelot/RateLimiting/RateLimiting.cs delete mode 100644 test/Ocelot.AcceptanceTests/ClientRateLimitTests.cs create mode 100644 test/Ocelot.AcceptanceTests/RateLimiting/ClientRateLimitingTests.cs delete mode 100644 test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs create mode 100644 test/Ocelot.UnitTests/RateLimiting/RateLimitingMiddlewareTests.cs create mode 100644 test/Ocelot.UnitTests/RateLimiting/RateLimitingTests.cs diff --git a/docs/features/ratelimiting.rst b/docs/features/ratelimiting.rst index 9a69f4ded..94db1db5d 100644 --- a/docs/features/ratelimiting.rst +++ b/docs/features/ratelimiting.rst @@ -1,35 +1,50 @@ Rate Limiting ============= +`What's rate limiting? `_ + +* `Rate limiting | Wikipedia `_ +* `Rate Limiting pattern | Azure Architecture Center | Microsoft Learn `_ +* `Rate Limiting | Ask Google `_ + Ocelot Own Implementation ------------------------- -Ocelot supports rate limiting of upstream requests so that your downstream services do not become overloaded. +Ocelot provides *rate limiting* for upstream requests to prevent downstream services from becoming overwhelmed. [#f1]_ -The authors of this feature were inspired by `@catcherwong article `_ to finally write this documentation. -This feature was added by `@geffzhang `_ on GitHub! Thanks very much! +Rate Limit by Client's Header +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -To get rate limiting working for a Route you need to add the following JSON to it: +To implement *rate limiting* for a Route, you need to incorporate the following JSON configuration: .. code-block:: json "RateLimitOptions": { - "ClientWhitelist": [], + "ClientWhitelist": [], // array of strings "EnableRateLimiting": true, - "Period": "1s", - "PeriodTimespan": 1, + "Period": "1s", // seconds, minutes, hours, days + "PeriodTimespan": 1, // only seconds "Limit": 1 } -* **ClientWhitelist** - This is an array that contains the whitelist of the client. - It means that the client in this array will not be affected by the rate limiting. -* **EnableRateLimiting** - This value specifies enable endpoint rate limiting. -* **Period** - This value specifies the period that the limit applies to, such as ``1s``, ``5m``, ``1h``, ``1d`` and so on. - If you make more requests in the period than the limit allows then you need to wait for **PeriodTimespan** to elapse before you make another request. -* **PeriodTimespan** - This value specifies that we can retry after a certain number of seconds. -* **Limit** - This value specifies the maximum number of requests that a client can make in a defined period. +* **ClientWhitelist** - An array containing the whitelisted clients. Clients listed here will be exempt from rate limiting. + For more information on the **ClientIdHeader** option, refer to the :ref:`rl-global-configuration` section. +* **EnableRateLimiting** - This setting enables rate limiting on endpoints. +* **Period** - This parameter defines the duration for which the limit is applicable, such as ``1s`` (seconds), ``5m`` (minutes), ``1h`` (hours), and ``1d`` (days). + If you reach the exact **Limit** of requests, the excess occurs immediately, and the **PeriodTimespan** begins. + You must wait for the **PeriodTimespan** duration to pass before making another request. + Should you exceed the number of requests within the period more than the **Limit** permits, the **QuotaExceededMessage** will appear in the response, accompanied by the **HttpStatusCode**. +* **PeriodTimespan** - This parameter indicates the time in **seconds** after which a retry is permissible. + During this interval, the **QuotaExceededMessage** will appear in the response, accompanied by an **HttpStatusCode**. + Clients are advised to consult the ``Retry-After`` header to determine the timing of subsequent requests. +* **Limit** - This parameter defines the upper limit of requests a client is allowed to make within a specified **Period**. + +.. _rl-global-configuration: + +Global Configuration +^^^^^^^^^^^^^^^^^^^^ -You can also set the following in the **GlobalConfiguration** part of **ocelot.json**: +You can set the following in the ``GlobalConfiguration`` section of `ocelot.json`_: .. code-block:: json @@ -38,33 +53,48 @@ You can also set the following in the **GlobalConfiguration** part of **ocelot.j "RateLimitOptions": { "DisableRateLimitHeaders": false, "QuotaExceededMessage": "Customize Tips!", - "HttpStatusCode": 123, - "ClientIdHeader": "Test" + "HttpStatusCode": 418, // I'm a teapot + "ClientIdHeader": "MyRateLimiting" } } -* **DisableRateLimitHeaders** - This value specifies whether ``X-Rate-Limit`` and ``Retry-After`` headers are disabled. -* **QuotaExceededMessage** - This value specifies the exceeded message. -* **HttpStatusCode** - This value specifies the returned HTTP status code when rate limiting occurs. -* **ClientIdHeader** - Allows you to specifiy the header that should be used to identify clients. By default it is ``ClientId`` +* **DisableRateLimitHeaders** - Determines if the ``X-Rate-Limit`` and ``Retry-After`` headers are disabled. +* **QuotaExceededMessage** - Defines the message displayed when the quota is exceeded. It is optional and the default message is informative. +* **HttpStatusCode** - Indicates the HTTP status code returned during *rate limiting*. The default value is **429** (`Too Many Requests`_). +* **ClientIdHeader** - Specifies the header used to identify clients, with ``ClientId`` as the default. Future and ASP.NET Core Implementation -------------------------------------- -The Ocelot team considers to redesign *Rate Limiting* feature, -because of `Announcing Rate Limiting for .NET `_ by Brennan Conroy on July 13th, 2022. -There is no decision at the moment, and the old version of the feature is included as a part of release `20.0 `_ for .NET 7. +The Ocelot team is contemplating a redesign of the *Rate Limiting* feature following the `Announcing Rate Limiting for .NET`_ by Brennan Conroy on July 13th, 2022. +Currently, no decision has been made, and the previous version of the feature remains part of the `20.0`_ release for .NET 7. [#f2]_ -See more about new feature being added into ASP.NET Core 7.0 release: +Discover the new features being introduced in the ASP.NET Core 7.0 release: -* `RateLimiter Class `_, since ASP.NET Core **7.0** -* `System.Threading.RateLimiting `_ NuGet package -* `Rate limiting middleware in ASP.NET Core `_ article by Arvin Kahbazi, Maarten Balliauw, and Rick Anderson +* The `RateLimiter Class `_, available since ASP.NET Core 7.0 +* The `System.Threading.RateLimiting `_ NuGet package +* The `Rate limiting middleware in ASP.NET Core `_ article by Arvin Kahbazi, Maarten Balliauw, and Rick Anderson -However, it makes sense to keep the old implementation as a Ocelot built-in native feature, but we are going to migrate to the new Rate Limiter from ``Microsoft.AspNetCore.RateLimiting`` namespace. +While retaining the old implementation as an Ocelot built-in feature makes sense, we plan to transition to the new Rate Limiter from the ``Microsoft.AspNetCore.RateLimiting`` namespace. +Please share your thoughts with us in the `Discussions `_ space of the repository. |octocat| + +"""" + +.. [#f1] Historically, the *"Ocelot Own Rate Limiting"* feature is one of the oldest and first features of Ocelot. This feature was delivered in PR `37`_ by `@geffzhang`_ on GitHub. Many thanks! It was initially released in version `1.3.2`_. The authors were inspired by `@catcherwong article`_ to write this documentation. +.. [#f2] Since PR `37`_ and version `1.3.2`_, the Ocelot team has reviewed and redesigned the feature to provide stable behavior. The fix for bug `1590`_ (PR `1592`_) was released as part of version `23.3`_. + +.. _Announcing Rate Limiting for .NET: https://devblogs.microsoft.com/dotnet/announcing-rate-limiting-for-dotnet/ +.. _ocelot.json: https://github.com/ThreeMammals/Ocelot/blob/main/test/Ocelot.ManualTest/ocelot.json +.. _@geffzhang: https://github.com/ThreeMammals/Ocelot/commits?author=geffzhang +.. _@catcherwong article: http://www.c-sharpcorner.com/article/building-api-gateway-using-ocelot-in-asp-net-core-rate-limiting-part-four/ +.. _Too Many Requests: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/429 +.. _37: https://github.com/ThreeMammals/Ocelot/pull/37 +.. _1590: https://github.com/ThreeMammals/Ocelot/issues/1590 +.. _1592: https://github.com/ThreeMammals/Ocelot/pull/1592 +.. _1.3.2: https://github.com/ThreeMammals/Ocelot/releases/tag/1.3.2 +.. _20.0: https://github.com/ThreeMammals/Ocelot/releases/tag/20.0.0 +.. _23.3: https://github.com/ThreeMammals/Ocelot/releases/tag/23.3.0 .. |octocat| image:: https://github.githubassets.com/images/icons/emoji/octocat.png :alt: octocat :width: 23 - -Please, share your opinion to us in the `Discussions `_ space of the repository. |octocat| diff --git a/src/Ocelot/DependencyInjection/Features.cs b/src/Ocelot/DependencyInjection/Features.cs new file mode 100644 index 000000000..51f836ea9 --- /dev/null +++ b/src/Ocelot/DependencyInjection/Features.cs @@ -0,0 +1,16 @@ +using Microsoft.Extensions.DependencyInjection; +using Ocelot.RateLimiting; + +namespace Ocelot.DependencyInjection; + +public static class Features +{ + /// + /// Ocelot feature: Rate Limiting. + /// + /// The services collection to add the feature to. + /// The same object. + public static IServiceCollection AddRateLimiting(this IServiceCollection services) => services + .AddSingleton() + .AddSingleton(); +} diff --git a/src/Ocelot/DependencyInjection/OcelotBuilder.cs b/src/Ocelot/DependencyInjection/OcelotBuilder.cs index a72ec3cbf..a501014cc 100644 --- a/src/Ocelot/DependencyInjection/OcelotBuilder.cs +++ b/src/Ocelot/DependencyInjection/OcelotBuilder.cs @@ -27,7 +27,7 @@ using Ocelot.Multiplexer; using Ocelot.PathManipulation; using Ocelot.QueryStrings; -using Ocelot.RateLimit; +using Ocelot.RateLimiting; using Ocelot.Request.Creator; using Ocelot.Request.Mapper; using Ocelot.Requester; @@ -109,7 +109,7 @@ public OcelotBuilder(IServiceCollection services, IConfiguration configurationRo Services.TryAddSingleton(); Services.TryAddSingleton(); Services.TryAddSingleton(); - Services.TryAddSingleton(); + Services.AddRateLimiting(); // Feature: Rate Limiting Services.TryAddSingleton(); Services.TryAddSingleton(); Services.TryAddSingleton(); diff --git a/src/Ocelot/Middleware/OcelotPipelineExtensions.cs b/src/Ocelot/Middleware/OcelotPipelineExtensions.cs index 01ec573fb..16f4a5cff 100644 --- a/src/Ocelot/Middleware/OcelotPipelineExtensions.cs +++ b/src/Ocelot/Middleware/OcelotPipelineExtensions.cs @@ -12,7 +12,7 @@ using Ocelot.LoadBalancer.Middleware; using Ocelot.Multiplexer; using Ocelot.QueryStrings.Middleware; -using Ocelot.RateLimit.Middleware; +using Ocelot.RateLimiting.Middleware; using Ocelot.Request.Middleware; using Ocelot.Requester.Middleware; using Ocelot.RequestId.Middleware; diff --git a/src/Ocelot/RateLimit/ClientRateLimitProcessor.cs b/src/Ocelot/RateLimit/ClientRateLimitProcessor.cs deleted file mode 100644 index 14ebc0594..000000000 --- a/src/Ocelot/RateLimit/ClientRateLimitProcessor.cs +++ /dev/null @@ -1,35 +0,0 @@ -using Microsoft.AspNetCore.Http; -using Ocelot.Configuration; - -namespace Ocelot.RateLimit -{ - public class ClientRateLimitProcessor - { - private readonly RateLimitCore _core; - - public ClientRateLimitProcessor(IRateLimitCounterHandler counterHandler) - { - _core = new RateLimitCore(counterHandler); - } - - public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitOptions option) - { - return _core.ProcessRequest(requestIdentity, option); - } - - public int RetryAfterFrom(DateTime timestamp, RateLimitRule rule) - { - return _core.RetryAfterFrom(timestamp, rule); - } - - public RateLimitHeaders GetRateLimitHeaders(HttpContext context, ClientRequestIdentity requestIdentity, RateLimitOptions option) - { - return _core.GetRateLimitHeaders(context, requestIdentity, option); - } - - public TimeSpan ConvertToTimeSpan(string timeSpan) - { - return _core.ConvertToTimeSpan(timeSpan); - } - } -} diff --git a/src/Ocelot/RateLimit/ClientRequestIdentity.cs b/src/Ocelot/RateLimit/ClientRequestIdentity.cs deleted file mode 100644 index b67b7c5a9..000000000 --- a/src/Ocelot/RateLimit/ClientRequestIdentity.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace Ocelot.RateLimit -{ - public class ClientRequestIdentity - { - public ClientRequestIdentity(string clientId, string path, string httpverb) - { - ClientId = clientId; - Path = path; - HttpVerb = httpverb; - } - - public string ClientId { get; } - - public string Path { get; } - - public string HttpVerb { get; } - } -} diff --git a/src/Ocelot/RateLimit/DistributedCacheRateLimitCounterHandler.cs b/src/Ocelot/RateLimit/DistributedCacheRateLimitCounterHandler.cs deleted file mode 100644 index c98e256ad..000000000 --- a/src/Ocelot/RateLimit/DistributedCacheRateLimitCounterHandler.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Microsoft.Extensions.Caching.Distributed; -using Newtonsoft.Json; - -namespace Ocelot.RateLimit -{ - public class DistributedCacheRateLimitCounterHandler : IRateLimitCounterHandler - { - private readonly IDistributedCache _memoryCache; - - public DistributedCacheRateLimitCounterHandler(IDistributedCache memoryCache) - { - _memoryCache = memoryCache; - } - - public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) - { - _memoryCache.SetString(id, JsonConvert.SerializeObject(counter), new DistributedCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); - } - - public bool Exists(string id) - { - var stored = _memoryCache.GetString(id); - return !string.IsNullOrEmpty(stored); - } - - public RateLimitCounter? Get(string id) - { - var stored = _memoryCache.GetString(id); - if (!string.IsNullOrEmpty(stored)) - { - return JsonConvert.DeserializeObject(stored); - } - - return null; - } - - public void Remove(string id) - { - _memoryCache.Remove(id); - } - } -} diff --git a/src/Ocelot/RateLimit/IRateLimitCounterHandler.cs b/src/Ocelot/RateLimit/IRateLimitCounterHandler.cs deleted file mode 100644 index c17d04f7c..000000000 --- a/src/Ocelot/RateLimit/IRateLimitCounterHandler.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace Ocelot.RateLimit -{ - public interface IRateLimitCounterHandler - { - bool Exists(string id); - - RateLimitCounter? Get(string id); - - void Remove(string id); - - void Set(string id, RateLimitCounter counter, TimeSpan expirationTime); - } -} diff --git a/src/Ocelot/RateLimit/MemoryCacheRateLimitCounterHandler.cs b/src/Ocelot/RateLimit/MemoryCacheRateLimitCounterHandler.cs deleted file mode 100644 index 1a030d511..000000000 --- a/src/Ocelot/RateLimit/MemoryCacheRateLimitCounterHandler.cs +++ /dev/null @@ -1,28 +0,0 @@ -using Microsoft.Extensions.Caching.Memory; - -namespace Ocelot.RateLimit -{ - public class MemoryCacheRateLimitCounterHandler : IRateLimitCounterHandler - { - private readonly IMemoryCache _memoryCache; - - public MemoryCacheRateLimitCounterHandler(IMemoryCache memoryCache) - { - _memoryCache = memoryCache; - } - - public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) - { - _memoryCache.Set(id, counter, new MemoryCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); - } - - public bool Exists(string id) => _memoryCache.TryGetValue(id, out RateLimitCounter counter); - - public RateLimitCounter? Get(string id) => _memoryCache.TryGetValue(id, out RateLimitCounter counter) ? counter : null; - - public void Remove(string id) - { - _memoryCache.Remove(id); - } - } -} diff --git a/src/Ocelot/RateLimit/Middleware/RateLimitMiddlewareExtensions.cs b/src/Ocelot/RateLimit/Middleware/RateLimitMiddlewareExtensions.cs deleted file mode 100644 index 91609c67f..000000000 --- a/src/Ocelot/RateLimit/Middleware/RateLimitMiddlewareExtensions.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Microsoft.AspNetCore.Builder; - -namespace Ocelot.RateLimit.Middleware -{ - public static class RateLimitMiddlewareExtensions - { - public static IApplicationBuilder UseRateLimiting(this IApplicationBuilder builder) - { - return builder.UseMiddleware(); - } - } -} diff --git a/src/Ocelot/RateLimit/RateLimitCore.cs b/src/Ocelot/RateLimit/RateLimitCore.cs deleted file mode 100644 index dddf8a772..000000000 --- a/src/Ocelot/RateLimit/RateLimitCore.cs +++ /dev/null @@ -1,147 +0,0 @@ -using Microsoft.AspNetCore.Http; -using Ocelot.Configuration; -using System.Globalization; -using System.Security.Cryptography; - -namespace Ocelot.RateLimit -{ - public class RateLimitCore - { - private readonly IRateLimitCounterHandler _counterHandler; - private static readonly object ProcessLocker = new(); - - public RateLimitCore(IRateLimitCounterHandler counterStore) - { - _counterHandler = counterStore; - } - - public RateLimitCounter ProcessRequest(ClientRequestIdentity requestIdentity, RateLimitOptions option) - { - var counter = new RateLimitCounter(DateTime.UtcNow, 1); - var rule = option.RateLimitRule; - - var counterId = ComputeCounterKey(requestIdentity, option); - - // serial reads and writes - lock (ProcessLocker) - { - var entry = _counterHandler.Get(counterId); - if (entry.HasValue) - { - // entry has not expired - if (entry.Value.Timestamp + TimeSpan.FromSeconds(rule.PeriodTimespan) >= DateTime.UtcNow) - { - // increment request count - var totalRequests = entry.Value.TotalRequests + 1; - - // deep copy - counter = new RateLimitCounter(entry.Value.Timestamp, totalRequests); - } - } - } - - if (counter.TotalRequests > rule.Limit) - { - var retryAfter = RetryAfterFrom(counter.Timestamp, rule); - if (retryAfter > 0) - { - var expirationTime = TimeSpan.FromSeconds(rule.PeriodTimespan); - _counterHandler.Set(counterId, counter, expirationTime); - } - else - { - _counterHandler.Remove(counterId); - } - } - else - { - var expirationTime = ConvertToTimeSpan(rule.Period); - _counterHandler.Set(counterId, counter, expirationTime); - } - - return counter; - } - - public void SaveRateLimitCounter(ClientRequestIdentity requestIdentity, RateLimitOptions option, RateLimitCounter counter, TimeSpan expirationTime) - { - var counterId = ComputeCounterKey(requestIdentity, option); - var rule = option.RateLimitRule; - - // stores: id (string) - timestamp (datetime) - total_requests (long) - _counterHandler.Set(counterId, counter, expirationTime); - } - - public RateLimitHeaders GetRateLimitHeaders(HttpContext context, ClientRequestIdentity requestIdentity, RateLimitOptions option) - { - var rule = option.RateLimitRule; - RateLimitHeaders headers; - var counterId = ComputeCounterKey(requestIdentity, option); - var entry = _counterHandler.Get(counterId); - if (entry.HasValue) - { - headers = new RateLimitHeaders(context, rule.Period, - (rule.Limit - entry.Value.TotalRequests).ToString(), - (entry.Value.Timestamp + ConvertToTimeSpan(rule.Period)).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo) - ); - } - else - { - headers = new RateLimitHeaders(context, - rule.Period, - rule.Limit.ToString(), - (DateTime.UtcNow + ConvertToTimeSpan(rule.Period)).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo)); - } - - return headers; - } - - public string ComputeCounterKey(ClientRequestIdentity requestIdentity, RateLimitOptions option) - { - var key = $"{option.RateLimitCounterPrefix}_{requestIdentity.ClientId}_{option.RateLimitRule.Period}_{requestIdentity.HttpVerb}_{requestIdentity.Path}"; - - var idBytes = Encoding.UTF8.GetBytes(key); - - byte[] hashBytes; - - using (var algorithm = SHA1.Create()) - { - hashBytes = algorithm.ComputeHash(idBytes); - } - - return BitConverter.ToString(hashBytes).Replace("-", string.Empty); - } - - public int RetryAfterFrom(DateTime timestamp, RateLimitRule rule) - { - var secondsPast = Convert.ToInt32((DateTime.UtcNow - timestamp).TotalSeconds); - var retryAfter = Convert.ToInt32(TimeSpan.FromSeconds(rule.PeriodTimespan).TotalSeconds); - retryAfter = retryAfter > 1 ? retryAfter - secondsPast : 1; - return retryAfter; - } - - public TimeSpan ConvertToTimeSpan(string timeSpan) - { - var l = timeSpan.Length - 1; - var value = timeSpan.Substring(0, l); - var type = timeSpan.Substring(l, 1); - - switch (type) - { - case "d": - return TimeSpan.FromDays(double.Parse(value)); - - case "h": - return TimeSpan.FromHours(double.Parse(value)); - - case "m": - return TimeSpan.FromMinutes(double.Parse(value)); - - case "s": - return TimeSpan.FromSeconds(double.Parse(value)); - - default: - throw new FormatException($"{timeSpan} can't be converted to TimeSpan, unknown type {type}"); - } - } - } -} diff --git a/src/Ocelot/RateLimit/RateLimitCounter.cs b/src/Ocelot/RateLimit/RateLimitCounter.cs deleted file mode 100644 index 4e869d440..000000000 --- a/src/Ocelot/RateLimit/RateLimitCounter.cs +++ /dev/null @@ -1,21 +0,0 @@ -using Newtonsoft.Json; - -namespace Ocelot.RateLimit -{ - /// - /// Stores the initial access time and the numbers of calls made from that point. - /// - public struct RateLimitCounter - { - [JsonConstructor] - public RateLimitCounter(DateTime timestamp, long totalRequests) - { - Timestamp = timestamp; - TotalRequests = totalRequests; - } - - public DateTime Timestamp { get; } - - public long TotalRequests { get; } - } -} diff --git a/src/Ocelot/RateLimit/RateLimitHeaders.cs b/src/Ocelot/RateLimit/RateLimitHeaders.cs deleted file mode 100644 index 67d7596ce..000000000 --- a/src/Ocelot/RateLimit/RateLimitHeaders.cs +++ /dev/null @@ -1,23 +0,0 @@ -using Microsoft.AspNetCore.Http; - -namespace Ocelot.RateLimit -{ - public class RateLimitHeaders - { - public RateLimitHeaders(HttpContext context, string limit, string remaining, string reset) - { - Context = context; - Limit = limit; - Remaining = remaining; - Reset = reset; - } - - public HttpContext Context { get; } - - public string Limit { get; } - - public string Remaining { get; } - - public string Reset { get; } - } -} diff --git a/src/Ocelot/RateLimiting/ClientRequestIdentity.cs b/src/Ocelot/RateLimiting/ClientRequestIdentity.cs new file mode 100644 index 000000000..b73fcbbbb --- /dev/null +++ b/src/Ocelot/RateLimiting/ClientRequestIdentity.cs @@ -0,0 +1,15 @@ +namespace Ocelot.RateLimiting; + +public class ClientRequestIdentity +{ + public ClientRequestIdentity(string clientId, string path, string httpverb) + { + ClientId = clientId; + Path = path; + HttpVerb = httpverb; + } + + public string ClientId { get; } + public string Path { get; } + public string HttpVerb { get; } +} diff --git a/src/Ocelot/RateLimiting/DistributedCacheRateLimitStorage.cs b/src/Ocelot/RateLimiting/DistributedCacheRateLimitStorage.cs new file mode 100644 index 000000000..b7fb79de3 --- /dev/null +++ b/src/Ocelot/RateLimiting/DistributedCacheRateLimitStorage.cs @@ -0,0 +1,32 @@ +using Microsoft.Extensions.Caching.Distributed; +using Newtonsoft.Json; + +namespace Ocelot.RateLimiting; + +/// +/// Custom storage based on a distributed cache of a remote/local services. +/// +/// +/// See the interface docs for more details. +/// +public class DistributedCacheRateLimitStorage : IRateLimitStorage +{ + private readonly IDistributedCache _memoryCache; + + public DistributedCacheRateLimitStorage(IDistributedCache memoryCache) => _memoryCache = memoryCache; + + public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) + => _memoryCache.SetString(id, JsonConvert.SerializeObject(counter), new DistributedCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); + + public bool Exists(string id) => !string.IsNullOrEmpty(_memoryCache.GetString(id)); + + public RateLimitCounter? Get(string id) + { + var stored = _memoryCache.GetString(id); + return !string.IsNullOrEmpty(stored) + ? JsonConvert.DeserializeObject(stored) + : null; + } + + public void Remove(string id) => _memoryCache.Remove(id); +} diff --git a/src/Ocelot/RateLimiting/IRateLimitStorage.cs b/src/Ocelot/RateLimiting/IRateLimitStorage.cs new file mode 100644 index 000000000..1044998b1 --- /dev/null +++ b/src/Ocelot/RateLimiting/IRateLimitStorage.cs @@ -0,0 +1,16 @@ +namespace Ocelot.RateLimiting; + +/// +/// Defines a storage for keeping of rate limiting data. +/// +/// Concrete classes should be based on solutions with excellent performance, such as in-memory solutions. +public interface IRateLimitStorage +{ + bool Exists(string id); + + RateLimitCounter? Get(string id); + + void Remove(string id); + + void Set(string id, RateLimitCounter counter, TimeSpan expirationTime); +} diff --git a/src/Ocelot/RateLimiting/IRateLimiting.cs b/src/Ocelot/RateLimiting/IRateLimiting.cs new file mode 100644 index 000000000..684d2f70e --- /dev/null +++ b/src/Ocelot/RateLimiting/IRateLimiting.cs @@ -0,0 +1,59 @@ +using Microsoft.AspNetCore.Http; +using Ocelot.Configuration; + +namespace Ocelot.RateLimiting; + +/// +/// Defines basic Rate Limiting functionality. +/// +public interface IRateLimiting +{ + /// Retrieves the key for the attached storage. + /// See the interface. + /// The current representation of the request. + /// The options of rate limiting. + /// A value of the key. + string GetStorageKey(ClientRequestIdentity identity, RateLimitOptions options); + + /// + /// Gets required information to create wanted headers in upper contexts (middleware, etc). + /// + /// The current context. + /// The current representation of the request. + /// The options of rate limiting. + /// A value. + RateLimitHeaders GetHeaders(HttpContext context, ClientRequestIdentity identity, RateLimitOptions options); + + /// + /// Main entry point to process the current request and apply the limiting rule. + /// + /// Warning! The method performs the storage operations which should be thread safe. + /// The representation of current request. + /// The current rate limiting options. + /// A value. + RateLimitCounter ProcessRequest(ClientRequestIdentity identity, RateLimitOptions options); + + /// + /// Counts requests based on the current counter state and taking into account the limiting rule. + /// + /// Old counter with starting moment inside. + /// The limiting rule. + /// A value. + RateLimitCounter Count(RateLimitCounter? entry, RateLimitRule rule); + + /// + /// Gets the seconds to wait for the next retry by starting moment and the rule. + /// + /// The method must be called after the counting by the method is completed; otherwise it doesn't make sense. + /// The counter with starting moment inside. + /// The limiting rule. + /// A value in seconds. + double RetryAfter(RateLimitCounter counter, RateLimitRule rule); + + /// + /// Converts to time span from a string, such as "1s", "1m", "1h", "1d". + /// + /// The string value with dimentions: '1s', '1m', '1h', '1d'. + /// A value. + TimeSpan ToTimespan(string timespan); +} diff --git a/src/Ocelot/RateLimiting/MemoryCacheRateLimitStorage.cs b/src/Ocelot/RateLimiting/MemoryCacheRateLimitStorage.cs new file mode 100644 index 000000000..7451dac97 --- /dev/null +++ b/src/Ocelot/RateLimiting/MemoryCacheRateLimitStorage.cs @@ -0,0 +1,25 @@ +using Microsoft.Extensions.Caching.Memory; + +namespace Ocelot.RateLimiting; + +/// +/// Default storage based on the memory cache of the local web server instance. +/// +/// +/// See the interface docs for more details. +/// +public class MemoryCacheRateLimitStorage : IRateLimitStorage +{ + private readonly IMemoryCache _memoryCache; + + public MemoryCacheRateLimitStorage(IMemoryCache memoryCache) => _memoryCache = memoryCache; + + public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) + => _memoryCache.Set(id, counter, new MemoryCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); + + public bool Exists(string id) => _memoryCache.TryGetValue(id, out RateLimitCounter counter); + + public RateLimitCounter? Get(string id) => _memoryCache.TryGetValue(id, out RateLimitCounter counter) ? counter : null; + + public void Remove(string id) => _memoryCache.Remove(id); +} diff --git a/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs b/src/Ocelot/RateLimiting/Middleware/RateLimitingMiddleware.cs similarity index 83% rename from src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs rename to src/Ocelot/RateLimiting/Middleware/RateLimitingMiddleware.cs index 47571046f..b407733ae 100644 --- a/src/Ocelot/RateLimit/Middleware/ClientRateLimitMiddleware.cs +++ b/src/Ocelot/RateLimiting/Middleware/RateLimitingMiddleware.cs @@ -2,21 +2,23 @@ using Ocelot.Configuration; using Ocelot.Logging; using Ocelot.Middleware; +using System.Globalization; -namespace Ocelot.RateLimit.Middleware +namespace Ocelot.RateLimiting.Middleware { - public class ClientRateLimitMiddleware : OcelotMiddleware + public class RateLimitingMiddleware : OcelotMiddleware { private readonly RequestDelegate _next; - private readonly ClientRateLimitProcessor _processor; + private readonly IRateLimiting _limiter; - public ClientRateLimitMiddleware(RequestDelegate next, - IOcelotLoggerFactory loggerFactory, - IRateLimitCounterHandler counterHandler) - : base(loggerFactory.CreateLogger()) + public RateLimitingMiddleware( + RequestDelegate next, + IOcelotLoggerFactory factory, + IRateLimiting limiter) + : base(factory.CreateLogger()) { _next = next; - _processor = new ClientRateLimitProcessor(counterHandler); + _limiter = limiter; } public async Task Invoke(HttpContext httpContext) @@ -48,26 +50,20 @@ public async Task Invoke(HttpContext httpContext) if (rule.Limit > 0) { // increment counter - var counter = _processor.ProcessRequest(identity, options); + var counter = _limiter.ProcessRequest(identity, options); // check if limit is reached if (counter.TotalRequests > rule.Limit) { - //compute retry after value - var retryAfter = _processor.RetryAfterFrom(counter.Timestamp, rule); - - // log blocked request - LogBlockedRequest(httpContext, identity, counter, rule, downstreamRoute); - - var retrystring = retryAfter.ToString(System.Globalization.CultureInfo.InvariantCulture); + var retryAfter = _limiter.RetryAfter(counter, rule); // compute retry after value based on counter state + LogBlockedRequest(httpContext, identity, counter, rule, downstreamRoute); // log blocked request virtually // break execution - var ds = ReturnQuotaExceededResponse(httpContext, options, retrystring); + var ds = ReturnQuotaExceededResponse(httpContext, options, retryAfter.ToString(CultureInfo.InvariantCulture)); httpContext.Items.UpsertDownstreamResponse(ds); // Set Error httpContext.Items.SetError(new QuotaExceededError(GetResponseMessage(options), options.HttpStatusCode)); - return; } } @@ -75,7 +71,7 @@ public async Task Invoke(HttpContext httpContext) //set X-Rate-Limit headers for the longest period if (!options.DisableRateLimitHeaders) { - var headers = _processor.GetRateLimitHeaders(httpContext, identity, options); + var headers = _limiter.GetHeaders(httpContext, identity, options); httpContext.Response.OnStarting(SetRateLimitHeaders, state: headers); } @@ -123,7 +119,7 @@ public virtual DownstreamResponse ReturnQuotaExceededResponse(HttpContext httpCo if (!option.DisableRateLimitHeaders) { - http.Headers.TryAddWithoutValidation("Retry-After", retryAfter); + http.Headers.TryAddWithoutValidation("Retry-After", retryAfter); // in seconds, not date string } return new DownstreamResponse(http); diff --git a/src/Ocelot/RateLimiting/Middleware/RateLimitingMiddlewareExtensions.cs b/src/Ocelot/RateLimiting/Middleware/RateLimitingMiddlewareExtensions.cs new file mode 100644 index 000000000..68268cb40 --- /dev/null +++ b/src/Ocelot/RateLimiting/Middleware/RateLimitingMiddlewareExtensions.cs @@ -0,0 +1,11 @@ +using Microsoft.AspNetCore.Builder; + +namespace Ocelot.RateLimiting.Middleware; + +public static class RateLimitingMiddlewareExtensions +{ + public static IApplicationBuilder UseRateLimiting(this IApplicationBuilder builder) + { + return builder.UseMiddleware(); + } +} diff --git a/src/Ocelot/RateLimit/QuotaExceededError.cs b/src/Ocelot/RateLimiting/QuotaExceededError.cs similarity index 89% rename from src/Ocelot/RateLimit/QuotaExceededError.cs rename to src/Ocelot/RateLimiting/QuotaExceededError.cs index 9c98dc5a6..a46cb4c78 100644 --- a/src/Ocelot/RateLimit/QuotaExceededError.cs +++ b/src/Ocelot/RateLimiting/QuotaExceededError.cs @@ -1,6 +1,6 @@ using Ocelot.Errors; -namespace Ocelot.RateLimit +namespace Ocelot.RateLimiting { public class QuotaExceededError : Error { diff --git a/src/Ocelot/RateLimiting/RateLimitCounter.cs b/src/Ocelot/RateLimiting/RateLimitCounter.cs new file mode 100644 index 000000000..2507a0433 --- /dev/null +++ b/src/Ocelot/RateLimiting/RateLimitCounter.cs @@ -0,0 +1,29 @@ +using Newtonsoft.Json; + +namespace Ocelot.RateLimiting; + +/// +/// Stores the initial access time and the numbers of calls made from that point. +/// +public struct RateLimitCounter +{ + [JsonConstructor] + public RateLimitCounter(DateTime startedAt, DateTime? exceededAt, long totalRequests) + { + StartedAt = startedAt; + ExceededAt = exceededAt; + TotalRequests = totalRequests; + } + + /// The moment when the counting was started. + /// A value of the moment. + public DateTime StartedAt { get; } + + /// The moment when the limit was exceeded. + /// A value of the moment. + public DateTime? ExceededAt { get; } + + /// Total number of requests counted. + /// A value of total number. + public long TotalRequests { get; set; } +} diff --git a/src/Ocelot/RateLimiting/RateLimitHeaders.cs b/src/Ocelot/RateLimiting/RateLimitHeaders.cs new file mode 100644 index 000000000..860e0d6bb --- /dev/null +++ b/src/Ocelot/RateLimiting/RateLimitHeaders.cs @@ -0,0 +1,19 @@ +using Microsoft.AspNetCore.Http; + +namespace Ocelot.RateLimiting; + +public class RateLimitHeaders +{ + public RateLimitHeaders(HttpContext context, string limit, string remaining, string reset) + { + Context = context; + Limit = limit; + Remaining = remaining; + Reset = reset; + } + + public HttpContext Context { get; } + public string Limit { get; } + public string Remaining { get; } + public string Reset { get; } +} diff --git a/src/Ocelot/RateLimiting/RateLimiting.cs b/src/Ocelot/RateLimiting/RateLimiting.cs new file mode 100644 index 000000000..9edf4a310 --- /dev/null +++ b/src/Ocelot/RateLimiting/RateLimiting.cs @@ -0,0 +1,194 @@ +using Microsoft.AspNetCore.Http; +using Ocelot.Configuration; +using System.Globalization; +using System.Security.Cryptography; + +namespace Ocelot.RateLimiting; + +public class RateLimiting : IRateLimiting +{ + private readonly IRateLimitStorage _storage; + private static readonly object ProcessLocker = new(); + + public RateLimiting(IRateLimitStorage storage) + { + _storage = storage; + } + + /// + /// Main entry point to process the current request and apply the limiting rule. + /// + /// Warning! The method performs the storage operations which MUST BE thread safe. + /// The representation of current request. + /// The current rate limiting options. + /// A value. + public virtual RateLimitCounter ProcessRequest(ClientRequestIdentity identity, RateLimitOptions options) + { + RateLimitCounter counter; + var rule = options.RateLimitRule; + var counterId = GetStorageKey(identity, options); + + // Serial reads/writes from/to the storage which must be thread safe + lock (ProcessLocker) + { + var entry = _storage.Get(counterId); + counter = Count(entry, rule); + var expiration = ToTimespan(rule.Period); // default expiration is set for the Period value + if (counter.TotalRequests > rule.Limit) + { + var retryAfter = RetryAfter(counter, rule); // the calculation depends on the counter returned from CountRequests + if (retryAfter > 0) + { + // Rate Limit exceeded, ban period is active + expiration = TimeSpan.FromSeconds(rule.PeriodTimespan); // current state should expire in the storage after ban period + } + else + { + // Ban period elapsed, start counting + _storage.Remove(counterId); // the store can delete the element on its own using an expiration mechanism, but let's force the element to be deleted + counter = new RateLimitCounter(DateTime.UtcNow, null, 1); + } + } + + _storage.Set(counterId, counter, expiration); + } + + return counter; + } + + /// + /// Counts requests based on the current counter state and taking into account the limiting rule. + /// + /// Old counter with starting moment inside. + /// The limiting rule. + /// A value. + public virtual RateLimitCounter Count(RateLimitCounter? entry, RateLimitRule rule) + { + var now = DateTime.UtcNow; + if (!entry.HasValue) // no entry, start counting + { + return new RateLimitCounter(now, null, 1); // current request is the 1st one + } + + var counter = entry.Value; + var total = counter.TotalRequests + 1; // increment request count + var startedAt = counter.StartedAt; + if (startedAt + ToTimespan(rule.Period) >= now) // counting Period is active + { + var exceededAt = total >= rule.Limit && !counter.ExceededAt.HasValue // current request number equals to the limit + ? now // the exceeding moment is now, the next request will fail but the current one doesn't + : counter.ExceededAt; + return new RateLimitCounter(startedAt, exceededAt, total); // deep copy + } + + var wasExceededAt = counter.ExceededAt; + return wasExceededAt + TimeSpan.FromSeconds(rule.PeriodTimespan) >= now // ban PeriodTimespan is active + ? new RateLimitCounter(startedAt, wasExceededAt, total) // still count + : new RateLimitCounter(now, null, 1); // Ban PeriodTimespan elapsed, start counting NOW! + } + + public virtual RateLimitHeaders GetHeaders(HttpContext context, ClientRequestIdentity identity, RateLimitOptions options) + { + RateLimitHeaders headers; + RateLimitCounter? entry; + lock (ProcessLocker) + { + var counterId = GetStorageKey(identity, options); + entry = _storage.Get(counterId); + } + + var rule = options.RateLimitRule; + if (entry.HasValue) + { + headers = new RateLimitHeaders(context, + limit: rule.Period, + remaining: (rule.Limit - entry.Value.TotalRequests).ToString(), + reset: (entry.Value.StartedAt + ToTimespan(rule.Period)).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo)); + } + else + { + headers = new RateLimitHeaders(context, + limit: rule.Period, // TODO Double check + remaining: rule.Limit.ToString(), // TODO Double check + reset: (DateTime.UtcNow + ToTimespan(rule.Period)).ToUniversalTime().ToString("o", DateTimeFormatInfo.InvariantInfo)); + } + + return headers; + } + + public virtual string GetStorageKey(ClientRequestIdentity identity, RateLimitOptions options) + { + var key = $"{options.RateLimitCounterPrefix}_{identity.ClientId}_{options.RateLimitRule.Period}_{identity.HttpVerb}_{identity.Path}"; + var idBytes = Encoding.UTF8.GetBytes(key); + + byte[] hashBytes; + using (var algorithm = SHA1.Create()) + { + hashBytes = algorithm.ComputeHash(idBytes); + } + + return BitConverter.ToString(hashBytes).Replace("-", string.Empty); + } + + /// + /// Gets the seconds to wait for the next retry by starting moment and the rule. + /// + /// The method must be called after the one. + /// The counter state. + /// The current rule. + /// An value of seconds. + public virtual double RetryAfter(RateLimitCounter counter, RateLimitRule rule) + { + const double defaultSeconds = 1.0D; // one second + var periodTimespan = rule.PeriodTimespan < defaultSeconds + ? defaultSeconds // allow values which are greater or equal to 1 second + : rule.PeriodTimespan; // good value + var now = DateTime.UtcNow; + if (counter.StartedAt + ToTimespan(rule.Period) >= now) // counting Period is active + { + return counter.TotalRequests < rule.Limit + ? 0.0D // happy path, no need to retry, current request is valid + : counter.ExceededAt.HasValue + ? periodTimespan - (now - counter.ExceededAt.Value).TotalSeconds // minus seconds past + : periodTimespan; // exceeding not yet detected -> let's ban for whole period + } + + if (counter.ExceededAt.HasValue && // limit exceeding was happen + counter.ExceededAt + TimeSpan.FromSeconds(periodTimespan) >= now) // ban PeriodTimespan is active + { + var startedAt = counter.ExceededAt.Value; // ban period was started at + double secondsPast = (now - startedAt).TotalSeconds; + double retryAfter = periodTimespan - secondsPast; + return retryAfter; // it can be negative, which means the wait in PeriodTimespan seconds has ended + } + + return 0.0D; // ban period elapsed, no need to retry, current request is valid + } + + /// + /// Converts to time span from a string, such as "1s", "1m", "1h", "1d". + /// + /// The string value with dimentions: '1s', '1m', '1h', '1d'. + /// A value. + /// By default if the value dimension can't be detected. + public virtual TimeSpan ToTimespan(string timespan) + { + if (string.IsNullOrEmpty(timespan)) + { + return TimeSpan.Zero; + } + + var len = timespan.Length - 1; + var value = timespan.Substring(0, len); + var type = timespan.Substring(len, 1); + + return type switch + { + "d" => TimeSpan.FromDays(double.Parse(value)), + "h" => TimeSpan.FromHours(double.Parse(value)), + "m" => TimeSpan.FromMinutes(double.Parse(value)), + "s" => TimeSpan.FromSeconds(double.Parse(value)), + _ => throw new FormatException($"{timespan} can't be converted to TimeSpan, unknown type {type}"), + }; + } +} diff --git a/test/Ocelot.AcceptanceTests/ClientRateLimitTests.cs b/test/Ocelot.AcceptanceTests/ClientRateLimitTests.cs deleted file mode 100644 index dad9af3dc..000000000 --- a/test/Ocelot.AcceptanceTests/ClientRateLimitTests.cs +++ /dev/null @@ -1,217 +0,0 @@ -using Microsoft.AspNetCore.Http; -using Ocelot.Configuration.File; - -namespace Ocelot.AcceptanceTests -{ - public class ClientRateLimitTests : IDisposable - { - private readonly Steps _steps; - private int _counterOne; - private readonly ServiceHandler _serviceHandler; - - public ClientRateLimitTests() - { - _serviceHandler = new ServiceHandler(); - _steps = new Steps(); - } - - [Fact] - public void should_call_withratelimiting() - { - var port = PortFinder.GetRandomPort(); - - var configuration = new FileConfiguration - { - Routes = new List - { - new() - { - DownstreamPathTemplate = "/api/ClientRateLimit", - DownstreamHostAndPorts = new List - { - new() - { - Host = "localhost", - Port = port, - }, - }, - DownstreamScheme = "http", - UpstreamPathTemplate = "/api/ClientRateLimit", - UpstreamHttpMethod = new List { "Get" }, - RequestIdKey = _steps.RequestIdKey, - RateLimitOptions = new FileRateLimitRule - { - EnableRateLimiting = true, - ClientWhitelist = new List(), - Limit = 3, - Period = "1s", - PeriodTimespan = 1000, - }, - }, - }, - GlobalConfiguration = new FileGlobalConfiguration - { - RateLimitOptions = new FileRateLimitOptions - { - ClientIdHeader = "ClientId", - DisableRateLimitHeaders = false, - QuotaExceededMessage = string.Empty, - RateLimitCounterPrefix = string.Empty, - HttpStatusCode = 428, - }, - RequestIdKey = "oceclientrequest", - }, - }; - - this.Given(x => x.GivenThereIsAServiceRunningOn($"http://localhost:{port}", "/api/ClientRateLimit")) - .And(x => _steps.GivenThereIsAConfiguration(configuration)) - .And(x => _steps.GivenOcelotIsRunning()) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(200)) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 2)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(200)) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(428)) - .BDDfy(); - } - - [Fact] - public void should_wait_for_period_timespan_to_elapse_before_making_next_request() - { - var port = PortFinder.GetRandomPort(); - - var configuration = new FileConfiguration - { - Routes = new List - { - new() - { - DownstreamPathTemplate = "/api/ClientRateLimit", - DownstreamHostAndPorts = new List - { - new() - { - Host = "localhost", - Port = port, - }, - }, - DownstreamScheme = "http", - UpstreamPathTemplate = "/api/ClientRateLimit", - UpstreamHttpMethod = new List { "Get" }, - RequestIdKey = _steps.RequestIdKey, - - RateLimitOptions = new FileRateLimitRule - { - EnableRateLimiting = true, - ClientWhitelist = new List(), - Limit = 3, - Period = "1s", - PeriodTimespan = 2, - }, - }, - }, - GlobalConfiguration = new FileGlobalConfiguration - { - RateLimitOptions = new FileRateLimitOptions - { - ClientIdHeader = "ClientId", - DisableRateLimitHeaders = false, - QuotaExceededMessage = string.Empty, - RateLimitCounterPrefix = string.Empty, - HttpStatusCode = 428, - }, - RequestIdKey = "oceclientrequest", - }, - }; - - this.Given(x => x.GivenThereIsAServiceRunningOn($"http://localhost:{port}", "/api/ClientRateLimit")) - .And(x => _steps.GivenThereIsAConfiguration(configuration)) - .And(x => _steps.GivenOcelotIsRunning()) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(200)) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 2)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(200)) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(428)) - .And(x => _steps.GivenIWait(1000)) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(428)) - .And(x => _steps.GivenIWait(1000)) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(200)) - .BDDfy(); - } - - [Fact] - public void should_call_middleware_withWhitelistClient() - { - var port = PortFinder.GetRandomPort(); - - var configuration = new FileConfiguration - { - Routes = new List - { - new() - { - DownstreamPathTemplate = "/api/ClientRateLimit", - DownstreamHostAndPorts = new List - { - new() - { - Host = "localhost", - Port = port, - }, - }, - DownstreamScheme = "http", - UpstreamPathTemplate = "/api/ClientRateLimit", - UpstreamHttpMethod = new List { "Get" }, - RequestIdKey = _steps.RequestIdKey, - - RateLimitOptions = new FileRateLimitRule - { - EnableRateLimiting = true, - ClientWhitelist = new List { "ocelotclient1"}, - Limit = 3, - Period = "1s", - PeriodTimespan = 100, - }, - }, - }, - GlobalConfiguration = new FileGlobalConfiguration - { - RateLimitOptions = new FileRateLimitOptions - { - ClientIdHeader = "ClientId", - DisableRateLimitHeaders = false, - QuotaExceededMessage = string.Empty, - RateLimitCounterPrefix = string.Empty, - }, - RequestIdKey = "oceclientrequest", - }, - }; - - this.Given(x => x.GivenThereIsAServiceRunningOn($"http://localhost:{port}", "/api/ClientRateLimit")) - .And(x => _steps.GivenThereIsAConfiguration(configuration)) - .And(x => _steps.GivenOcelotIsRunning()) - .When(x => _steps.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 4)) - .Then(x => _steps.ThenTheStatusCodeShouldBe(200)) - .BDDfy(); - } - - private void GivenThereIsAServiceRunningOn(string baseUrl, string basePath) - { - _serviceHandler.GivenThereIsAServiceRunningOn(baseUrl, basePath, context => - { - _counterOne++; - context.Response.StatusCode = 200; - context.Response.WriteAsync(_counterOne.ToString()); - return Task.CompletedTask; - }); - } - - public void Dispose() - { - _steps.Dispose(); - } - } -} diff --git a/test/Ocelot.AcceptanceTests/RateLimiting/ClientRateLimitingTests.cs b/test/Ocelot.AcceptanceTests/RateLimiting/ClientRateLimitingTests.cs new file mode 100644 index 000000000..4dd80e7ec --- /dev/null +++ b/test/Ocelot.AcceptanceTests/RateLimiting/ClientRateLimitingTests.cs @@ -0,0 +1,182 @@ +using Microsoft.AspNetCore.Http; +using Ocelot.Configuration.File; + +namespace Ocelot.AcceptanceTests.RateLimiting; + +public sealed class ClientRateLimitingTests : Steps, IDisposable +{ + const int OK = (int)HttpStatusCode.OK; + const int TooManyRequests = (int)HttpStatusCode.TooManyRequests; + + private int _counterOne; + private readonly ServiceHandler _serviceHandler; + + public ClientRateLimitingTests() + { + _serviceHandler = new ServiceHandler(); + } + + public override void Dispose() + { + _serviceHandler.Dispose(); + base.Dispose(); + } + + [Fact] + [Trait("Feat", "37")] + public void Should_call_with_rate_limiting() + { + var port = PortFinder.GetRandomPort(); + var route = GivenRoute(port, null, null, new(), 3, "1s", 1); // periods are equal + var configuration = GivenConfigurationWithRateLimitOptions(route); + this.Given(x => x.GivenThereIsAServiceRunningOn(DownstreamUrl(port), "/api/ClientRateLimit")) + .And(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning()) + .When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 2)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1)) + .Then(x => ThenTheStatusCodeShouldBe(TooManyRequests)) + .BDDfy(); + } + + [Fact] + [Trait("Feat", "37")] + public void Should_wait_for_period_timespan_to_elapse_before_making_next_request() + { + var port = PortFinder.GetRandomPort(); + var route = GivenRoute(port, "/api/ClientRateLimit?count={count}", "/ClientRateLimit/?{count}", new(), 3, "1s", 2); + var configuration = GivenConfigurationWithRateLimitOptions(route); + _counterOne = 0; + this.Given(x => x.GivenThereIsAServiceRunningOn(DownstreamUrl(port), "/api/ClientRateLimit")) + .And(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning()) + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 1)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 2)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 1)) + .Then(x => ThenTheStatusCodeShouldBe(TooManyRequests)) + .And(x => GivenIWait(1000)) + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 1)) + .Then(x => ThenTheStatusCodeShouldBe(TooManyRequests)) + .And(x => GivenIWait(1000)) + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 1)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .And(x => ThenTheResponseBodyShouldBe("4")) // total 4 OK responses + .BDDfy(); + } + + private int _count = 0; + private int Count() => ++_count; + private string Url() => $"/ClientRateLimit/?{Count()}"; + + private void WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Func urlDelegate, long times) + { + for (long i = 0; i < times; i++) + { + var url = urlDelegate.Invoke(); + WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(url, 1); + } + } + + [Fact] + [Trait("Feat", "37")] + public void Should_call_middleware_with_white_list_client() + { + var port = PortFinder.GetRandomPort(); + var route = GivenRoute(port, null, null, whitelist: new() { "ocelotclient1" }, 3, "3s", 2); // main period is greater than ban one + var configuration = GivenConfigurationWithRateLimitOptions(route); + this.Given(x => x.GivenThereIsAServiceRunningOn(DownstreamUrl(port), "/api/ClientRateLimit")) + .And(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning()) + .When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 4)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .BDDfy(); + } + + [Fact] + [Trait("Bug", "1590")] + public void StatusShouldNotBeEqualTo429_PeriodTimespanValueIsGreaterThanPeriod() + { + _counterOne = 0; + + // Bug scenario + const string period = "1s"; + const double periodTimespan = /*30*/3; // but decrease 30 to 3 secs, "no wasting time" life hack + const long limit = 100L; + + var port = PortFinder.GetRandomPort(); + var route = GivenRoute(port, "/api/ClientRateLimit?count={count}", "/ClientRateLimit/?{count}", new(), + limit, period, periodTimespan); // bug scenario, adapted + var configuration = GivenConfigurationWithRateLimitOptions(route); + this.Given(x => x.GivenThereIsAServiceRunningOn(DownstreamUrl(port), "/api/ClientRateLimit")) + .And(x => GivenThereIsAConfiguration(configuration)) + .And(x => GivenOcelotIsRunning()) + + // main scenario + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, route.RateLimitOptions.Limit)) // 100 times to reach the limit + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .And(x => ThenTheResponseBodyShouldBe(route.RateLimitOptions.Limit.ToString())) // total 100 OK responses + + // extra scenario + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 1)) // 101st request should fail + .Then(x => ThenTheStatusCodeShouldBe(TooManyRequests)) + .And(x => GivenIWait((int)TimeSpan.FromSeconds(route.RateLimitOptions.PeriodTimespan).TotalMilliseconds)) // in 3 secs PeriodTimespan will elapse + .When(x => x.WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(Url, 1)) + .Then(x => ThenTheStatusCodeShouldBe(OK)) + .And(x => ThenTheResponseBodyShouldBe("101")) // total 101 OK responses + .BDDfy(); + } + + private void GivenThereIsAServiceRunningOn(string baseUrl, string basePath) + { + _serviceHandler.GivenThereIsAServiceRunningOn(baseUrl, basePath, context => + { + _counterOne++; + context.Response.StatusCode = OK; + context.Response.WriteAsync(_counterOne.ToString()); + return Task.CompletedTask; + }); + } + + private FileRoute GivenRoute(int port, string downstream, string upstream, List whitelist, long limit, string period, double periodTimespan) => new() + { + DownstreamPathTemplate = downstream ?? "/api/ClientRateLimit", + DownstreamHostAndPorts = new() + { + new("localhost", port), + }, + DownstreamScheme = Uri.UriSchemeHttp, + UpstreamPathTemplate = upstream ?? "/api/ClientRateLimit", + UpstreamHttpMethod = new() { HttpMethods.Get }, + RequestIdKey = RequestIdKey, + RateLimitOptions = new FileRateLimitRule + { + EnableRateLimiting = true, + ClientWhitelist = whitelist ?? new() { "ocelotclient1" }, + Limit = limit, + Period = period ?? "1s", + PeriodTimespan = periodTimespan, + }, + }; + + private static FileConfiguration GivenConfigurationWithRateLimitOptions(params FileRoute[] routes) + { + var config = GivenConfiguration(routes); + config.GlobalConfiguration = new() + { + RateLimitOptions = new() + { + ClientIdHeader = "ClientId", + DisableRateLimitHeaders = false, + QuotaExceededMessage = "Exceeding!", + RateLimitCounterPrefix = "ABC", + HttpStatusCode = TooManyRequests, // 429 + }, + RequestIdKey = "OcelotClientRequest", + }; + return config; + } +} diff --git a/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs b/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs deleted file mode 100644 index 1c2267ae2..000000000 --- a/test/Ocelot.UnitTests/RateLimit/ClientRateLimitMiddlewareTests.cs +++ /dev/null @@ -1,183 +0,0 @@ -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Caching.Memory; -using Ocelot.Configuration; -using Ocelot.Configuration.Builder; -using Ocelot.Logging; -using Ocelot.Middleware; -using Ocelot.RateLimit; -using Ocelot.RateLimit.Middleware; -using Ocelot.Request.Middleware; - -namespace Ocelot.UnitTests.RateLimit -{ - public class ClientRateLimitMiddlewareTests : UnitTest - { - private readonly IRateLimitCounterHandler _rateLimitCounterHandler; - private readonly Mock _loggerFactory; - private readonly Mock _logger; - private readonly ClientRateLimitMiddleware _middleware; - private readonly RequestDelegate _next; - private DownstreamResponse _downstreamResponse; - private readonly string _url; - - public ClientRateLimitMiddlewareTests() - { - _url = "http://localhost:51879"; - var cacheEntryOptions = new MemoryCacheOptions(); - _rateLimitCounterHandler = new MemoryCacheRateLimitCounterHandler(new MemoryCache(cacheEntryOptions)); - _loggerFactory = new Mock(); - _logger = new Mock(); - _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); - _next = context => Task.CompletedTask; - _middleware = new ClientRateLimitMiddleware(_next, _loggerFactory.Object, _rateLimitCounterHandler); - } - - [Fact] - public void should_call_middleware_and_ratelimiting() - { - var upstreamTemplate = new UpstreamPathTemplateBuilder().Build(); - - var downstreamRoute = new DownstreamRouteBuilder() - .WithEnableRateLimiting(true) - .WithRateLimitOptions(new RateLimitOptions(true, "ClientId", () => new List(), false, string.Empty, string.Empty, new RateLimitRule("1s", 100, 3), 429)) - .WithUpstreamHttpMethod(new List { "Get" }) - .WithUpstreamPathTemplate(upstreamTemplate) - .Build(); - - var route = new RouteBuilder() - .WithDownstreamRoute(downstreamRoute) - .WithUpstreamHttpMethod(new List { "Get" }) - .Build(); - - var downstreamRouteHolder = new Ocelot.DownstreamRouteFinder.DownstreamRouteHolder(new List(), route); - - this.Given(x => x.WhenICallTheMiddlewareMultipleTimes(2, downstreamRouteHolder)) - .Then(x => x.ThenThereIsNoDownstreamResponse()) - .When(x => x.WhenICallTheMiddlewareMultipleTimes(3, downstreamRouteHolder)) - .Then(x => x.ThenTheResponseIs429()) - .BDDfy(); - } - - [Fact] - public void should_call_middleware_withWhitelistClient() - { - var downstreamRoute = new Ocelot.DownstreamRouteFinder.DownstreamRouteHolder(new List(), - new RouteBuilder() - .WithDownstreamRoute(new DownstreamRouteBuilder() - .WithEnableRateLimiting(true) - .WithRateLimitOptions( - new RateLimitOptions(true, "ClientId", () => new List { "ocelotclient2" }, false, string.Empty, string.Empty, new RateLimitRule("1s", 100, 3), 429)) - .WithUpstreamHttpMethod(new List { "Get" }) - .Build()) - .WithUpstreamHttpMethod(new List { "Get" }) - .Build()); - - this.Given(x => x.WhenICallTheMiddlewareWithWhiteClient(downstreamRoute)) - .Then(x => x.ThenThereIsNoDownstreamResponse()) - .BDDfy(); - } - - private void WhenICallTheMiddlewareMultipleTimes(int times, Ocelot.DownstreamRouteFinder.DownstreamRouteHolder downstreamRoute) - { - var httpContexts = new List(); - - for (var i = 0; i < times; i++) - { - var httpContext = new DefaultHttpContext - { - Response = - { - Body = new FakeStream(), - }, - }; - httpContext.Items.UpsertDownstreamRoute(downstreamRoute.Route.DownstreamRoute[0]); - httpContext.Items.UpsertTemplatePlaceholderNameAndValues(downstreamRoute.TemplatePlaceholderNameAndValues); - httpContext.Items.UpsertDownstreamRoute(downstreamRoute); - var clientId = "ocelotclient1"; - var request = new HttpRequestMessage(new HttpMethod("GET"), _url); - httpContext.Items.UpsertDownstreamRequest(new DownstreamRequest(request)); - httpContext.Request.Headers.TryAdd("ClientId", clientId); - httpContexts.Add(httpContext); - } - - foreach (var httpContext in httpContexts) - { - _middleware.Invoke(httpContext).GetAwaiter().GetResult(); - var ds = httpContext.Items.DownstreamResponse(); - _downstreamResponse = ds; - } - } - - private void WhenICallTheMiddlewareWithWhiteClient(Ocelot.DownstreamRouteFinder.DownstreamRouteHolder downstreamRoute) - { - var clientId = "ocelotclient2"; - - for (var i = 0; i < 10; i++) - { - var httpContext = new DefaultHttpContext - { - Response = - { - Body = new FakeStream(), - }, - }; - httpContext.Items.UpsertDownstreamRoute(downstreamRoute.Route.DownstreamRoute[0]); - httpContext.Items.UpsertTemplatePlaceholderNameAndValues(downstreamRoute.TemplatePlaceholderNameAndValues); - httpContext.Items.UpsertDownstreamRoute(downstreamRoute); - var request = new HttpRequestMessage(new HttpMethod("GET"), _url); - request.Headers.Add("ClientId", clientId); - httpContext.Items.UpsertDownstreamRequest(new DownstreamRequest(request)); - httpContext.Request.Headers.TryAdd("ClientId", clientId); - _middleware.Invoke(httpContext).GetAwaiter().GetResult(); - var ds = httpContext.Items.DownstreamResponse(); - _downstreamResponse = ds; - } - } - - private void ThenTheResponseIs429() - { - var code = (int)_downstreamResponse.StatusCode; - code.ShouldBe(429); - } - - private void ThenThereIsNoDownstreamResponse() - { - _downstreamResponse.ShouldBeNull(); - } - } - - internal class FakeStream : Stream - { - public override void Flush() - { - //do nothing - //throw new System.NotImplementedException(); - } - - public override int Read(byte[] buffer, int offset, int count) - { - throw new System.NotImplementedException(); - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new System.NotImplementedException(); - } - - public override void SetLength(long value) - { - throw new System.NotImplementedException(); - } - - public override void Write(byte[] buffer, int offset, int count) - { - //do nothing - } - - public override bool CanRead { get; } - public override bool CanSeek { get; } - public override bool CanWrite => true; - public override long Length { get; } - public override long Position { get; set; } - } -} diff --git a/test/Ocelot.UnitTests/RateLimiting/RateLimitingMiddlewareTests.cs b/test/Ocelot.UnitTests/RateLimiting/RateLimitingMiddlewareTests.cs new file mode 100644 index 000000000..29c3b0dee --- /dev/null +++ b/test/Ocelot.UnitTests/RateLimiting/RateLimitingMiddlewareTests.cs @@ -0,0 +1,218 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Caching.Memory; +using Ocelot.Configuration; +using Ocelot.Configuration.Builder; +using Ocelot.Logging; +using Ocelot.Middleware; +using Ocelot.RateLimiting; +using Ocelot.RateLimiting.Middleware; +using Ocelot.Request.Middleware; +using System.Text; +using _DownstreamRouteHolder_ = Ocelot.DownstreamRouteFinder.DownstreamRouteHolder; +using _RateLimiting_ = Ocelot.RateLimiting.RateLimiting; + +namespace Ocelot.UnitTests.RateLimiting; + +public class RateLimitingMiddlewareTests : UnitTest +{ + private readonly IRateLimitStorage _storage; + private readonly Mock _loggerFactory; + private readonly Mock _logger; + private readonly RateLimitingMiddleware _middleware; + private readonly RequestDelegate _next; + private readonly IRateLimiting _rateLimiting; + private readonly List _downstreamResponses; + private readonly string _url; + + public RateLimitingMiddlewareTests() + { + _url = "http://localhost:51879"; + var cacheEntryOptions = new MemoryCacheOptions(); + _storage = new MemoryCacheRateLimitStorage(new MemoryCache(cacheEntryOptions)); + _loggerFactory = new Mock(); + _logger = new Mock(); + _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); + _next = context => Task.CompletedTask; + _rateLimiting = new _RateLimiting_(_storage); + _middleware = new RateLimitingMiddleware(_next, _loggerFactory.Object, _rateLimiting); + _downstreamResponses = new(); + } + + [Fact] + [Trait("Feat", "37")] + public async Task Should_call_middleware_and_ratelimiting() + { + // Arrange + const long limit = 3L; + var upstreamTemplate = new UpstreamPathTemplateBuilder() + .Build(); + var downstreamRoute = new DownstreamRouteBuilder() + .WithEnableRateLimiting(true) + .WithRateLimitOptions(new( + enableRateLimiting: true, + clientIdHeader: "ClientId", + getClientWhitelist: () => new List(), + disableRateLimitHeaders: false, + quotaExceededMessage: "Exceeding!", + rateLimitCounterPrefix: string.Empty, + new RateLimitRule("1s", 100.0D, limit), + (int)HttpStatusCode.TooManyRequests)) + .WithUpstreamHttpMethod(new() { "Get" }) + .WithUpstreamPathTemplate(upstreamTemplate) + .Build(); + var route = new RouteBuilder() + .WithDownstreamRoute(downstreamRoute) + .WithUpstreamHttpMethod(new() { "Get" }) + .Build(); + var downstreamRouteHolder = new _DownstreamRouteHolder_(new(), route); + + // Act, Assert + await WhenICallTheMiddlewareMultipleTimes(limit, downstreamRouteHolder); + _downstreamResponses.ForEach(dsr => dsr.ShouldBeNull()); + + // Act, Assert: the next request should fail + await WhenICallTheMiddlewareMultipleTimes(3, downstreamRouteHolder); + _downstreamResponses.ShouldNotBeNull(); + for (int i = 0; i < _downstreamResponses.Count; i++) + { + var response = _downstreamResponses[i].ShouldNotBeNull(); + response.StatusCode.ShouldBe(HttpStatusCode.TooManyRequests, $"Downstream Response no is {i}"); + var body = await response.Content.ReadAsStringAsync(); + body.ShouldBe("Exceeding!"); + } + } + + [Fact] + [Trait("Feat", "37")] + public async Task Should_call_middleware_withWhitelistClient() + { + // Arrange + var route = new RouteBuilder() + .WithDownstreamRoute(new DownstreamRouteBuilder() + .WithEnableRateLimiting(true) + .WithRateLimitOptions(new( + enableRateLimiting: true, + clientIdHeader: "ClientId", + getClientWhitelist: () => new List { "ocelotclient2" }, + disableRateLimitHeaders: false, + quotaExceededMessage: "Exceeding!", + rateLimitCounterPrefix: string.Empty, + new RateLimitRule("1s", 100.0D, 3), + (int)HttpStatusCode.TooManyRequests)) + .WithUpstreamHttpMethod(new() { "Get" }) + .Build()) + .WithUpstreamHttpMethod(new() { "Get" }) + .Build(); + var downstreamRoute = new _DownstreamRouteHolder_(new(), route); + + // Act + await WhenICallTheMiddlewareWithWhiteClient(downstreamRoute); + + // Assert + _downstreamResponses.ForEach(dsr => dsr.ShouldBeNull()); + } + + [Fact] + [Trait("Bug", "1590")] + public async Task MiddlewareInvoke_PeriodTimespanValueIsGreaterThanPeriod_StatusNotEqualTo429() + { + // Arrange + const long limit = 100L; + var upstreamTemplate = new UpstreamPathTemplateBuilder() + .Build(); + var downstreamRoute = new DownstreamRouteBuilder() + .WithEnableRateLimiting(true) + .WithRateLimitOptions(new( + enableRateLimiting: true, + clientIdHeader: "ClientId", + getClientWhitelist: () => new List(), + disableRateLimitHeaders: false, + quotaExceededMessage: "Exceeding!", + rateLimitCounterPrefix: string.Empty, + new RateLimitRule("1s", 30.0D, limit), // bug scenario + (int)HttpStatusCode.TooManyRequests)) + .WithUpstreamHttpMethod(new() { "Get" }) + .WithUpstreamPathTemplate(upstreamTemplate) + .Build(); + var route = new RouteBuilder() + .WithDownstreamRoute(downstreamRoute) + .WithUpstreamHttpMethod(new() { "Get" }) + .Build(); + var downstreamRouteHolder = new _DownstreamRouteHolder_(new(), route); + + // Act, Assert: 100 requests must be successful + var contexts = await WhenICallTheMiddlewareMultipleTimes(limit, downstreamRouteHolder); // make 100 requests, but not exceed the limit + _downstreamResponses.ForEach(dsr => dsr.ShouldBeNull()); + contexts.ForEach(ctx => + { + ctx.ShouldNotBeNull(); + ctx.Items.Errors().ShouldNotBeNull().ShouldBeEmpty(); // no errors + ctx.Response.StatusCode.ShouldBe((int)HttpStatusCode.OK); // not 429 aka TooManyRequests + }); + + // Act, Assert: the next 101st request should fail + contexts = await WhenICallTheMiddlewareMultipleTimes(1, downstreamRouteHolder); + _downstreamResponses.ShouldNotBeNull(); + var ds = _downstreamResponses.SingleOrDefault().ShouldNotBeNull(); + ds.StatusCode.ShouldBe(HttpStatusCode.TooManyRequests, $"Downstream Response no {limit + 1}"); + var body = await ds.Content.ReadAsStringAsync(); + body.ShouldBe("Exceeding!"); + contexts[0].Items.Errors().ShouldNotBeNull().ShouldNotBeEmpty(); // having errors + contexts[0].Items.Errors().Single().HttpStatusCode.ShouldBe((int)HttpStatusCode.TooManyRequests); + } + + private async Task> WhenICallTheMiddlewareMultipleTimes(long times, _DownstreamRouteHolder_ downstreamRoute) + { + var contexts = new List(); + _downstreamResponses.Clear(); + for (var i = 0; i < times; i++) + { + var context = new DefaultHttpContext(); + var stream = GetFakeStream($"{i}"); + context.Response.Body = stream; + context.Response.RegisterForDispose(stream); + context.Items.UpsertDownstreamRoute(downstreamRoute.Route.DownstreamRoute[0]); + context.Items.UpsertTemplatePlaceholderNameAndValues(downstreamRoute.TemplatePlaceholderNameAndValues); + context.Items.UpsertDownstreamRoute(downstreamRoute); + var request = new HttpRequestMessage(new HttpMethod("GET"), _url); + context.Items.UpsertDownstreamRequest(new DownstreamRequest(request)); + context.Request.Headers.TryAdd("ClientId", "ocelotclient1"); + contexts.Add(context); + + await _middleware.Invoke(context); + + _downstreamResponses.Add(context.Items.DownstreamResponse()); + } + + return contexts; + } + + private static Stream GetFakeStream(string str) + { + byte[] data = Encoding.ASCII.GetBytes(str); + return new MemoryStream(data, 0, data.Length); + } + + private async Task WhenICallTheMiddlewareWithWhiteClient(_DownstreamRouteHolder_ downstreamRoute) + { + const string ClientId = "ocelotclient2"; + for (var i = 0; i < 10; i++) + { + var context = new DefaultHttpContext(); + var stream = GetFakeStream($"{i}"); + context.Response.Body = stream; + context.Response.RegisterForDispose(stream); + context.Items.UpsertDownstreamRoute(downstreamRoute.Route.DownstreamRoute[0]); + context.Items.UpsertTemplatePlaceholderNameAndValues(downstreamRoute.TemplatePlaceholderNameAndValues); + context.Items.UpsertDownstreamRoute(downstreamRoute); + var request = new HttpRequestMessage(new HttpMethod("GET"), _url); + request.Headers.Add("ClientId", ClientId); + context.Items.UpsertDownstreamRequest(new DownstreamRequest(request)); + context.Request.Headers.TryAdd("ClientId", ClientId); + + await _middleware.Invoke(context); + + _downstreamResponses.Add(context.Items.DownstreamResponse()); + } + } +} diff --git a/test/Ocelot.UnitTests/RateLimiting/RateLimitingTests.cs b/test/Ocelot.UnitTests/RateLimiting/RateLimitingTests.cs new file mode 100644 index 000000000..a4eb4738e --- /dev/null +++ b/test/Ocelot.UnitTests/RateLimiting/RateLimitingTests.cs @@ -0,0 +1,268 @@ +using Microsoft.AspNetCore.Http; +using Ocelot.Configuration; +using Ocelot.Configuration.Builder; +using Ocelot.RateLimiting; +using System.Runtime.CompilerServices; +using _RateLimiting_ = Ocelot.RateLimiting.RateLimiting; + +namespace Ocelot.UnitTests.RateLimiting; + +public sealed class RateLimitingTests +{ + private readonly Mock _storage; + private readonly _RateLimiting_ _sut; + + public RateLimitingTests() + { + _storage = new(); + _sut = new(_storage.Object); + } + + [Theory] + [Trait("Feat", "37")] + [InlineData(null)] + [InlineData("")] + public void ToTimespan_EmptyValue_ShouldReturnZero(string empty) + { + // Arrange, Act + var actual = _sut.ToTimespan(empty); + + // Assert + Assert.Equal(TimeSpan.Zero, actual); + } + + [Theory] + [Trait("Feat", "37")] + [InlineData("1a")] + [InlineData("2unknown")] + public void ToTimespan_UnknownType_ShouldThrowFormatException(string timespan) + { + // Arrange, Act, Assert + Assert.Throws( + () => _sut.ToTimespan(timespan)); + } + + [Theory] + [Trait("Feat", "37")] + [InlineData("1s", 1 * TimeSpan.TicksPerSecond)] + [InlineData("2m", 2 * TimeSpan.TicksPerMinute)] + [InlineData("3h", 3 * TimeSpan.TicksPerHour)] + [InlineData("4d", 4 * TimeSpan.TicksPerDay)] + public void ToTimespan_KnownType_HappyPath(string timespan, long ticks) + { + // Arrange + var expected = TimeSpan.FromTicks(ticks); + + // Act + var actual = _sut.ToTimespan(timespan); + + // Assert + Assert.Equal(expected, actual); + } + + [Fact] + [Trait("PR", "1592")] + public void Count_NoEntry_StartCounting() + { + // Arrange + RateLimitCounter? arg1 = null; // No Entry + RateLimitRule arg2 = null; + + // Act + RateLimitCounter actual = _sut.Count(arg1, arg2); + + // Assert + Assert.Equal(1L, actual.TotalRequests); + Assert.True(DateTime.UtcNow - actual.StartedAt < TimeSpan.FromSeconds(1.0D)); + } + + [Fact] + [Trait("PR", "1592")] + public void Count_EntryHasNotExpired_IncrementedRequestCount() + { + // Arrange + long total = 2; + RateLimitCounter? arg1 = new RateLimitCounter(DateTime.UtcNow, null, total); // entry has not expired + RateLimitRule arg2 = new("1s", 1.0D, total + 1); // with not exceeding limit + + // Act + RateLimitCounter actual = _sut.Count(arg1, arg2); + + // Assert + Assert.Equal(total + 1, actual.TotalRequests); // incremented request count + Assert.Equal(arg1.Value.StartedAt, actual.StartedAt); // starting point has not changed + } + + [Fact] + [Trait("PR", "1592")] + public void Count_EntryHasNotExpiredAndExceedingLimit_IncrementedRequestCountWithRenewedStartMoment() + { + // Arrange + long total = 2; + RateLimitCounter? arg1 = new RateLimitCounter(DateTime.UtcNow, null, total); // entry has not expired + RateLimitRule arg2 = new("1s", 1.0D, 1L); + + // Act + RateLimitCounter actual = _sut.Count(arg1, arg2); + + // Assert + Assert.Equal(total + 1, actual.TotalRequests); // incremented request count + Assert.InRange(actual.StartedAt, arg1.Value.StartedAt, DateTime.UtcNow); // starting point has renewed and it is between StartedAt and Now + } + + [Fact] + [Trait("PR", "1592")] + public void Count_RateLimitExceeded_StartedCounting() + { + // Arrange + long total = 3, limit = total - 1; + TimeSpan periodTimespan = TimeSpan.FromSeconds(1.0D); + DateTime startedAt = DateTime.UtcNow.AddSeconds(-2.0), // 2 secs ago + exceededAt = startedAt + periodTimespan; // 1 second ago + RateLimitCounter? arg1 = new RateLimitCounter(startedAt, exceededAt, total); // Entry has expired + RateLimitRule arg2 = new("1s", periodTimespan.TotalSeconds, limit); // rate limit exceeded + + // Act + RateLimitCounter actual = _sut.Count(arg1, arg2); + + // Assert + Assert.Equal(1L, actual.TotalRequests); // started counting, the counter was changed + Assert.InRange(actual.StartedAt, arg1.Value.ExceededAt.Value, DateTime.UtcNow); // starting point has renewed and it is between exceededAt and Now + } + + [Fact] + [Trait("PR", "1592")] + public void Count_RateLimitNotExceededAndPeriodIsElapsed_StartedCountingByDefault() + { + // Arrange + long total = 3, limit = 3; + RateLimitCounter? arg1 = new RateLimitCounter(DateTime.UtcNow.AddSeconds(-2.0), null, total); // Entry has expired + RateLimitRule arg2 = new("1s", 1.0D, limit); // Rate limit not exceeded + + // Act + RateLimitCounter actual = _sut.Count(arg1, arg2); + + // Assert + Assert.Equal(1L, actual.TotalRequests); // started counting + Assert.True(DateTime.UtcNow - actual.StartedAt < TimeSpan.FromSeconds(1.0D)); // started now + } + + [Fact] + [Trait("PR", "1592")] + public void ProcessRequest_RateLimitExceededAndBanPeriodElapsed_StartedCounting() + { + // Arrange + const double periodTimespan = 2.0D; + const int millisecondsBeforeAfterEnding = 100; // current processing time of unit test should not take more 100 ms + DateTime now = DateTime.UtcNow, + startedAt = now.AddSeconds(-3).AddMilliseconds(millisecondsBeforeAfterEnding); + DateTime? exceededAt = null; + long totalRequests = 2L; + TimeSpan expiration = TimeSpan.Zero; + + var (identity, options) = SetupProcessRequest("3s", periodTimespan, totalRequests, + () => new RateLimitCounter(startedAt, exceededAt, totalRequests), + (value) => expiration = value); + + // Act 1 + var counter = _sut.ProcessRequest(identity, options); + + // Assert 1 + Assert.Equal(3L, counter.TotalRequests); // old counting -> 3 + Assert.Equal(startedAt, counter.StartedAt); // starting point was not changed + Assert.NotNull(counter.ExceededAt); // exceeded + Assert.Equal(DateTime.UtcNow.Second, counter.ExceededAt.Value.Second); // exceeded now, in the same second + + // Arrange 2 + TimeSpan shift = TimeSpan.FromSeconds(periodTimespan); // don't wait, just move to future + startedAt = counter.StartedAt - shift; // move to past + exceededAt = counter.ExceededAt - shift; // move to past + totalRequests = counter.TotalRequests; // 3 + + // Act 2 + var actual = _sut.ProcessRequest(identity, options); + + // Assert + Assert.Equal(1L, actual.TotalRequests); // started counting + Assert.InRange(actual.StartedAt, now, DateTime.UtcNow); // starting point has renewed and it is between test starting and Now + Assert.Null(actual.ExceededAt); + _storage.Verify(x => x.Remove(It.IsAny()), + Times.Never()); // Once()? Seems Remove is never called because of renewing + _storage.Verify(x => x.Get(It.IsAny()), + Times.Exactly(2)); + _storage.Verify(x => x.Set(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Exactly(2)); + Assert.Equal(TimeSpan.FromSeconds(3), expiration); + } + + private (ClientRequestIdentity Identity, RateLimitOptions Options) SetupProcessRequest(string period, double periodTimespan, long limit, + Func counterFactory, Action expirationAction, [CallerMemberName] string testName = "") + { + ClientRequestIdentity identity = new(nameof(RateLimitingTests), "/" + testName, HttpMethods.Get); + RateLimitOptions options = new RateLimitOptionsBuilder() + .WithEnableRateLimiting(true) + .WithRateLimitCounterPrefix(nameof(_RateLimiting_.ProcessRequest)) + .WithRateLimitRule(new RateLimitRule(period, periodTimespan, limit)) + .Build(); + _storage.Setup(x => x.Get(It.IsAny())) + .Returns(counterFactory); // counter value factory + _storage.Setup(x => x.Remove(It.IsAny())) + .Verifiable(); + expirationAction?.Invoke(TimeSpan.Zero); + _storage.Setup(x => x.Set(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((id, counter, expirationTime) => expirationAction?.Invoke(expirationTime)) + .Verifiable(); + return (identity, options); + } + + [Fact] + [Trait("Bug", "1590")] + public void ProcessRequest_PeriodTimespanValueIsGreaterThanPeriod_ExpectedBehaviorAndExpirationInPeriod() + { + // Arrange: user scenario + const string period = "1s"; + const double periodTimespan = 30.0D; // seconds + const long limit = 100L, requestsPerSecond = 20L; + + // Arrange: setup + DateTime? startedAt = null; + TimeSpan expiration = TimeSpan.Zero; + long total = 1L, count = requestsPerSecond; + RateLimitCounter? current = null; + var (identity, options) = SetupProcessRequest(period, periodTimespan, limit, + () => current, + (value) => expiration = value); + + // Arrange 20 requests per period (1 sec) + var periodSeconds = TimeSpan.FromSeconds(double.Parse(period[0].ToString())); + var periodMilliseconds = periodSeconds.TotalMilliseconds; + int delay = (int)((periodMilliseconds - 200) / requestsPerSecond); // 20 requests per 1 second + + while (count > 0L) + { + // Act + var actual = _sut.ProcessRequest(identity, options); + + // life hack for the 1st request + if (count == requestsPerSecond) + { + startedAt = actual.StartedAt; // for the 1st request get expected value + } + + // Assert + Assert.True(actual.TotalRequests < limit); + actual.TotalRequests.ShouldBe(total++, $"Count is {count}"); + Assert.Equal(startedAt, actual.StartedAt); // starting point is not changed + Assert.Null(actual.ExceededAt); // no exceeding at all + Assert.Equal(periodSeconds, expiration); // expiration in the period + + // Arrange: next micro test + current = actual; + Thread.Sleep(delay); + count--; + } + + Assert.NotEqual(TimeSpan.FromSeconds(periodTimespan), expiration); // Not ban period expiration + Assert.Equal(periodSeconds, expiration); // last 20th request was in counting period + } +}