Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disconnect existing client connections when another client instance for same user is detected #196

Merged
merged 14 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions SampleMultiplayerClient/MultiplayerClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Client;
using osu.Game.Online;
using osu.Game.Online.API;
using osu.Game.Online.Multiplayer;
using osu.Game.Online.Rooms;
Expand Down Expand Up @@ -45,6 +46,7 @@ public MultiplayerClient(HubConnection connection, int userId)
connection.On<MultiplayerPlaylistItem>(nameof(IMultiplayerClient.PlaylistItemChanged), ((IMultiplayerClient)this).PlaylistItemChanged);
connection.On<long>(nameof(IMultiplayerClient.PlaylistItemRemoved), ((IMultiplayerClient)this).PlaylistItemRemoved);
connection.On<int, long, string>(nameof(IMultiplayerClient.Invited), ((IMultiplayerClient)this).Invited);
connection.On(nameof(IStatefulUserHubClient.DisconnectRequested), ((IStatefulUserHubClient)this).DisconnectRequested);
}

public MultiplayerUserState State { get; private set; }
Expand Down Expand Up @@ -250,5 +252,11 @@ public Task PlaylistItemChanged(MultiplayerPlaylistItem item)
Console.WriteLine($"Playlist item changed (id: {item.ID} beatmap: {item.BeatmapID}, ruleset: {item.RulesetID})");
return Task.CompletedTask;
}

public async Task DisconnectRequested()
{
Console.WriteLine("Disconnect requested");
await LeaveRoom();
}
}
}
8 changes: 8 additions & 0 deletions SampleSpectatorClient/SpectatorClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Client;
using osu.Game.Online;
using osu.Game.Online.Spectator;

namespace SampleSpectatorClient
Expand All @@ -26,6 +27,7 @@ public SpectatorClient(HubConnection connection)
connection.On<int, FrameDataBundle>(nameof(ISpectatorClient.UserSentFrames), ((ISpectatorClient)this).UserSentFrames);
connection.On<int, SpectatorState>(nameof(ISpectatorClient.UserFinishedPlaying), ((ISpectatorClient)this).UserFinishedPlaying);
connection.On<int, long>(nameof(ISpectatorClient.UserScoreProcessed), ((ISpectatorClient)this).UserScoreProcessed);
connection.On(nameof(IStatefulUserHubClient.DisconnectRequested), ((IStatefulUserHubClient)this).DisconnectRequested);
}

Task ISpectatorClient.UserBeganPlaying(int userId, SpectatorState state)
Expand Down Expand Up @@ -69,5 +71,11 @@ Task ISpectatorClient.UserScoreProcessed(int userId, long scoreId)
public Task EndPlaying(SpectatorState state) => connection.SendAsync(nameof(ISpectatorServer.EndPlaySession), state);

public Task WatchUser(int userId) => connection.SendAsync(nameof(ISpectatorServer.StartWatchingUser), userId);

public Task DisconnectRequested()
{
Console.WriteLine($"{connection.ConnectionId} Disconnect requested");
return Task.CompletedTask;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,11 @@ public virtual async Task PlaylistItemChanged(MultiplayerPlaylistItem item)
{
return (Task)GetType().GetMethod(method, BindingFlags.Instance | BindingFlags.Public)!.Invoke(this, args)!;
}

public async Task DisconnectRequested()
{
foreach (var c in Clients)
await c.DisconnectRequested();
}
}
}
6 changes: 4 additions & 2 deletions osu.Server.Spectator.Tests/StatefulUserHubTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;
using Moq;
using osu.Game.Online;
using osu.Server.Spectator.Entities;
using osu.Server.Spectator.Extensions;
using osu.Server.Spectator.Hubs;
using Xunit;

Expand Down Expand Up @@ -104,7 +106,7 @@ public async Task SameUserOldConnectionDoesntDestroyNewState()
private void setNewConnectionId(string? connectionId = null) =>
mockContext.Setup(context => context.ConnectionId).Returns(connectionId ?? Guid.NewGuid().ToString());

private class TestStatefulHub : StatefulUserHub<object, ClientState>
private class TestStatefulHub : StatefulUserHub<IStatefulUserHubClient, ClientState>
{
public TestStatefulHub(IDistributedCache cache, EntityStore<ClientState> userStates)
: base(cache, userStates)
Expand All @@ -114,7 +116,7 @@ public TestStatefulHub(IDistributedCache cache, EntityStore<ClientState> userSta
public async Task CreateUserState()
{
using (var state = await GetOrCreateLocalUserState())
state.Item = new ClientState(Context.ConnectionId, CurrentContextUserId);
state.Item = new ClientState(Context.ConnectionId, Context.GetUserId());
}
}
}
Expand Down
121 changes: 121 additions & 0 deletions osu.Server.Spectator/ConcurrentConnectionLimiter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.DependencyInjection;
using osu.Framework.Extensions.TypeExtensions;
using osu.Framework.Logging;
using osu.Game.Online;
using osu.Game.Online.Multiplayer;
using osu.Server.Spectator.Entities;
using osu.Server.Spectator.Extensions;
using osu.Server.Spectator.Hubs;

namespace osu.Server.Spectator
{
public class ConcurrentConnectionLimiter : IHubFilter
{
private readonly EntityStore<ConnectionState> connectionStates;

private readonly IServiceProvider serviceProvider;

private static readonly IEnumerable<Type> stateful_user_hubs
= typeof(IStatefulUserHub).Assembly.GetTypes().Where(type => typeof(IStatefulUserHub).IsAssignableFrom(type) && typeof(Hub).IsAssignableFrom(type) && !type.IsInterface && !type.IsAbstract).ToArray();

public ConcurrentConnectionLimiter(
EntityStore<ConnectionState> connectionStates,
IServiceProvider serviceProvider)
{
this.connectionStates = connectionStates;
this.serviceProvider = serviceProvider;
}

public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
try
{
var userId = context.Context.GetUserId();

using (var userState = await connectionStates.GetForUse(userId, true))
{
if (context.Context.GetTokenId() == userState.Item?.TokenId)
{
log(context, "subsequent connection from same client instance, registering");
peppy marked this conversation as resolved.
Show resolved Hide resolved
userState.Item.RegisterConnectionId(context);
return;
}

if (userState.Item != null)
{
log(context, "connection from new client instance, dropping existing state");

foreach (var hub in stateful_user_hubs)
{
var hubContextType = typeof(IHubContext<>).MakeGenericType(hub);
var hubContext = serviceProvider.GetRequiredService(hubContextType) as IHubContext;
hubContext?.Clients.Client(userState.Item.ConnectionIds[hub])
.SendCoreAsync(nameof(IStatefulUserHubClient.DisconnectRequested), Array.Empty<object>());
}

log(context, "existing state dropped");
}
else
log(context, "connection from first client instance");

userState.Item = new ConnectionState(context);
}
}
finally
{
await next(context);
}
}

private static void log(HubLifetimeContext context, string message)
=> Logger.Log($"[user:{context.Context.GetUserId()}] [connection:{context.Context.ConnectionId}] [hub:{context.Hub.GetType().ReadableName()}] {message}");

public async ValueTask<object?> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object?>> next)
{
// TODO: allow things to execute for hubs that aren't exclusive (like metadata or whatever)

var userId = invocationContext.Context.GetUserId();

using (var userState = await connectionStates.GetForUse(userId))
{
if (invocationContext.Context.GetTokenId() != userState.Item?.TokenId)
throw new InvalidStateException("State is not valid for this connection");
}

return await next(invocationContext);
}

public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception? exception, Func<HubLifetimeContext, Exception?, Task> next)
{
try
{
if (exception != null)
// network disconnection. wait for user to return.
return;

var userId = context.Context.GetUserId();

using (var userState = await connectionStates.GetForUse(userId, true))
{
if (userState.Item?.TokenId == context.Context.GetTokenId())
{
log(context, "disconnected");
userState.Destroy();
}
}
}
finally
{
await next(context, exception);
}
}
}
}
38 changes: 38 additions & 0 deletions osu.Server.Spectator/Entities/ConnectionState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;
using osu.Server.Spectator.Extensions;

namespace osu.Server.Spectator.Entities
{
public class ConnectionState
{
/// <summary>
/// The unique ID of the JWT the user is using to authenticate.
/// This is used to control user uniqueness.
/// </summary>
public readonly string TokenId;

/// <summary>
/// The connection IDs of the user.
/// </summary>
/// <remarks>
/// In SignalR, connection IDs are unique per user, <em>and</em> per hub instance.
/// Therefore, to keep track of all of them, a dictionary is necessary.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I find this just a tad confusing. Might be better to say something like "In SignalR, connection IDs are unique per connection. Because we use multiple hubs and a user is expected to be connected to each hub, we use a dictionary to track connections across all hubs for a specific user."

Copy link
Member

@peppy peppy Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a diff with this suggestion applied, alongside some other xmldoc and code structure improvements. Please apply as you see fit:

diff --git a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs
index 9c25535..835143f 100644
--- a/osu.Server.Spectator/ConcurrentConnectionLimiter.cs
+++ b/osu.Server.Spectator/ConcurrentConnectionLimiter.cs
@@ -42,34 +42,41 @@ namespace osu.Server.Spectator
 
                 using (var userState = await connectionStates.GetForUse(userId, true))
                 {
-                    if (context.Context.GetTokenId() == userState.Item?.TokenId)
+                    if (userState.Item == null)
                     {
+                        log(context, "connection from first client instance");
+                        userState.Item = new ConnectionState(context);
+                        return;
+                    }
+
+                    if (context.Context.GetTokenId() == userState.Item.TokenId)
+                    {
+                        // The assumption is that the client has already dropped the old connection,
+                        // so we don't bother to ask for a disconnection.
+
                         log(context, "subsequent connection from same client instance, registering");
+
+                        // Importantly, this will replace the old connection, ensuring it cannot be
+                        // used to communicate on anymore.
                         userState.Item.RegisterConnectionId(context);
                         return;
                     }
 
-                    if (userState.Item != null)
+                    log(context, "connection from new client instance, dropping existing state");
+
+                    foreach (var hubType in stateful_user_hubs)
                     {
-                        log(context, "connection from new client instance, dropping existing state");
+                        var hubContextType = typeof(IHubContext<>).MakeGenericType(hubType);
+                        var hubContext = serviceProvider.GetRequiredService(hubContextType) as IHubContext;
 
-                        foreach (var hubType in stateful_user_hubs)
+                        if (userState.Item.ConnectionIds.TryGetValue(hubType, out var connectionId))
                         {
-                            var hubContextType = typeof(IHubContext<>).MakeGenericType(hubType);
-                            var hubContext = serviceProvider.GetRequiredService(hubContextType) as IHubContext;
-
-                            if (userState.Item.ConnectionIds.TryGetValue(hubType, out var connectionId))
-                            {
-                                hubContext?.Clients.Client(connectionId)
-                                          .SendCoreAsync(nameof(IStatefulUserHubClient.DisconnectRequested), Array.Empty<object>());
-                            }
+                            hubContext?.Clients.Client(connectionId)
+                                      .SendCoreAsync(nameof(IStatefulUserHubClient.DisconnectRequested), Array.Empty<object>());
                         }
-
-                        log(context, "existing state dropped");
                     }
-                    else
-                        log(context, "connection from first client instance");
 
+                    log(context, "existing state dropped");
                     userState.Item = new ConnectionState(context);
                 }
             }
diff --git a/osu.Server.Spectator/Entities/ConnectionState.cs b/osu.Server.Spectator/Entities/ConnectionState.cs
index dff8e00..80615f7 100644
--- a/osu.Server.Spectator/Entities/ConnectionState.cs
+++ b/osu.Server.Spectator/Entities/ConnectionState.cs
@@ -8,6 +8,9 @@ using osu.Server.Spectator.Extensions;
 
 namespace osu.Server.Spectator.Entities
 {
+    /// <summary>
+    /// Maintains the connection state of a single client (notably, client, not user) across multiple hubs.
+    /// </summary>
     public class ConnectionState
     {
         /// <summary>
@@ -17,11 +20,13 @@ namespace osu.Server.Spectator.Entities
         public readonly string TokenId;
 
         /// <summary>
-        /// The connection IDs of the user.
+        /// The connection IDs of the user for each hub type.
         /// </summary>
         /// <remarks>
-        /// In SignalR, connection IDs are unique per user, <em>and</em> per hub instance.
-        /// Therefore, to keep track of all of them, a dictionary is necessary.
+        /// In SignalR, connection IDs are unique per connection.
+        ///
+        /// Because we use multiple hubs and a user is expected to be connected to each hub,
+        /// we use a dictionary to track connections across all hubs for a specific user.
         /// </remarks>
         public readonly Dictionary<Type, string> ConnectionIds = new Dictionary<Type, string>();
 
@@ -32,6 +37,10 @@ namespace osu.Server.Spectator.Entities
             RegisterConnectionId(context);
         }
 
+        /// <summary>
+        /// Registers the provided hub/connection context, replacing any existing connection for the hub type.
+        /// </summary>
+        /// <param name="context">The hub context to retrieve information from.</param>
         public void RegisterConnectionId(HubLifetimeContext context)
             => ConnectionIds[context.Hub.GetType()] = context.Context.ConnectionId;
     }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied with minor alterations.

/// </remarks>
public readonly Dictionary<Type, string> ConnectionIds = new Dictionary<Type, string>();

public ConnectionState(HubLifetimeContext context)
{
TokenId = context.Context.GetTokenId();

RegisterConnectionId(context);
}

public void RegisterConnectionId(HubLifetimeContext context)
=> ConnectionIds.Add(context.Hub.GetType(), context.Context.ConnectionId);
}
}
34 changes: 34 additions & 0 deletions osu.Server.Spectator/Extensions/HubCallerContextExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

using System;
using Microsoft.AspNetCore.SignalR;

namespace osu.Server.Spectator.Extensions
{
public static class HubCallerContextExtensions
{
/// <summary>
/// Returns the osu! user id for the supplied <paramref name="context"/>.
/// </summary>
public static int GetUserId(this HubCallerContext context)
{
if (context.UserIdentifier == null)
throw new InvalidOperationException($"Attempted to get user id with null {nameof(context.UserIdentifier)}");

return int.Parse(context.UserIdentifier);
}

/// <summary>
/// Returns the ID of the authorisation token (more accurately, the <c>jti</c> claim)
/// for the supplied <paramref name="context"/>.
/// This is used for the purpose of identifying individual client instances
/// and preventing multiple concurrent sessions from being active.
/// </summary>
public static string GetTokenId(this HubCallerContext context)
{
return context.User?.FindFirst(claim => claim.Type == "jti")?.Value
?? throw new InvalidOperationException("Could not retrieve JWT ID claim from token");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public static IServiceCollection AddHubEntities(this IServiceCollection serviceC
return serviceCollection.AddSingleton<EntityStore<SpectatorClientState>>()
.AddSingleton<EntityStore<MultiplayerClientState>>()
.AddSingleton<EntityStore<ServerMultiplayerRoom>>()
.AddSingleton<EntityStore<ConnectionState>>()
.AddSingleton<GracefulShutdownManager>()
.AddSingleton<MetadataBroadcaster>()
.AddSingleton<IScoreStorage, S3ScoreStorage>()
Expand Down
13 changes: 13 additions & 0 deletions osu.Server.Spectator/Hubs/IStatefulUserHub.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

namespace osu.Server.Spectator.Hubs
{
/// <summary>
/// Marker interface for <see cref="StatefulUserHub{TClient,TUserState}"/>.
/// Allows bypassing generic constraints.
/// </summary>
public interface IStatefulUserHub
{
}
}
14 changes: 0 additions & 14 deletions osu.Server.Spectator/Hubs/LoggingHub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,6 @@ public LoggingHub()
logger = Logger.GetLogger(Name);
}

/// <summary>
/// The osu! user id for the currently processing context.
/// </summary>
protected int CurrentContextUserId
{
get
{
if (Context.UserIdentifier == null)
throw new InvalidOperationException($"Attempted to get user id with null {nameof(Context.UserIdentifier)}");

return int.Parse(Context.UserIdentifier);
}
}

public override async Task OnConnectedAsync()
{
Log("Connected");
Expand Down
Loading
Loading