Skip to content

Commit

Permalink
Merge pull request #26 from SimCubeLtd/develop
Browse files Browse the repository at this point in the history
Add Rate Limit Policy Attribute
  • Loading branch information
prom3theu5 authored Oct 15, 2022
2 parents 33e4b48 + 0155b9b commit 2597bed
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 48 deletions.
20 changes: 20 additions & 0 deletions src/SimCube.Spartan/Attributes/RateLimitPolicyAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
namespace SimCube.Spartan.Attributes;

/// <summary>
/// Attribute for Rate Limiting Policy setup on an Endpoint.
/// </summary>
[ExcludeFromCodeCoverage]
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false)]
public sealed class RateLimitPolicyAttribute : Attribute
{
/// <summary>
/// Initializes a new instance of the <see cref="RateLimitPolicyAttribute"/> class.
/// </summary>
/// <param name="policyName">The Policy Name.</param>
public RateLimitPolicyAttribute(string policyName) => PolicyName = policyName;

/// <summary>
/// Gets the type of the http request.
/// </summary>
public string PolicyName { get; }
}
24 changes: 16 additions & 8 deletions src/SimCube.Spartan/Extensions/AttributeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,19 @@ internal static void SetupMediatedRequestEndpointAttributes(this Attribute[]? at
var configureMethod = request.GetMethod(nameof(BaseMediatedRequest.ConfigureEndpoint));
var endpointFilters = request.GetProperty(nameof(BaseMediatedRequest.EndpointFilters));
var cachePolicyName = GetCachedPolicyNameIfHasAttribute(attributes);
var rateLimitPolicyName = GetRateLimitPolicyNameIfHasAttribute(attributes);

CreateMediatedHandler(mediatedRequestAttribute, request, app, configureMethod, endpointFilters, cachePolicyName);
CreateMediatedHandler(mediatedRequestAttribute, request, app, configureMethod, endpointFilters, cachePolicyName, rateLimitPolicyName);
}

private static string? GetCachedPolicyNameIfHasAttribute(Attribute[] attributes) =>
attributes?.FirstOrDefault(x => x is CachePolicyAttribute) is not CachePolicyAttribute cachePolicyAttribute
? null
: cachePolicyAttribute.PolicyName;

private static void CreateMediatedHandler(
MediatedEndpointAttribute mediatedEndpointAttribute,
Type request,
WebApplication app,
MethodBase? configureMethod,
PropertyInfo? endpointFilters,
string? cachedPolicyName)
string? cachedPolicyName,
string? rateLimitPolicyName)
{
var resultType = GetResultType(request);

Expand All @@ -59,7 +56,8 @@ private static void CreateMediatedHandler(
mediatedEndpointAttribute.Route,
configureMethod?.Invoke(FormatterServices.GetUninitializedObject(request), Array.Empty<object>()) as Action<RouteHandlerBuilder>,
endpointFilters?.GetValue(FormatterServices.GetUninitializedObject(request)) as List<IEndpointFilter>,
cachedPolicyName
cachedPolicyName,
rateLimitPolicyName
});
}

Expand All @@ -73,4 +71,14 @@ private static void CreateMediatedHandler(
return Array.Find(request.GetInterfaces(), x => x.GetGenericTypeDefinition() == typeof(IMediatedStream<>))
?.GetGenericArguments().FirstOrDefault();
}

private static string? GetCachedPolicyNameIfHasAttribute(Attribute[] attributes) =>
attributes?.FirstOrDefault(x => x is CachePolicyAttribute) is not CachePolicyAttribute cachePolicyAttribute
? null
: cachePolicyAttribute.PolicyName;

private static string? GetRateLimitPolicyNameIfHasAttribute(Attribute[] attributes) =>
attributes?.FirstOrDefault(x => x is RateLimitPolicyAttribute) is not RateLimitPolicyAttribute rateLimitPolicyAttribute
? null
: rateLimitPolicyAttribute.PolicyName;
}
55 changes: 45 additions & 10 deletions src/SimCube.Spartan/Extensions/MediatedRequestExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ public static class MediatedRequestExtensions
/// <param name="route">the route to map the request on.</param>
/// <param name="configureEndpoint">The optional route handler configuration action for endpoint extension.</param>
/// <param name="endpointFilters">The optional Endpoint filters to chain to the request pipeline.</param>
/// <param name="namedCachePolicy">The named Output Cache Policy to use.</param>
/// <param name="namedCachePolicy">The optional named Output Cache Policy to use.</param>
/// <param name="namedRateLimitPolicy">The optional named Rate Limit Policy to use.</param>
/// <typeparam name="TRequest">The type of the request to map with its parameters.</typeparam>
/// <returns>A type of <see cref="IResult"/>.</returns>
public static WebApplication MediatedGet<TRequest>(
this WebApplication app,
string route,
Action<RouteHandlerBuilder>? configureEndpoint = null,
IReadOnlyCollection<IEndpointFilter>? endpointFilters = default,
string? namedCachePolicy = default)
string? namedCachePolicy = default,
string? namedRateLimitPolicy = default)
where TRequest : IMediatedRequest
{
var builder = app.MapGet(route, async (IMediator mediator, [AsParameters] TRequest request, CancellationToken cancellationToken)
Expand All @@ -37,6 +39,11 @@ public static WebApplication MediatedGet<TRequest>(
builder.CacheOutput(namedCachePolicy);
}

if (namedRateLimitPolicy is not null)
{
builder.RequireRateLimiting(namedRateLimitPolicy);
}

configureEndpoint?.Invoke(builder);

return app;
Expand All @@ -49,15 +56,17 @@ public static WebApplication MediatedGet<TRequest>(
/// <param name="route">the route to map the request on.</param>
/// <param name="configureEndpoint">The optional route handler configuration action for endpoint extension.</param>
/// <param name="endpointFilters">The optional Endpoint filters to chain to the request pipeline.</param>
/// <param name="namedCachePolicy">The named Output Cache Policy to use.</param>
/// <param name="namedCachePolicy">The optional named Output Cache Policy to use.</param>
/// <param name="namedRateLimitPolicy">The optional named Rate Limit Policy to use.</param>
/// <typeparam name="TRequest">The type of the request to map with its parameters.</typeparam>
/// <returns>A type of <see cref="IResult"/>.</returns>
public static WebApplication MediatedPost<TRequest>(
this WebApplication app,
string route,
Action<RouteHandlerBuilder>? configureEndpoint = null,
IReadOnlyCollection<IEndpointFilter>? endpointFilters = default,
string? namedCachePolicy = default)
string? namedCachePolicy = default,
string? namedRateLimitPolicy = default)
where TRequest : IMediatedRequest
{
var builder = app.MapPost(route, async (IMediator mediator, [AsParameters] TRequest request, CancellationToken cancellationToken) =>
Expand All @@ -73,6 +82,11 @@ public static WebApplication MediatedPost<TRequest>(
builder.CacheOutput(namedCachePolicy);
}

if (namedRateLimitPolicy is not null)
{
builder.RequireRateLimiting(namedRateLimitPolicy);
}

configureEndpoint?.Invoke(builder);

return app;
Expand All @@ -85,15 +99,17 @@ public static WebApplication MediatedPost<TRequest>(
/// <param name="route">the route to map the request on.</param>
/// <param name="configureEndpoint">The optional route handler configuration action for endpoint extension.</param>
/// <param name="endpointFilters">The optional Endpoint filters to chain to the request pipeline.</param>
/// <param name="namedCachePolicy">The named Output Cache Policy to use.</param>
/// <param name="namedCachePolicy">The optional named Output Cache Policy to use.</param>
/// <param name="namedRateLimitPolicy">The optional named Rate Limit Policy to use.</param>
/// <typeparam name="TRequest">The type of the request to map with its parameters.</typeparam>
/// <returns>A type of <see cref="IResult"/>.</returns>
public static WebApplication MediatedPut<TRequest>(
this WebApplication app,
string route,
Action<RouteHandlerBuilder>? configureEndpoint = null,
IReadOnlyCollection<IEndpointFilter>? endpointFilters = default,
string? namedCachePolicy = default)
string? namedCachePolicy = default,
string? namedRateLimitPolicy = default)
where TRequest : IMediatedRequest
{
var builder = app.MapPut(route, async (IMediator mediator, [AsParameters] TRequest request, CancellationToken cancellationToken) =>
Expand All @@ -109,6 +125,11 @@ public static WebApplication MediatedPut<TRequest>(
builder.CacheOutput(namedCachePolicy);
}

if (namedRateLimitPolicy is not null)
{
builder.RequireRateLimiting(namedRateLimitPolicy);
}

configureEndpoint?.Invoke(builder);

return app;
Expand All @@ -121,15 +142,17 @@ public static WebApplication MediatedPut<TRequest>(
/// <param name="route">the route to map the request on.</param>
/// <param name="configureEndpoint">The optional route handler configuration action for endpoint extension.</param>
/// <param name="endpointFilters">The optional Endpoint filters to chain to the request pipeline.</param>
/// <param name="namedCachePolicy">The named Output Cache Policy to use.</param>
/// <param name="namedCachePolicy">The optional named Output Cache Policy to use.</param>
/// <param name="namedRateLimitPolicy">The optional named Rate Limit Policy to use.</param>
/// <typeparam name="TRequest">The type of the request to map with its parameters.</typeparam>
/// <returns>A type of <see cref="IResult"/>.</returns>
public static WebApplication MediatedPatch<TRequest>(
this WebApplication app,
string route,
Action<RouteHandlerBuilder>? configureEndpoint = null,
IReadOnlyCollection<IEndpointFilter>? endpointFilters = default,
string? namedCachePolicy = default)
string? namedCachePolicy = default,
string? namedRateLimitPolicy = default)
where TRequest : IMediatedRequest
{
var builder = app.MapPatch(route, async (IMediator mediator, [AsParameters] TRequest request, CancellationToken cancellationToken) =>
Expand All @@ -145,6 +168,11 @@ public static WebApplication MediatedPatch<TRequest>(
builder.CacheOutput(namedCachePolicy);
}

if (namedRateLimitPolicy is not null)
{
builder.RequireRateLimiting(namedRateLimitPolicy);
}

configureEndpoint?.Invoke(builder);

return app;
Expand All @@ -157,15 +185,17 @@ public static WebApplication MediatedPatch<TRequest>(
/// <param name="route">the route to map the request on.</param>
/// <param name="configureEndpoint">The optional route handler configuration action for endpoint extension.</param>
/// <param name="endpointFilters">The optional Endpoint filters to chain to the request pipeline.</param>
/// <param name="namedCachePolicy">The named Output Cache Policy to use.</param>
/// <param name="namedCachePolicy">The optional named Output Cache Policy to use.</param>
/// <param name="namedRateLimitPolicy">The optional named Rate Limit Policy to use.</param>
/// <typeparam name="TRequest">The type of the request to map with its parameters.</typeparam>
/// <returns>A type of <see cref="IResult"/>.</returns>
public static WebApplication MediatedDelete<TRequest>(
this WebApplication app,
string route,
Action<RouteHandlerBuilder>? configureEndpoint = null,
IReadOnlyCollection<IEndpointFilter>? endpointFilters = default,
string? namedCachePolicy = default)
string? namedCachePolicy = default,
string? namedRateLimitPolicy = default)
where TRequest : IMediatedRequest
{
var builder = app.MapDelete(route, async (IMediator mediator, [AsParameters] TRequest request, CancellationToken cancellationToken) =>
Expand All @@ -181,6 +211,11 @@ public static WebApplication MediatedDelete<TRequest>(
builder.CacheOutput(namedCachePolicy);
}

if (namedRateLimitPolicy is not null)
{
builder.RequireRateLimiting(namedRateLimitPolicy);
}

configureEndpoint?.Invoke(builder);

return app;
Expand Down
Loading

0 comments on commit 2597bed

Please sign in to comment.