Skip to content

Commit

Permalink
Changing IRequest to just be a normal Task instead of Task<Unit>
Browse files Browse the repository at this point in the history
  • Loading branch information
jbogard committed Feb 9, 2023
1 parent 528a6cc commit 761fb0b
Show file tree
Hide file tree
Showing 20 changed files with 331 additions and 176 deletions.
4 changes: 2 additions & 2 deletions samples/MediatR.Examples/JingHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace MediatR.Examples;

public class JingHandler : AsyncRequestHandler<Jing>
public class JingHandler : IRequestHandler<Jing>
{
private readonly TextWriter _writer;

Expand All @@ -13,7 +13,7 @@ public JingHandler(TextWriter writer)
_writer = writer;
}

protected override Task Handle(Jing request, CancellationToken cancellationToken)
public Task Handle(Jing request, CancellationToken cancellationToken)
{
return _writer.WriteLineAsync($"--- Handled Jing: {request.Message}, no Jong");
}
Expand Down
2 changes: 1 addition & 1 deletion src/MediatR.Contracts/IRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ namespace MediatR;
/// <summary>
/// Marker interface to represent a request with a void response
/// </summary>
public interface IRequest : IRequest<Unit> { }
public interface IRequest : IBaseRequest { }

/// <summary>
/// Marker interface to represent a request with a response
Expand Down
2 changes: 1 addition & 1 deletion src/MediatR.Contracts/MediatR.Contracts.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<SymbolPackageFormat>snupkg</SymbolPackageFormat>
<EmbedUntrackedSources>true</EmbedUntrackedSources>
<Deterministic>true</Deterministic>
<Version>1.0.1</Version>
<Version>2.0.0</Version>
<RootNamespace>MediatR</RootNamespace>

</PropertyGroup>
Expand Down
65 changes: 7 additions & 58 deletions src/MediatR/IRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,68 +21,17 @@ public interface IRequestHandler<in TRequest, TResponse>
}

/// <summary>
/// Defines a handler for a request with a void (<see cref="Unit" />) response.
/// You do not need to register this interface explicitly with a container as it inherits from the base <see cref="IRequestHandler{TRequest, TResponse}" /> interface.
/// Defines a handler for a request with a void response.
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
public interface IRequestHandler<in TRequest> : IRequestHandler<TRequest, Unit>
where TRequest : IRequest<Unit>
{
}

/// <summary>
/// Wrapper class for a handler that asynchronously handles a request and does not return a response
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
public abstract class AsyncRequestHandler<TRequest> : IRequestHandler<TRequest>
public interface IRequestHandler<in TRequest>
where TRequest : IRequest
{
async Task<Unit> IRequestHandler<TRequest, Unit>.Handle(TRequest request, CancellationToken cancellationToken)
{
await Handle(request, cancellationToken).ConfigureAwait(false);
return Unit.Value;
}

/// <summary>
/// Override in a derived class for the handler logic
/// </summary>
/// <param name="request">Request</param>
/// <param name="cancellationToken"></param>
/// <returns>Response</returns>
protected abstract Task Handle(TRequest request, CancellationToken cancellationToken);
}

/// <summary>
/// Wrapper class for a handler that synchronously handles a request and returns a response
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
/// <typeparam name="TResponse">The type of response from the handler</typeparam>
public abstract class RequestHandler<TRequest, TResponse> : IRequestHandler<TRequest, TResponse>
where TRequest : IRequest<TResponse>
{
Task<TResponse> IRequestHandler<TRequest, TResponse>.Handle(TRequest request, CancellationToken cancellationToken)
=> Task.FromResult(Handle(request));

/// <summary>
/// Override in a derived class for the handler logic
/// Handles a request
/// </summary>
/// <param name="request">Request</param>
/// <returns>Response</returns>
protected abstract TResponse Handle(TRequest request);
}

/// <summary>
/// Wrapper class for a handler that synchronously handles a request and does not return a response
/// </summary>
/// <typeparam name="TRequest">The type of request being handled</typeparam>
public abstract class RequestHandler<TRequest> : IRequestHandler<TRequest>
where TRequest : IRequest
{
Task<Unit> IRequestHandler<TRequest, Unit>.Handle(TRequest request, CancellationToken cancellationToken)
{
Handle(request);
return Unit.Task;
}

protected abstract void Handle(TRequest request);
/// <param name="request">The request</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Response from the request</returns>
Task Handle(TRequest request, CancellationToken cancellationToken);
}
9 changes: 9 additions & 0 deletions src/MediatR/ISender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ public interface ISender
/// <returns>A task that represents the send operation. The task result contains the handler response</returns>
Task<TResponse> Send<TResponse>(IRequest<TResponse> request, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously send a request to a single handler with no response
/// </summary>
/// <param name="request">Request object</param>
/// <param name="cancellationToken">Optional cancellation token</param>
/// <returns>A task that represents the send operation.</returns>
Task Send<TRequest>(TRequest request, CancellationToken cancellationToken = default)
where TRequest : IRequest;

/// <summary>
/// Asynchronously send an object request to a single handler via dynamic dispatch
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion src/MediatR/MediatR.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="MediatR.Contracts" Version="1.0.1" />
<PackageReference Include="MediatR.Contracts" Version="[2.0.0, 3.0.0)" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="7.0.0" />
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.1.1" PrivateAssets="All" />
<PackageReference Include="MinVer" Version="4.2.0" PrivateAssets="All" />
</ItemGroup>

</Project>
41 changes: 37 additions & 4 deletions src/MediatR/Mediator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ public Task<TResponse> Send<TResponse>(IRequest<TResponse> request, Cancellation
return handler.Handle(request, _serviceProvider, cancellationToken);
}

public Task Send<TRequest>(TRequest request, CancellationToken cancellationToken = default)
where TRequest : IRequest
{
if (request == null)
{
throw new ArgumentNullException(nameof(request));
}

var requestType = typeof(TRequest);

var handler = (RequestHandlerWrapper)_requestHandlers.GetOrAdd(requestType,
static t => (RequestHandlerBase)(Activator.CreateInstance(typeof(RequestHandlerWrapperImpl<>).MakeGenericType(t))
?? throw new InvalidOperationException($"Could not create wrapper type for {t}")));

return handler.Handle(request, _serviceProvider, cancellationToken);
}

public Task<object?> Send(object request, CancellationToken cancellationToken = default)
{
if (request == null)
Expand All @@ -55,13 +72,29 @@ public Task<TResponse> Send<TResponse>(IRequest<TResponse> request, Cancellation
.GetInterfaces()
.FirstOrDefault(static i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IRequest<>));

Type wrapperType;

if (requestInterfaceType is null)
{
throw new ArgumentException($"{requestTypeKey.Name} does not implement {nameof(IRequest)}", nameof(request));
requestInterfaceType = requestTypeKey
.GetInterfaces()
.FirstOrDefault(static i => i == typeof(IRequest));

if (requestInterfaceType is null)
{
throw new ArgumentException($"{requestTypeKey.Name} does not implement {nameof(IRequest)}",
nameof(request));
}

wrapperType =
typeof(RequestHandlerWrapperImpl<>).MakeGenericType(requestTypeKey);
}
else
{
var responseType = requestInterfaceType.GetGenericArguments()[0];
wrapperType =
typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType);
}

var responseType = requestInterfaceType.GetGenericArguments()[0];
var wrapperType = typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType);

return (RequestHandlerBase)(Activator.CreateInstance(wrapperType)
?? throw new InvalidOperationException($"Could not create wrapper for type {wrapperType}"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using MediatR;
using MediatR.Pipeline;
using MediatR.Registration;
Expand Down
1 change: 1 addition & 0 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public static void AddMediatRClasses(IServiceCollection services, MediatRService
var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();

ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestPreProcessor<>), services, assembliesToScan, true, configuration);
Expand Down
48 changes: 40 additions & 8 deletions src/MediatR/Wrappers/RequestHandlerWrapper.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
using System;
using Microsoft.Extensions.DependencyInjection;

namespace MediatR.Wrappers;

using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;

namespace MediatR.Wrappers;

public abstract class RequestHandlerBase
{
public abstract Task<object?> Handle(object request, IServiceProvider serviceProvider,
CancellationToken cancellationToken);

}

public abstract class RequestHandlerWrapper<TResponse> : RequestHandlerBase
Expand All @@ -20,21 +18,55 @@ public abstract Task<TResponse> Handle(IRequest<TResponse> request, IServiceProv
CancellationToken cancellationToken);
}

public abstract class RequestHandlerWrapper : RequestHandlerBase
{
public abstract Task<Unit> Handle(IRequest request, IServiceProvider serviceProvider,
CancellationToken cancellationToken);
}

public class RequestHandlerWrapperImpl<TRequest, TResponse> : RequestHandlerWrapper<TResponse>
where TRequest : IRequest<TResponse>
{
public override async Task<object?> Handle(object request, IServiceProvider serviceProvider,
CancellationToken cancellationToken) =>
await Handle((IRequest<TResponse>)request, serviceProvider, cancellationToken).ConfigureAwait(false);
await Handle((IRequest<TResponse>) request, serviceProvider, cancellationToken).ConfigureAwait(false);

public override Task<TResponse> Handle(IRequest<TResponse> request, IServiceProvider serviceProvider,
CancellationToken cancellationToken)
{
Task<TResponse> Handler() => serviceProvider.GetRequiredService<IRequestHandler<TRequest, TResponse>>().Handle((TRequest) request, cancellationToken);
Task<TResponse> Handler() => serviceProvider.GetRequiredService<IRequestHandler<TRequest, TResponse>>()
.Handle((TRequest) request, cancellationToken);

return serviceProvider
.GetServices<IPipelineBehavior<TRequest, TResponse>>()
.Reverse()
.Aggregate((RequestHandlerDelegate<TResponse>) Handler, (next, pipeline) => () => pipeline.Handle((TRequest)request, next, cancellationToken))();
.Aggregate((RequestHandlerDelegate<TResponse>) Handler,
(next, pipeline) => () => pipeline.Handle((TRequest) request, next, cancellationToken))();
}
}

public class RequestHandlerWrapperImpl<TRequest> : RequestHandlerWrapper
where TRequest : IRequest
{
public override async Task<object?> Handle(object request, IServiceProvider serviceProvider,
CancellationToken cancellationToken) =>
await Handle((IRequest) request, serviceProvider, cancellationToken).ConfigureAwait(false);

public override Task<Unit> Handle(IRequest request, IServiceProvider serviceProvider,
CancellationToken cancellationToken)
{
async Task<Unit> Handler()
{
await serviceProvider.GetRequiredService<IRequestHandler<TRequest>>()
.Handle((TRequest) request, cancellationToken);

return Unit.Value;
}

return serviceProvider
.GetServices<IPipelineBehavior<TRequest, Unit>>()
.Reverse()
.Aggregate((RequestHandlerDelegate<Unit>) Handler,
(next, pipeline) => () => pipeline.Handle((TRequest) request, next, cancellationToken))();
}
}
10 changes: 6 additions & 4 deletions test/MediatR.Tests/ExceptionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public Task<Pong> Handle(NullPing request, CancellationToken cancellationToken)
}
}

public class VoidNullPingHandler : IRequestHandler<VoidNullPing, Unit>
public class VoidNullPingHandler : IRequestHandler<VoidNullPing>
{
public Task<Unit> Handle(VoidNullPing request, CancellationToken cancellationToken)
public Task Handle(VoidNullPing request, CancellationToken cancellationToken)
{
return Unit.Task;
return Task.CompletedTask;
}
}

Expand Down Expand Up @@ -244,7 +244,7 @@ public class PingException : IRequest

public class PingExceptionHandler : IRequestHandler<PingException>
{
public Task<Unit> Handle(PingException request, CancellationToken cancellationToken)
public Task Handle(PingException request, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand All @@ -261,6 +261,7 @@ public async Task Should_throw_exception_for_non_generic_send_when_exception_occ
scanner.IncludeNamespaceContainingType<Ping>();
scanner.WithDefaultConventions();
scanner.AddAllTypesOf(typeof(IRequestHandler<,>));
scanner.AddAllTypesOf(typeof(IRequestHandler<>));
});
cfg.For<IMediator>().Use<Mediator>();
});
Expand Down Expand Up @@ -309,6 +310,7 @@ public async Task Should_throw_exception_for_generic_send_when_exception_occurs(
scanner.IncludeNamespaceContainingType<Ping>();
scanner.WithDefaultConventions();
scanner.AddAllTypesOf(typeof(IRequestHandler<,>));
scanner.AddAllTypesOf(typeof(IRequestHandler<>));
});
cfg.For<IMediator>().Use<Mediator>();
});
Expand Down
15 changes: 8 additions & 7 deletions test/MediatR.Tests/GenericTypeConstraintsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ public class Jing : IRequest
public string? Message { get; set; }
}

public class JingHandler : IRequestHandler<Jing, Unit>
public class JingHandler : IRequestHandler<Jing>
{
public Task<Unit> Handle(Jing request, CancellationToken cancellationToken)
public Task Handle(Jing request, CancellationToken cancellationToken)
{
// empty handle
return Unit.Task;
return Task.CompletedTask;
}
}

Expand Down Expand Up @@ -98,6 +98,7 @@ public GenericTypeConstraintsTests()
scanner.IncludeNamespaceContainingType<Jing>();
scanner.WithDefaultConventions();
scanner.AddAllTypesOf(typeof(IRequestHandler<,>));
scanner.AddAllTypesOf(typeof(IRequestHandler<>));
});
cfg.For<IMediator>().Use<Mediator>();
});
Expand All @@ -119,15 +120,15 @@ public async Task Should_Resolve_Void_Return_Request()

// Assert it is of type IRequest and IRequest<T>
Assert.True(genericTypeConstraintsVoidReturn.IsIRequest);
Assert.True(genericTypeConstraintsVoidReturn.IsIRequestT);
Assert.False(genericTypeConstraintsVoidReturn.IsIRequestT);
Assert.True(genericTypeConstraintsVoidReturn.IsIBaseRequest);

// Verify it is of IRequest and IBaseRequest and IRequest<Unit>
// Verify it is of IRequest and IBaseRequest
var results = genericTypeConstraintsVoidReturn.Handle(jing);

Assert.Equal(3, results.Length);
Assert.Equal(2, results.Length);

results.ShouldContain(typeof(IRequest<Unit>));
results.ShouldNotContain(typeof(IRequest<Unit>));
results.ShouldContain(typeof(IBaseRequest));
results.ShouldContain(typeof(IRequest));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void ShouldResolveRequestHandler()
[Fact]
public void ShouldResolveInternalHandler()
{
_provider.GetService<IRequestHandler<InternalPing, Unit>>().ShouldNotBeNull();
_provider.GetService<IRequestHandler<InternalPing>>().ShouldNotBeNull();
}

[Fact]
Expand Down
Loading

0 comments on commit 761fb0b

Please sign in to comment.