Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websocket improvements #4129

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 65 additions & 24 deletions src/BizHawk.Client.Common/Api/ClientWebSocketWrapper.cs
Original file line number Diff line number Diff line change
@@ -1,57 +1,95 @@
#nullable enable

using System.Collections.Generic;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace BizHawk.Client.Common
{
public struct ClientWebSocketWrapper
public class ClientWebSocketWrapper(Uri uri)
{
private ClientWebSocket? _w;

private readonly Queue<string> _receivedMessages = new();

private readonly Uri _uri = uri;

/// <summary>calls <see cref="ClientWebSocket.State"/> getter (unless closed/disposed, then <see cref="WebSocketState.Closed"/> is always returned)</summary>
public WebSocketState State => _w?.State ?? WebSocketState.Closed;

public ClientWebSocketWrapper(
Uri uri,
CancellationToken cancellationToken = default/* == CancellationToken.None */)
{
_w = new ClientWebSocket();
_w.ConnectAsync(uri, cancellationToken).Wait(cancellationToken);
}

/// <summary>calls <see cref="ClientWebSocket.CloseAsync"/></summary>
/// <summary>calls <see cref="ClientWebSocket.CloseOutputAsync"/></summary>
/// <remarks>also calls <see cref="ClientWebSocket.Dispose"/> (wrapper property <see cref="State"/> will continue to work, method calls will throw <see cref="ObjectDisposedException"/>)</remarks>
public Task Close(
WebSocketCloseStatus closeStatus,
string statusDescription,
CancellationToken cancellationToken = default/* == CancellationToken.None */)
{
if (_w == null) throw new ObjectDisposedException(nameof(_w));
var task = _w.CloseAsync(closeStatus, statusDescription, cancellationToken);
var task = _w.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
_w.Dispose();
_w = null;
return task;
}

/// <summary>calls <see cref="ClientWebSocket.ReceiveAsync"/></summary>
public Task<WebSocketReceiveResult> Receive(
ArraySegment<byte> buffer,
CancellationToken cancellationToken = default/* == CancellationToken.None */)
=> _w?.ReceiveAsync(buffer, cancellationToken)
?? throw new ObjectDisposedException(nameof(_w));
public async Task Connect(int bufferSize, int maxMessages)
{
_w ??= new();
if ((_w != null) && (_w.State != WebSocketState.Open))
{
await _w.ConnectAsync(_uri, CancellationToken.None);
await Receive(bufferSize, maxMessages);
}
}

/// <summary>opens a connection to the configured server and passes messages to [consumer]</summary>
public async Task Connect(Action<string> consumer, int bufferSize = 1024)
{
_w ??= new();
if ((_w != null) && (_w.State != WebSocketState.Open))
{
await _w.ConnectAsync(_uri, CancellationToken.None);
}

var buffer = new ArraySegment<byte>(new byte[bufferSize]);
while ((_w != null) && (_w.State == WebSocketState.Open))
{
var result = await _w.ReceiveAsync(buffer, CancellationToken.None);
string message = Encoding.UTF8.GetString(buffer.Array, 0, result.Count);
consumer(message);
}
}

/// <summary>opens a connection to the configured server and passes messages to [consumer]</summary>
public async Task Connect(Action<byte[]> consumer, int bufferSize = 2048)
{
_w ??= new();
if ((_w != null) && (_w.State != WebSocketState.Open))
{
await _w.ConnectAsync(_uri, CancellationToken.None);
}

var buffer = new ArraySegment<byte>(new byte[bufferSize]);
while ((_w != null) && (_w.State == WebSocketState.Open))
{
_ = await _w.ReceiveAsync(buffer, CancellationToken.None);
consumer(buffer.Array);
}
}

/// <summary>calls <see cref="ClientWebSocket.ReceiveAsync"/></summary>
public string Receive(
int bufferCap,
CancellationToken cancellationToken = default/* == CancellationToken.None */)
public async Task Receive(int bufferSize, int maxMessages)
{
if (_w == null) throw new ObjectDisposedException(nameof(_w));
var buffer = new byte[bufferCap];
var result = Receive(new ArraySegment<byte>(buffer), cancellationToken).Result;
return Encoding.UTF8.GetString(buffer, 0, result.Count);
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
while ((_w != null) && (_w.State == WebSocketState.Open))
{
var result = await _w.ReceiveAsync(buffer, CancellationToken.None);
if (maxMessages == 0 || _receivedMessages.Count < maxMessages)
{
_receivedMessages.Enqueue(Encoding.UTF8.GetString(buffer.Array, 0, result.Count));
}
}
}

/// <summary>calls <see cref="ClientWebSocket.SendAsync"/></summary>
Expand All @@ -77,5 +115,8 @@ public Task Send(
cancellationToken
);
}

/// <summary>pops the first cached message off the message queue, otherwise returns null</summary>
public string? PopMessage() => (_receivedMessages.Count > 0) ? _receivedMessages.Dequeue() : null;
}
}
2 changes: 0 additions & 2 deletions src/BizHawk.Client.Common/Api/Interfaces/ICommApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ public interface ICommApi : IExternalApi

SocketServer? Sockets { get; }

#if ENABLE_WEBSOCKETS
WebSocketServer WebSockets { get; }
#endif

string? HttpTest();

Expand Down
7 changes: 1 addition & 6 deletions src/BizHawk.Client.Common/Api/WebSocketServer.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
#nullable enable

using System.Threading;

namespace BizHawk.Client.Common
{
public sealed class WebSocketServer
{
public ClientWebSocketWrapper Open(
Uri uri,
CancellationToken cancellationToken = default/* == CancellationToken.None */)
=> new(uri, cancellationToken);
public ClientWebSocketWrapper Open(Uri uri) => new(uri);
}
}
45 changes: 27 additions & 18 deletions src/BizHawk.Client.Common/lua/CommonLibs/CommLuaLibrary.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Net.WebSockets;
using System.Text;

using NLua;

namespace BizHawk.Client.Common
Expand All @@ -13,7 +13,7 @@ public sealed class CommLuaLibrary : LuaLibraryBase
private readonly IDictionary<Guid, ClientWebSocketWrapper> _websockets = new Dictionary<Guid, ClientWebSocketWrapper>();

public CommLuaLibrary(ILuaLibraries luaLibsImpl, ApiContainer apiContainer, Action<string> logOutputCallback)
: base(luaLibsImpl, apiContainer, logOutputCallback) {}
: base(luaLibsImpl, apiContainer, logOutputCallback) { }

public override string Name => "comm";

Expand Down Expand Up @@ -253,20 +253,24 @@ private void CheckHttp()
}
}

#if ENABLE_WEBSOCKETS
[LuaMethod("ws_open", "Opens a websocket and returns the id so that it can be retrieved later.")]
[LuaMethod("ws_open", "Opens a websocket and returns the id so that it can be retrieved later. If an id is provided, reconnects to the server")]
[LuaMethodExample("local ws_id = comm.ws_open(\"wss://echo.websocket.org\");")]
public string WebSocketOpen(string uri)
public string WebSocketOpen(
string uri,
string guid = null,
int bufferSize = 1024,
int maxMessages = 20)
{
var wsServer = APIs.Comm.WebSockets;
if (wsServer == null)
{
Log("WebSocket server is somehow not available");
return null;
}
var guid = new Guid();
_websockets[guid] = wsServer.Open(new Uri(uri));
return guid.ToString();
var localGuid = guid == null ? Guid.NewGuid() : Guid.Parse(guid);

_websockets[localGuid] ??= wsServer.Open(new Uri(uri));
_websockets[localGuid].Connect(bufferSize, maxMessages).Wait(500);
return localGuid.ToString();
}

[LuaMethod("ws_send", "Send a message to a certain websocket id (boolean flag endOfMessage)")]
Expand All @@ -276,32 +280,37 @@ public void WebSocketSend(
string content,
bool endOfMessage)
{
if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper)) wrapper.Send(content, endOfMessage);
if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper))
{
_ = wrapper.Send(content, endOfMessage);
}
}

[LuaMethod("ws_receive", "Receive a message from a certain websocket id and a maximum number of bytes to read")]
[LuaMethodExample("local ws = comm.ws_receive(ws_id, str_len);")]
public string WebSocketReceive(string guid, int bufferCap)
[LuaMethod("ws_receive", "Receive a message from a certain websocket id")]
[LuaMethodExample("local ws = comm.ws_receive(ws_id);")]
public string WebSocketReceive(string guid)
=> _websockets.TryGetValue(Guid.Parse(guid), out var wrapper)
? wrapper.Receive(bufferCap)
? wrapper.PopMessage()
: null;

[LuaMethod("ws_get_status", "Get a websocket's status")]
[LuaMethodExample("local ws_status = comm.ws_get_status(ws_id);")]
public int? WebSocketGetStatus(string guid)
=> _websockets.TryGetValue(Guid.Parse(guid), out var wrapper)
? (int) wrapper.State
: (int?) null;
: null;

[LuaMethod("ws_close", "Close a websocket connection with a close status")]
[LuaMethodExample("local ws_status = comm.ws_close(ws_id, close_status);")]
public void WebSocketClose(
string guid,
WebSocketCloseStatus status,
int status,
string closeMessage)
{
if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper)) wrapper.Close(status, closeMessage);
if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper))
{
_ = wrapper.Close((WebSocketCloseStatus) status, closeMessage);
}
}
#endif
}
}
Loading