diff --git a/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs b/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs new file mode 100644 index 00000000..ad12ac8a --- /dev/null +++ b/osu.Server.Spectator.Tests/ConcurrentConnectionLimiterTests.cs @@ -0,0 +1,217 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Linq; +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Moq; +using osu.Game.Online.Multiplayer; +using osu.Server.Spectator.Entities; +using osu.Server.Spectator.Hubs.Spectator; +using Xunit; + +namespace osu.Server.Spectator.Tests +{ + public class ConcurrentConnectionLimiterTests + { + private readonly EntityStore connectionStates; + private readonly Mock serviceProviderMock; + private readonly Mock hubMock; + + public ConcurrentConnectionLimiterTests() + { + connectionStates = new EntityStore(); + serviceProviderMock = new Mock(); + + var hubContextMock = new Mock(); + serviceProviderMock.Setup(sp => sp.GetService(It.IsAny())) + .Returns(hubContextMock.Object); + + hubMock = new Mock(); + } + + [Fact] + public async Task TestNormalOperation() + { + var hubCallerContextMock = new Mock(); + hubCallerContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + hubCallerContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", Guid.NewGuid().ToString()) + }) + })); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object); + var lifetimeContext = new HubLifetimeContext(hubCallerContextMock.Object, serviceProviderMock.Object, hubMock.Object); + + bool connected = false; + await filter.OnConnectedAsync(lifetimeContext, _ => + { + connected = true; + return Task.CompletedTask; + }); + Assert.True(connected); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + + bool methodInvoked = false; + var invocationContext = new HubInvocationContext(hubCallerContextMock.Object, serviceProviderMock.Object, hubMock.Object, + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + await filter.InvokeMethodAsync(invocationContext, _ => + { + methodInvoked = true; + return new ValueTask(new object()); + }); + Assert.True(methodInvoked); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + + bool disconnected = false; + await filter.OnDisconnectedAsync(lifetimeContext, null, (_, _) => + { + disconnected = true; + return Task.CompletedTask; + }); + Assert.True(disconnected); + Assert.Null(connectionStates.GetEntityUnsafe(1234)); + } + + [Fact] + public async Task TestConcurrencyBlocked() + { + var firstContextMock = new Mock(); + var secondContextMock = new Mock(); + + firstContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstContextMock.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", Guid.NewGuid().ToString()) + }) + })); + + secondContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondContextMock.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", Guid.NewGuid().ToString()) + }) + })); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstContextMock.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondContextMock.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + + var secondInvocationContext = new HubInvocationContext(secondContextMock.Object, serviceProviderMock.Object, hubMock.Object, + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + // should succeed. + await filter.InvokeMethodAsync(secondInvocationContext, _ => new ValueTask(new object())); + + var firstInvocationContext = new HubInvocationContext(firstContextMock.Object, serviceProviderMock.Object, hubMock.Object, + typeof(SpectatorHub).GetMethod(nameof(SpectatorHub.StartWatchingUser))!, new object[] { 1234 }); + // should throw. + await Assert.ThrowsAsync(() => filter.InvokeMethodAsync(firstInvocationContext, _ => new ValueTask(new object())).AsTask()); + } + + [Fact] + public async Task TestStaleDisconnectIsANoOp() + { + var firstContextMock = new Mock(); + var secondContextMock = new Mock(); + string commonTokenId = Guid.NewGuid().ToString(); + + firstContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstContextMock.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonTokenId) + }) + })); + + secondContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondContextMock.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonTokenId) + }) + })); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstContextMock.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondContextMock.Object, serviceProviderMock.Object, hubMock.Object); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + + await filter.OnDisconnectedAsync(firstLifetimeContext, null, (_, _) => Task.CompletedTask); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds.Single().Value); + } + + [Fact] + public async Task TestHubDisconnectsTrackedSeparately() + { + var firstContextMock = new Mock(); + var secondContextMock = new Mock(); + string commonTokenId = Guid.NewGuid().ToString(); + + firstContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + firstContextMock.Setup(ctx => ctx.ConnectionId).Returns("abcd"); + firstContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonTokenId) + }) + })); + + secondContextMock.Setup(ctx => ctx.UserIdentifier).Returns("1234"); + secondContextMock.Setup(ctx => ctx.ConnectionId).Returns("efgh"); + secondContextMock.Setup(ctx => ctx.User).Returns(new ClaimsPrincipal(new[] + { + new ClaimsIdentity(new[] + { + new Claim("jti", commonTokenId) + }) + })); + + var filter = new ConcurrentConnectionLimiter(connectionStates, serviceProviderMock.Object); + + var firstLifetimeContext = new HubLifetimeContext(firstContextMock.Object, serviceProviderMock.Object, new FirstHub()); + await filter.OnConnectedAsync(firstLifetimeContext, _ => Task.CompletedTask); + + var secondLifetimeContext = new HubLifetimeContext(secondContextMock.Object, serviceProviderMock.Object, new SecondHub()); + await filter.OnConnectedAsync(secondLifetimeContext, _ => Task.CompletedTask); + Assert.Equal(2, connectionStates.GetEntityUnsafe(1234)!.ConnectionIds.Count); + Assert.Equal("abcd", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(FirstHub)]); + Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(SecondHub)]); + + await filter.OnDisconnectedAsync(firstLifetimeContext, null, (_, _) => Task.CompletedTask); + Assert.Single(connectionStates.GetEntityUnsafe(1234)!.ConnectionIds); + Assert.Equal("efgh", connectionStates.GetEntityUnsafe(1234)!.ConnectionIds[typeof(SecondHub)]); + } + + private class FirstHub : Hub + { + } + + private class SecondHub : Hub + { + } + } +} diff --git a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs index cdc67eeb..2a6cda46 100644 --- a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs +++ b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs @@ -68,7 +68,7 @@ public async Task OnConnectedAsync(HubLifetimeContext context, Func).MakeGenericType(hubType); var hubContext = serviceProvider.GetRequiredService(hubContextType) as IHubContext; - if (userState.Item.ConnectionIds.TryGetValue(hubType, out var connectionId)) + if (userState.Item.ConnectionIds.TryGetValue(hubType, out string? connectionId)) { hubContext?.Clients.Client(connectionId) .SendCoreAsync(nameof(IStatefulUserHubClient.DisconnectRequested), Array.Empty()); @@ -94,11 +94,16 @@ private static void log(HubLifetimeContext context, string message) using (var userState = await connectionStates.GetForUse(userId)) { - if (invocationContext.Context.GetTokenId() != userState.Item?.TokenId - || invocationContext.Context.ConnectionId != userState.Item?.ConnectionIds[invocationContext.Hub.GetType()]) - { + string? registeredConnectionId = null; + + bool tokenIdMatches = invocationContext.Context.GetTokenId() == userState.Item?.TokenId; + bool hubRegistered = userState.Item?.ConnectionIds.TryGetValue(invocationContext.Hub.GetType(), out registeredConnectionId) == true; + bool connectionIdMatches = registeredConnectionId == invocationContext.Context.ConnectionId; + + bool connectionIsValid = tokenIdMatches && hubRegistered && connectionIdMatches; + + if (!connectionIsValid) throw new InvalidStateException("State is not valid for this connection"); - } } return await next(invocationContext); @@ -116,9 +121,23 @@ public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception? exc using (var userState = await connectionStates.GetForUse(userId, true)) { - if (userState.Item?.TokenId == context.Context.GetTokenId()) + string? registeredConnectionId = null; + + bool tokenIdMatches = context.Context.GetTokenId() == userState.Item?.TokenId; + bool hubRegistered = userState.Item?.ConnectionIds.TryGetValue(context.Hub.GetType(), out registeredConnectionId) == true; + bool connectionIdMatches = registeredConnectionId == context.Context.ConnectionId; + + bool connectionCanBeCleanedUp = tokenIdMatches && hubRegistered && connectionIdMatches; + + if (connectionCanBeCleanedUp) + { + log(context, "disconnected from hub"); + userState.Item!.ConnectionIds.Remove(context.Hub.GetType()); + } + + if (userState.Item?.ConnectionIds.Count == 0) { - log(context, "disconnected"); + log(context, "all connections closed, destroying state"); userState.Destroy(); } }