Skip to content

Commit

Permalink
Correct cancellation token usage.
Browse files Browse the repository at this point in the history
The dispatch service used its internal cancellation token source
incorrectly, leading to deadlocking behaviour and lost events on
shutdown. This change removes the usage from all but the intended one
(to signal responders when they should cancel) and moves the stop
responsibility to the data channels instead.

Fixes Remora#305.
  • Loading branch information
Nihlus authored and VelvetToroyashi committed Jun 17, 2023
1 parent dd9010f commit 538db5a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 78 deletions.
1 change: 1 addition & 0 deletions .idea/.idea.Remora.Discord/.idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

143 changes: 65 additions & 78 deletions Backend/Remora.Discord.Gateway/Services/ResponderDispatchService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,18 @@ public class ResponderDispatchService : IAsyncDisposable, IResponderDispatchServ
private readonly IResponderTypeRepository _responderTypeRepository;

private readonly Dictionary<Type, Type> _cachedInterfaceTypeArguments;
private readonly Dictionary<Type, Func<IPayload, CancellationToken, Task<IReadOnlyList<Result>>>> _cachedDispatchDelegates;
private readonly CancellationTokenSource _dispatchCancellationSource;
private readonly Dictionary<Type, Func<IPayload, Task<IReadOnlyList<Result>>>> _cachedDispatchDelegates;
private readonly Task _dispatcher;
private readonly Task _finalizer;
private readonly Channel<IPayload> _payloadsToDispatch;
private readonly Channel<Task<IReadOnlyList<Result>>> _respondersToFinalize;

/// <summary>
/// Holds the token source used to get tokens for running responders. Execution of the dispatch service's own tasks
/// is controlled via the channels.
/// </summary>
private readonly CancellationTokenSource _responderCancellationSource;

private bool _isDisposed;

/// <summary>
Expand All @@ -81,7 +86,7 @@ IOptions<ResponderDispatchOptions> options
_cachedInterfaceTypeArguments = new();
_cachedDispatchDelegates = new();

_dispatchCancellationSource = new();
_responderCancellationSource = new();
_payloadsToDispatch = Channel.CreateBounded<IPayload>
(
new BoundedChannelOptions((int)_options.MaxItems)
Expand All @@ -99,8 +104,8 @@ IOptions<ResponderDispatchOptions> options
}
);

_dispatcher = Task.Run(DispatcherTaskAsync, _dispatchCancellationSource.Token);
_finalizer = Task.Run(FinalizerTaskAsync, _dispatchCancellationSource.Token);
_dispatcher = Task.Run(DispatcherTaskAsync, CancellationToken.None);
_finalizer = Task.Run(FinalizerTaskAsync, CancellationToken.None);
}

/// <summary>
Expand Down Expand Up @@ -138,31 +143,24 @@ public async Task<Result> DispatchAsync(IPayload payload, CancellationToken ct =
/// </summary>
private async Task DispatcherTaskAsync()
{
while (!_dispatchCancellationSource.Token.IsCancellationRequested)
try
{
var payload = await _payloadsToDispatch.Reader.ReadAsync(_dispatchCancellationSource.Token);
var dispatch = UnwrapAndDispatchEvent(payload, _dispatchCancellationSource.Token);
if (!dispatch.IsSuccess)
while (await _payloadsToDispatch.Reader.WaitToReadAsync())
{
_log.LogWarning("Failed to dispatch payload: {Reason}", dispatch.Error.Message);
continue;
}
var payload = await _payloadsToDispatch.Reader.ReadAsync();
var dispatch = UnwrapAndDispatchEvent(payload);
if (!dispatch.IsSuccess)
{
_log.LogWarning("Failed to dispatch payload: {Reason}", dispatch.Error.Message);
continue;
}

await _respondersToFinalize.Writer.WriteAsync(dispatch.Entity, _dispatchCancellationSource.Token);
await _respondersToFinalize.Writer.WriteAsync(dispatch.Entity);
}
}

// Finish up remaining dispatches
await _payloadsToDispatch.Reader.Completion;
await foreach (var payload in _payloadsToDispatch.Reader.ReadAllAsync())
catch (Exception ex) when (ex is OperationCanceledException or ChannelClosedException)
{
var dispatch = UnwrapAndDispatchEvent(payload, _dispatchCancellationSource.Token);
if (!dispatch.IsSuccess)
{
_log.LogWarning("Failed to dispatch payload: {Reason}", dispatch.Error.Message);
continue;
}

await _respondersToFinalize.Writer.WriteAsync(dispatch.Entity, _dispatchCancellationSource.Token);
// this is fine, no further incoming payloads to accept
}

_respondersToFinalize.Writer.Complete();
Expand All @@ -173,7 +171,7 @@ private async Task DispatcherTaskAsync()
/// </summary>
private async Task FinalizerTaskAsync()
{
if (_dispatchCancellationSource is null)
if (_responderCancellationSource is null)
{
throw new InvalidOperationException();
}
Expand All @@ -183,43 +181,47 @@ private async Task FinalizerTaskAsync()
throw new InvalidOperationException();
}

while (!_dispatchCancellationSource.Token.IsCancellationRequested)
try
{
var responderResults = await _respondersToFinalize.Reader.ReadAsync
(
_dispatchCancellationSource.Token
);

if (!responderResults.IsCompleted)
while (await _respondersToFinalize.Reader.WaitToReadAsync())
{
var timeout = Task.Delay(TimeSpan.FromMilliseconds(10));

var finishedTask = await Task.WhenAny(responderResults, timeout);
if (finishedTask == timeout)
var responderResults = await _respondersToFinalize.Reader.ReadAsync();
if (!responderResults.IsCompleted)
{
// This responder is taking too long... put it back on the channel and look at some other stuff
// in the meantime.
await _respondersToFinalize.Writer.WriteAsync(responderResults, _dispatchCancellationSource.Token);
continue;
var timeout = Task.Delay(TimeSpan.FromMilliseconds(10));

var finishedTask = await Task.WhenAny(responderResults, timeout);
if (finishedTask == timeout)
{
// This responder is taking too long... put it back on the channel and look at some other stuff
// in the meantime.
try
{
await _respondersToFinalize.Writer.WriteAsync(responderResults);
continue;
}
catch (ChannelClosedException)
{
// Okay, we can't put it back on, so we'll just drop out and await it. It should be the last
// item in the pipe anyway
}
}
}
}

FinalizeResponderDispatch(await responderResults);
FinalizeResponderDispatch(await responderResults);
}
}

await _respondersToFinalize.Reader.Completion;
await foreach (var responderResults in _respondersToFinalize.Reader.ReadAllAsync())
catch (Exception ex) when (ex is OperationCanceledException or ChannelClosedException)
{
FinalizeResponderDispatch(await responderResults);
// this is fine, nothing further to do
}
}

/// <summary>
/// Unwraps the given payload into its typed representation, dispatching all events for it.
/// </summary>
/// <param name="payload">The payload.</param>
/// <param name="ct">The cancellation token for the dispatched event.</param>
private Result<Task<IReadOnlyList<Result>>> UnwrapAndDispatchEvent(IPayload payload, CancellationToken ct = default)
private Result<Task<IReadOnlyList<Result>>> UnwrapAndDispatchEvent(IPayload payload)
{
var payloadType = payload.GetType();

Expand Down Expand Up @@ -269,15 +271,14 @@ private Result<Task<IReadOnlyList<Result>>> UnwrapAndDispatchEvent(IPayload payl
throw new MissingMethodException(nameof(DiscordGatewayClient), nameof(DispatchEventAsync));
}

var delegateType = typeof(Func<,,>).MakeGenericType
var delegateType = typeof(Func<,>).MakeGenericType
(
typeof(IPayload<>).MakeGenericType(interfaceArgument),
typeof(CancellationToken),
typeof(Task<IReadOnlyList<Result>>)
);

// Naughty unsafe cast, because we know we're calling it with compatible types in this method
dispatchDelegate = Unsafe.As<Func<IPayload, CancellationToken, Task<IReadOnlyList<Result>>>>
dispatchDelegate = Unsafe.As<Func<IPayload, Task<IReadOnlyList<Result>>>>
(
dispatchMethod
.MakeGenericMethod(interfaceArgument)
Expand All @@ -287,7 +288,9 @@ private Result<Task<IReadOnlyList<Result>>> UnwrapAndDispatchEvent(IPayload payl
_cachedDispatchDelegates.Add(interfaceArgument, dispatchDelegate);
}

var responderTask = Task.Run(() => dispatchDelegate(payload, ct), ct);
// Don't use the cancellation token here; we want the task to always run and let the responder decide when to
// actually cancel
var responderTask = Task.Run(() => dispatchDelegate(payload), CancellationToken.None);

return responderTask;
}
Expand All @@ -296,13 +299,8 @@ private Result<Task<IReadOnlyList<Result>>> UnwrapAndDispatchEvent(IPayload payl
/// Dispatches the given event to all relevant gateway event responders.
/// </summary>
/// <param name="gatewayEvent">The event to dispatch.</param>
/// <param name="ct">The cancellation token to use.</param>
/// <typeparam name="TGatewayEvent">The gateway event.</typeparam>
private async Task<IReadOnlyList<Result>> DispatchEventAsync<TGatewayEvent>
(
IPayload<TGatewayEvent> gatewayEvent,
CancellationToken ct = default
)
private async Task<IReadOnlyList<Result>> DispatchEventAsync<TGatewayEvent>(IPayload<TGatewayEvent> gatewayEvent)
where TGatewayEvent : IGatewayEvent
{
// Batch up the responders according to their groups
Expand All @@ -329,7 +327,7 @@ private async Task<IReadOnlyList<Result>> DispatchEventAsync<TGatewayEvent>
var responder = (IResponder<TGatewayEvent>)serviceScope.ServiceProvider
.GetRequiredService(rt);
return await responder.RespondAsync(gatewayEvent.Data, ct);
return await responder.RespondAsync(gatewayEvent.Data, _responderCancellationSource.Token);
}
catch (Exception e)
{
Expand Down Expand Up @@ -423,26 +421,15 @@ public async ValueTask DisposeAsync()

GC.SuppressFinalize(this);

// Stop!
_dispatchCancellationSource.Cancel();
_payloadsToDispatch.Writer.Complete();
// Signal running responders that they should cancel
_responderCancellationSource.Cancel();

// Wait for everything to actually stop...
try
{
await _dispatcher;
}
catch (OperationCanceledException)
{
}
// Prevent further payloads from being written, signalling the readers that they should terminate
_payloadsToDispatch.Writer.Complete();

try
{
await _finalizer;
}
catch (OperationCanceledException)
{
}
// Wait for everything to actually stop
await _dispatcher;
await _finalizer;

_isDisposed = true;
}
Expand Down

0 comments on commit 538db5a

Please sign in to comment.