From a3618c67da8b3fe8c4cbaefd9a354b4438c040ac Mon Sep 17 00:00:00 2001 From: Luke Bakken Date: Thu, 26 Dec 2024 15:57:18 -0800 Subject: [PATCH] Fix cancellation of RPC methods Fixes #1750 * Start by adding a test that demonstrates the error. Give a 5ms cancellation to `BasicConsumeAsync`, with a much longer delay via a hacked RabbitMQ. If running in debug mode, you will see the same `task canceled` exception, but it does not propagate to the test itself. * Set cancellation correctly for TaskCompletionSource in AsyncRpcContinuation --- .../ConsumerDispatcherChannelBase.cs | 30 +++++- .../Impl/AsyncRpcContinuations.cs | 66 ++++++++---- .../Test/Integration/GH/TestGitHubIssues.cs | 100 ++++++++++++++++++ 3 files changed, 171 insertions(+), 25 deletions(-) create mode 100644 projects/Test/Integration/GH/TestGitHubIssues.cs diff --git a/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs b/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs index 92c273957..13f0f5d21 100644 --- a/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs +++ b/projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs @@ -85,11 +85,23 @@ internal ConsumerDispatcherChannelBase(Impl.Channel channel, ushort concurrency) public ValueTask HandleBasicConsumeOkAsync(IAsyncBasicConsumer consumer, string consumerTag, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { - AddConsumer(consumer, consumerTag); - WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag); - return _writer.WriteAsync(work, cancellationToken); + try + { + AddConsumer(consumer, consumerTag); + WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag); + + cancellationToken.ThrowIfCancellationRequested(); + return _writer.WriteAsync(work, cancellationToken); + } + catch + { + _ = GetAndRemoveConsumer(consumerTag); + throw; + } } else { @@ -101,10 +113,14 @@ public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag, string exchange, string routingKey, IReadOnlyBasicProperties basicProperties, RentedMemory body, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { IAsyncBasicConsumer consumer = GetConsumerOrDefault(consumerTag); var work = WorkStruct.CreateDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, body); + + cancellationToken.ThrowIfCancellationRequested(); return _writer.WriteAsync(work, cancellationToken); } else @@ -115,10 +131,14 @@ public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag, public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag); WorkStruct work = WorkStruct.CreateCancelOk(consumer, consumerTag); + + cancellationToken.ThrowIfCancellationRequested(); return _writer.WriteAsync(work, cancellationToken); } else @@ -129,10 +149,14 @@ public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken public ValueTask HandleBasicCancelAsync(string consumerTag, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + if (false == _disposed && false == _quiesce) { IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag); WorkStruct work = WorkStruct.CreateCancel(consumer, consumerTag); + + cancellationToken.ThrowIfCancellationRequested(); return _writer.WriteAsync(work, cancellationToken); } else diff --git a/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs b/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs index 1e0068a48..94085d2b0 100644 --- a/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs +++ b/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs @@ -51,7 +51,7 @@ internal abstract class AsyncRpcContinuation : IRpcContinuation private bool _disposedValue; - public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken cancellationToken) + public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken rpcCancellationToken) { /* * Note: we can't use an ObjectPool for these because the netstandard2.0 @@ -89,7 +89,7 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken canc _tcsConfiguredTaskAwaitable = _tcs.Task.ConfigureAwait(false); _linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource( - _continuationTimeoutCancellationTokenSource.Token, cancellationToken); + _continuationTimeoutCancellationTokenSource.Token, rpcCancellationToken); } public CancellationToken CancellationToken @@ -105,7 +105,27 @@ public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter GetAwaiter() return _tcsConfiguredTaskAwaitable.GetAwaiter(); } - public abstract Task HandleCommandAsync(IncomingCommand cmd); + public async Task HandleCommandAsync(IncomingCommand cmd) + { + try + { + await DoHandleCommandAsync(cmd) + .ConfigureAwait(false); + } + catch (OperationCanceledException) + { + if (CancellationToken.IsCancellationRequested) + { + _tcs.SetCanceled(); + } + else + { + throw; + } + } + } + + protected abstract Task DoHandleCommandAsync(IncomingCommand cmd); public virtual void HandleChannelShutdown(ShutdownEventArgs reason) { @@ -141,17 +161,17 @@ public ConnectionSecureOrTuneAsyncRpcContinuation(TimeSpan continuationTimeout, { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.ConnectionSecure) { var secure = new ConnectionSecure(cmd.MethodSpan); - _tcs.TrySetResult(new ConnectionSecureOrTune(secure._challenge, default)); + _tcs.SetResult(new ConnectionSecureOrTune(secure._challenge, default)); } else if (cmd.CommandId == ProtocolCommandId.ConnectionTune) { var tune = new ConnectionTune(cmd.MethodSpan); - _tcs.TrySetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails + _tcs.SetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails { m_channelMax = tune._channelMax, m_frameMax = tune._frameMax, @@ -178,11 +198,11 @@ public SimpleAsyncRpcContinuation(ProtocolCommandId expectedCommandId, TimeSpan _expectedCommandId = expectedCommandId; } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == _expectedCommandId) { - _tcs.TrySetResult(true); + _tcs.SetResult(true); } else { @@ -206,14 +226,14 @@ public BasicCancelAsyncRpcContinuation(string consumerTag, IConsumerDispatcher c _consumerDispatcher = consumerDispatcher; } - public override async Task HandleCommandAsync(IncomingCommand cmd) + protected override async Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicCancelOk) { - _tcs.TrySetResult(true); Debug.Assert(_consumerTag == new BasicCancelOk(cmd.MethodSpan)._consumerTag); await _consumerDispatcher.HandleBasicCancelOkAsync(_consumerTag, CancellationToken) .ConfigureAwait(false); + _tcs.SetResult(true); } else { @@ -235,14 +255,16 @@ public BasicConsumeAsyncRpcContinuation(IAsyncBasicConsumer consumer, IConsumerD _consumerDispatcher = consumerDispatcher; } - public override async Task HandleCommandAsync(IncomingCommand cmd) + protected override async Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicConsumeOk) { var method = new BasicConsumeOk(cmd.MethodSpan); - _tcs.TrySetResult(method._consumerTag); + await _consumerDispatcher.HandleBasicConsumeOkAsync(_consumer, method._consumerTag, CancellationToken) .ConfigureAwait(false); + + _tcs.SetResult(method._consumerTag); } else { @@ -264,7 +286,7 @@ public BasicGetAsyncRpcContinuation(Func adjustDeliveryTag, internal DateTime StartTime { get; } = DateTime.UtcNow; - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicGetOk) { @@ -280,11 +302,11 @@ public override Task HandleCommandAsync(IncomingCommand cmd) header, cmd.Body.ToArray()); - _tcs.TrySetResult(result); + _tcs.SetResult(result); } else if (cmd.CommandId == ProtocolCommandId.BasicGetEmpty) { - _tcs.TrySetResult(null); + _tcs.SetResult(null); } else { @@ -325,7 +347,7 @@ public override void HandleChannelShutdown(ShutdownEventArgs reason) public Task OnConnectionShutdownAsync(object? sender, ShutdownEventArgs reason) { - _tcs.TrySetResult(true); + _tcs.SetResult(true); return Task.CompletedTask; } } @@ -377,13 +399,13 @@ public QueueDeclareAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellati { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueueDeclareOk) { var method = new Client.Framing.QueueDeclareOk(cmd.MethodSpan); var result = new QueueDeclareOk(method._queue, method._messageCount, method._consumerCount); - _tcs.TrySetResult(result); + _tcs.SetResult(result); } else { @@ -417,12 +439,12 @@ public QueueDeleteAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellatio { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueueDeleteOk) { var method = new QueueDeleteOk(cmd.MethodSpan); - _tcs.TrySetResult(method._messageCount); + _tcs.SetResult(method._messageCount); } else { @@ -440,12 +462,12 @@ public QueuePurgeAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellation { } - public override Task HandleCommandAsync(IncomingCommand cmd) + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueuePurgeOk) { var method = new QueuePurgeOk(cmd.MethodSpan); - _tcs.TrySetResult(method._messageCount); + _tcs.SetResult(method._messageCount); } else { diff --git a/projects/Test/Integration/GH/TestGitHubIssues.cs b/projects/Test/Integration/GH/TestGitHubIssues.cs new file mode 100644 index 000000000..dc0fb54b5 --- /dev/null +++ b/projects/Test/Integration/GH/TestGitHubIssues.cs @@ -0,0 +1,100 @@ +// This source code is dual-licensed under the Apache License, version +// 2.0, and the Mozilla Public License, version 2.0. +// +// The APL v2.0: +// +//--------------------------------------------------------------------------- +// Copyright (c) 2007-2024 Broadcom. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//--------------------------------------------------------------------------- +// +// The MPL v2.0: +// +//--------------------------------------------------------------------------- +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright (c) 2007-2024 Broadcom. All Rights Reserved. +//--------------------------------------------------------------------------- + +using System.Threading; +using System.Threading.Tasks; +using RabbitMQ.Client; +using RabbitMQ.Client.Events; +using Xunit; +using Xunit.Abstractions; + +#nullable enable + +namespace Test.Integration.GH +{ + public class TestGitHubIssues : IntegrationFixture + { + public TestGitHubIssues(ITestOutputHelper output) : base(output) + { + } + + public override Task InitializeAsync() + { + // NB: nothing to do here since each test creates its own factory, + // connections and channels + Assert.Null(_connFactory); + Assert.Null(_conn); + Assert.Null(_channel); + return Task.CompletedTask; + } + + [Fact] + public async Task TestBasicConsumeCancellation_GH1750() + { + /* + * Note: + * Testing that the task is actually canceled requires a hacked RabbitMQ server. + * Modify deps/rabbit/src/rabbit_channel.erl, handle_cast for basic.consume_ok + * Before send/2, add timer:sleep(1000), then `make run-broker` + * + * The _output line at the end of the test will print TaskCanceledException + */ + Assert.Null(_connFactory); + Assert.Null(_conn); + Assert.Null(_channel); + + _connFactory = CreateConnectionFactory(); + _connFactory.AutomaticRecoveryEnabled = false; + _connFactory.TopologyRecoveryEnabled = false; + + _conn = await _connFactory.CreateConnectionAsync(); + _channel = await _conn.CreateChannelAsync(); + + QueueDeclareOk q = await _channel.QueueDeclareAsync(); + + var consumer = new AsyncEventingBasicConsumer(_channel); + consumer.ReceivedAsync += (o, a) => + { + return Task.CompletedTask; + }; + + try + { + using var cts = new CancellationTokenSource(5); + await _channel.BasicConsumeAsync(q.QueueName, true, consumer, cts.Token); + } + catch (TaskCanceledException ex) + { + _output.WriteLine("ex: {0}", ex); + } + } + } +}