Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websocket server #4130

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/BizHawk.Client.Common/Api/Classes/CommApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ namespace BizHawk.Client.Common
{
public sealed class CommApi : ICommApi
{
private static readonly WebSocketServer _wsServer = new WebSocketServer();
private static readonly WebSocketClient _wsClient = new();

private readonly (HttpCommunication? HTTP, MemoryMappedFiles MMF, SocketServer? Sockets) _networkingHelpers;
private readonly (HttpCommunication? HTTP, MemoryMappedFiles MMF, SocketServer? Sockets, WebSocketServer? WebSocketServer) _networkingHelpers;

public HttpCommunication? HTTP => _networkingHelpers.HTTP;

public MemoryMappedFiles MMF => _networkingHelpers.MMF;

public SocketServer? Sockets => _networkingHelpers.Sockets;

public WebSocketServer WebSockets => _wsServer;
public WebSocketClient WebSockets => _wsClient;

public WebSocketServer? WebSocketServer => _networkingHelpers.WebSocketServer;

public CommApi(IMainFormForApi mainForm) => _networkingHelpers = mainForm.NetworkingHelpers;

Expand Down
9 changes: 9 additions & 0 deletions src/BizHawk.Client.Common/Api/WebSocketClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#nullable enable

namespace BizHawk.Client.Common
{
public sealed class WebSocketClient
{
public ClientWebSocketWrapper Open(Uri uri) => new(uri);
}
}
209 changes: 205 additions & 4 deletions src/BizHawk.Client.Common/Api/WebSocketServer.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,215 @@
#nullable enable

using System.Net;
using System.Threading.Tasks;
using System.Net.WebSockets;
using System.Threading;
using System.Text;
using System.Collections.Generic;
using BizHawk.Client.Common.Websocket.Messages;
using Newtonsoft.Json;
using BizHawk.Common.CollectionExtensions;
using Newtonsoft.Json.Serialization;

namespace BizHawk.Client.Common
{
public sealed class WebSocketServer
{
public ClientWebSocketWrapper Open(
Uri uri,
CancellationToken cancellationToken = default/* == CancellationToken.None */)
=> new(uri, cancellationToken);
private static readonly HashSet<Topic> forcedRegistrationTopics = [Topic.Error, Topic.Registration];
private readonly HttpListener clientRegistrationListener;
private CancellationToken _cancellationToken = default;
private bool _running = false;
private readonly Dictionary<string, WebSocket> clients = [];
private readonly Dictionary<Topic, HashSet<string>> topicRegistrations = [];

/// <param name="host">
/// host address to register for listening to connections, defaults to <see cref="IPAddress.Loopback"/>>
/// </param>
/// <param name="port">port to register for listening to connections</param>
public WebSocketServer(IPAddress? host = null, int port = 3333)
{
clientRegistrationListener = new();
clientRegistrationListener.Prefixes.Add($"http://{host}:{port}/");
}

/// <summary>
/// Stops the server. Alternatively, use the cancellation token passed into <see cref="Start"/>.
/// The server can be restarted by calling <see cref="Start"/> again.
/// </summary>
public void Stop()
{
var cancellationTokenSource = new CancellationTokenSource();
_cancellationToken = cancellationTokenSource.Token;
cancellationTokenSource.Cancel();
_running = false;
}

/// <summary>
/// Starts the websocket server at the configured address and registers clients.
/// </summary>
/// <param name="cancellationToken">optional cancellation token to stop the server</param>
/// <returns>async task for the server loop</returns>
public async Task Start(CancellationToken cancellationToken = default)
{
if (_running)
{
throw new InvalidOperationException("Server has already been started");
}
_cancellationToken = cancellationToken;
_running = true;

clientRegistrationListener.Start();
await ListenForAndRegisterClients();
}

private async Task ListenForAndRegisterClients()
{
while (_running && !_cancellationToken.IsCancellationRequested)
{
var context = await clientRegistrationListener.GetContextAsync();
if (context is null) return;

if (!context.Request.IsWebSocketRequest)
{
context.Response.Abort();
return;
}

var webSocketContext = await context.AcceptWebSocketAsync(subProtocol: null);
if (webSocketContext is null) return;
RegisterClient(webSocketContext.WebSocket);
}
}

private void RegisterClient(WebSocket newClient)
{
string clientId = Guid.NewGuid().ToString();
clients.Add(clientId, newClient);
_ = Task.Run(() => ClientMessageReceiveLoop(clientId), _cancellationToken);
}

private async Task ClientMessageReceiveLoop(string clientId)
{
byte[] buffer = new byte[1024];
var messageStringBuilder = new StringBuilder(2048);
var client = clients[clientId];
while (client.State == WebSocketState.Open && !_cancellationToken.IsCancellationRequested)
{
ArraySegment<byte> messageBuffer = new(buffer);
var receiveResult = await client.ReceiveAsync(messageBuffer, _cancellationToken);
if (receiveResult.Count == 0)
return;

messageStringBuilder.Append(Encoding.ASCII.GetString(buffer, 0, receiveResult.Count));
if (receiveResult.EndOfMessage)
{
string messageString = messageStringBuilder.ToString();
messageStringBuilder = new StringBuilder(2048);

try
{
var request = JsonSerde.Deserialize<RequestMessageWrapper>(messageString);
await HandleRequest(clientId, request);
}
catch (Exception e)
{
// TODO proper logging
Console.WriteLine("Error deserializing message {0}", e);
await SendClientGenericError(clientId);
}
}
}
}

private async Task HandleRequest(string clientId, RequestMessageWrapper request)
{
try
{
switch (request.Topic)
{
case Topic.Error:
// clients arent allowed to publish to this topic
await SendClientGenericError(clientId);
break;

case Topic.Registration:
await HandleRegistrationRequest(clientId, request.Registration!.Value);
break;

case Topic.Echo:
await HandleEchoRequest(clientId, request.Echo!.Value);
break;
}

}
catch (Exception e)
{
// this could happen if, for instance, the client sent a registration request to the echo topic, such
// that we tried to access the wrong field of the wrapper
// TODO proper logging
Console.WriteLine("Error handling message {0}", e);
Copy link
Author

@austinmilt austinmilt Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would appreciate advice on how to properly do logging (both debug and otherwise)

await SendClientGenericError(clientId);
}
}

private async Task HandleRegistrationRequest(string clientId, RegistrationRequestMessage request)
{
foreach (Topic topic in Enum.GetValues(typeof(Topic)))
{
if (forcedRegistrationTopics.Contains(topic))
{
// we dont need to keep track of topics that clients must be registered for.
continue;
}
else if (request.Topics.Contains(topic))
{
_ = topicRegistrations.GetValueOrPut(topic, (_) => []).Add(clientId);
}
else
{
_ = topicRegistrations.GetValueOrDefault(topic, [])?.Remove(clientId);
}
}

var registeredTopics = request.Topics;
registeredTopics.AddRange(forcedRegistrationTopics);
var responseMessage = new ResponseMessageWrapper(new RegistrationResponseMessage(registeredTopics));
await SendClientMessage(clientId, responseMessage);
}

private async Task HandleEchoRequest(string clientId, EchoRequestMessage request)
{
if (topicRegistrations.GetValueOrDefault(Topic.Echo, [])?.Contains(clientId) ?? false)
{
await SendClientMessage(clientId, new ResponseMessageWrapper(new EchoResponseMessage(request.Message)));
}
}

// clients always get error topics
private async Task SendClientGenericError(string clientId) => await SendClientMessage(clientId, new ResponseMessageWrapper(new ErrorMessage(ErrorType.UnknownRequest)));

private async Task SendClientMessage(string clientId, object message)
{
await clients[clientId].SendAsync(
JsonSerde.Serialize(message),
WebSocketMessageType.Text,
endOfMessage: true,
_cancellationToken
);
}

private static class JsonSerde
{

private static readonly JsonSerializerSettings serializerSettings = new()
{
NullValueHandling = NullValueHandling.Ignore,
ContractResolver = new CamelCasePropertyNamesContractResolver(),
};

public static ArraySegment<byte> Serialize(object message) => new(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(message, serializerSettings)));

public static T? Deserialize<T>(string message) => JsonConvert.DeserializeObject<T>(message, serializerSettings);
}
}
}
51 changes: 34 additions & 17 deletions src/BizHawk.Client.Common/ArgParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ public static class ArgParser
private sealed class BespokeOption<T> : Option<T>
{
public BespokeOption(string name)
: base(name) {}
: base(name) { }

public BespokeOption(string name, string description)
: base(name: name, description: description) {}
: base(name: name, description: description) { }

public BespokeOption(string[] aliases)
: base(aliases) {}
: base(aliases) { }

public BespokeOption(string[] aliases, string description)
: base(aliases, description) {}
: base(aliases, description) { }
}

private static readonly Argument<string?> ArgumentRomFilePath = new(name: "rom", () => null, description: "path; if specified, the file will be loaded the same way as it would be from `File` > `Open...`");
Expand All @@ -47,9 +47,9 @@ public BespokeOption(string[] aliases, string description)

private static readonly BespokeOption<string?> OptionConfigFilePath = new(name: "--config", description: "path of config file to use");

private static readonly BespokeOption<string?> OptionHTTPClientURIGET = new(aliases: [ "--url-get", "--url_get" ], description: "string; URI to use for HTTP 'GET' IPC (Lua `comm.http*Get*`)");
private static readonly BespokeOption<string?> OptionHTTPClientURIGET = new(aliases: ["--url-get", "--url_get"], description: "string; URI to use for HTTP 'GET' IPC (Lua `comm.http*Get*`)");

private static readonly BespokeOption<string?> OptionHTTPClientURIPOST = new(aliases: [ "--url-post", "--url_post" ], description: "string; URI to use for HTTP 'POST' IPC (Lua `comm.http*Post*`)");
private static readonly BespokeOption<string?> OptionHTTPClientURIPOST = new(aliases: ["--url-post", "--url_post"], description: "string; URI to use for HTTP 'POST' IPC (Lua `comm.http*Post*`)");

private static readonly BespokeOption<bool> OptionLaunchChromeless = new(name: "--chromeless", description: "pass and the 'chrome' (a.k.a. GUI) will never be shown, not even in windowed mode");

Expand All @@ -71,11 +71,15 @@ public BespokeOption(string[] aliases, string description)

private static readonly BespokeOption<bool> OptionQueryAppVersion = new(name: "--version", description: "pass to print version information and immediately quit");

private static readonly BespokeOption<string?> OptionSocketServerIP = new(aliases: [ "--socket-ip", "--socket_ip" ]); // desc added in static ctor
private static readonly BespokeOption<string?> OptionSocketServerIP = new(aliases: ["--socket-ip", "--socket_ip"]); // desc added in static ctor

private static readonly BespokeOption<ushort?> OptionSocketServerPort = new(aliases: [ "--socket-port", "--socket_port" ]); // desc added in static ctor
private static readonly BespokeOption<ushort?> OptionSocketServerPort = new(aliases: ["--socket-port", "--socket_port"]); // desc added in static ctor

private static readonly BespokeOption<bool> OptionSocketServerUseUDP = new(aliases: [ "--socket-udp", "--socket_udp" ]); // desc added in static ctor
private static readonly BespokeOption<bool> OptionSocketServerUseUDP = new(aliases: ["--socket-udp", "--socket_udp"]); // desc added in static ctor

private static readonly BespokeOption<string?> OptionWebSocketServerIP = new(aliases: ["--websocket-server-ip", "--ws_ip"]); // desc added in static ctor

private static readonly BespokeOption<ushort?> OptionWebSocketServerPort = new(aliases: ["--websocket-server-port", "--ws_port"]); // desc added in static ctor

private static readonly BespokeOption<string?> OptionUserdataUnparsedPairs = new(name: "--userdata", description: "pairs in the format `k1:v1;k2:v2` (mind your shell escape sequences); if the value is `true`/`false` it's interpreted as a boolean, if it's a valid 32-bit signed integer e.g. `-1234` it's interpreted as such, if it's a valid 32-bit float e.g. `12.34` it's interpreted as such, else it's interpreted as a string");

Expand All @@ -91,6 +95,8 @@ static ArgParser()
OptionSocketServerIP.Description = $"string; IP address for Unix socket IPC (Lua `comm.socket*`); must be paired with `--{OptionSocketServerPort.Name}`";
OptionSocketServerPort.Description = $"int; port for Unix socket IPC (Lua `comm.socket*`); must be paired with `--{OptionSocketServerIP.Name}`";
OptionSocketServerUseUDP.Description = $"pass to use UDP instead of TCP for Unix socket IPC (Lua `comm.socket*`); ignored unless `--{OptionSocketServerIP.Name} --{OptionSocketServerPort.Name}` also passed";
OptionWebSocketServerIP.Description = $"string; IP address for websocket server; must be paired with `--{OptionWebSocketServerPort.Name}`";
OptionWebSocketServerPort.Description = $"int; port for websocket server; must be paired with `--{OptionWebSocketServerIP.Name}`";

RootCommand root = new();
root.Add(ArgumentRomFilePath);
Expand All @@ -114,22 +120,24 @@ static ArgParser()
root.Add(/* --socket-ip */ OptionSocketServerIP);
root.Add(/* --socket-port */ OptionSocketServerPort);
root.Add(/* --socket-udp */ OptionSocketServerUseUDP);
root.Add(/* --websocket-server-ip */ OptionWebSocketServerIP);
root.Add(/* --websocket-server-port */ OptionWebSocketServerPort);
root.Add(/* --url-get */ OptionHTTPClientURIGET);
root.Add(/* --url-post */ OptionHTTPClientURIPOST);
root.Add(/* --userdata */ OptionUserdataUnparsedPairs);
root.Add(/* --version */ OptionQueryAppVersion);

Parser = new CommandLineBuilder(root)
// .UseVersionOption() // "cannot be combined with other arguments" which is fair enough but `--config` is crucial on NixOS
// .UseVersionOption() // "cannot be combined with other arguments" which is fair enough but `--config` is crucial on NixOS
.UseHelp()
// .UseEnvironmentVariableDirective() // useless
// .UseEnvironmentVariableDirective() // useless
.UseParseDirective()
.UseSuggestDirective()
// .RegisterWithDotnetSuggest() // intended for dotnet tools
// .UseTypoCorrections() // we're only using the parser, and I guess this only works with the full buy-in
// .UseParseErrorReporting() // we're only using the parser, and I guess this only works with the full buy-in
// .UseExceptionHandler() // we're only using the parser, so nothing should be throwing
// .CancelOnProcessTermination() // we're only using the parser, so there's not really anything to cancel
// .RegisterWithDotnetSuggest() // intended for dotnet tools
// .UseTypoCorrections() // we're only using the parser, and I guess this only works with the full buy-in
// .UseParseErrorReporting() // we're only using the parser, and I guess this only works with the full buy-in
// .UseExceptionHandler() // we're only using the parser, so nothing should be throwing
// .CancelOnProcessTermination() // we're only using the parser, so there's not really anything to cancel
.Build();
GeneratedOptions = root.Options.Where(static o =>
{
Expand Down Expand Up @@ -185,6 +193,14 @@ static ArgParser()
? (socketIP, socketPort.Value)
: throw new ArgParserException("Socket server needs both --socket_ip and --socket_port. Socket server was not started");

var websocketIP = result.GetValueForOption(OptionWebSocketServerIP);
var websocketPort = result.GetValueForOption(OptionWebSocketServerPort);
var webSocketServerAddress = websocketIP is null && websocketPort is null
? ((string, ushort)?) null // don't bother
: websocketIP is not null && websocketPort is not null
? (websocketIP, websocketPort.Value)
: throw new ArgParserException("Websocket server needs both --ws_ip and --ws_port. Websocket server was not started");

var httpClientURIGET = result.GetValueForOption(OptionHTTPClientURIGET);
var httpClientURIPOST = result.GetValueForOption(OptionHTTPClientURIPOST);
var httpAddresses = httpClientURIGET is null && httpClientURIPOST is null
Expand Down Expand Up @@ -221,6 +237,7 @@ static ArgParser()
luaScript: luaScript,
luaConsole: luaConsole,
socketAddress: socketAddress,
webSocketServerAddress: webSocketServerAddress,
mmfFilename: result.GetValueForOption(OptionMMFPath),
httpAddresses: httpAddresses,
audiosync: audiosync,
Expand All @@ -234,7 +251,7 @@ static ArgParser()

public sealed class ArgParserException : Exception
{
public ArgParserException(string message) : base(message) {}
public ArgParserException(string message) : base(message) { }
}
}
}
8 changes: 7 additions & 1 deletion src/BizHawk.Client.Common/IMainFormForApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ public interface IMainFormForApi
bool IsTurboing { get; }

/// <remarks>only referenced from <see cref="CommApi"/></remarks>
(HttpCommunication HTTP, MemoryMappedFiles MMF, SocketServer Sockets) NetworkingHelpers { get; }
(
HttpCommunication HTTP,
MemoryMappedFiles MMF,
SocketServer Sockets,
WebSocketServer WebSocketServer
) NetworkingHelpers
{ get; }

/// <remarks>only referenced from <c>EmuClientApi</c></remarks>
bool PauseAvi { get; set; }
Expand Down
Loading