diff --git a/docker-compose-local.yaml b/demo/docker-compose-local.yaml similarity index 95% rename from docker-compose-local.yaml rename to demo/docker-compose-local.yaml index 32171ff..a141c31 100644 --- a/docker-compose-local.yaml +++ b/demo/docker-compose-local.yaml @@ -2,11 +2,10 @@ version: '3' services: websocket-gateway: - build: . + build: .. ports: - "8765:8765" volumes: - - ./server:/code - ./wait-for-tunnel.sh:/wait-for-tunnel.sh:ro,z entrypoint: /wait-for-tunnel.sh command: > diff --git a/docker-compose.yaml b/demo/docker-compose.yaml similarity index 91% rename from docker-compose.yaml rename to demo/docker-compose.yaml index a19ba33..557df8c 100644 --- a/docker-compose.yaml +++ b/demo/docker-compose.yaml @@ -2,11 +2,11 @@ version: '3' services: websocket-gateway: - build: . + build: .. ports: - "8765:8765" volumes: - - ./socketdock:/usr/src/app/socketdock:z + - ../socketdock:/usr/src/app/socketdock:z command: > --bindip 0.0.0.0 --backend http diff --git a/socket_client.py b/demo/socket_client.py similarity index 100% rename from socket_client.py rename to demo/socket_client.py diff --git a/wait-for-tunnel.sh b/demo/wait-for-tunnel.sh similarity index 89% rename from wait-for-tunnel.sh rename to demo/wait-for-tunnel.sh index 1eaea0c..748f064 100755 --- a/wait-for-tunnel.sh +++ b/demo/wait-for-tunnel.sh @@ -9,4 +9,4 @@ done WS_ENDPOINT=$(curl --silent "${TUNNEL_ENDPOINT}/start" | python -c "import sys, json; print(json.load(sys.stdin)['url'])" | sed -rn 's#https?://([^/]+).*#\1#p') echo "fetched hostname and port [$WS_ENDPOINT]" -exec "$@" --externalhostandport ${WS_ENDPOINT} \ No newline at end of file +exec "$@" --externalhostandport ${WS_ENDPOINT} diff --git a/socketdock/__main__.py b/socketdock/__main__.py index 2464b73..c8b0a0a 100644 --- a/socketdock/__main__.py +++ b/socketdock/__main__.py @@ -4,7 +4,7 @@ import argparse from sanic import Sanic -from .api import api, backend_var, endpoint_var +from .api import api, backend_var def config() -> argparse.Namespace: @@ -38,12 +38,13 @@ def main(): elif args.backend == "http": from .httpbackend import HTTPBackend - backend = HTTPBackend(args.connect_uri, args.message_uri, args.disconnect_uri) + backend = HTTPBackend( + args.endpoint, args.connect_uri, args.message_uri, args.disconnect_uri + ) else: raise ValueError("Invalid backend type") backend_var.set(backend) - endpoint_var.set(args.endpoint) logging.basicConfig(level=args.log_level) diff --git a/socketdock/api.py b/socketdock/api.py index 962d025..083f6f6 100644 --- a/socketdock/api.py +++ b/socketdock/api.py @@ -9,7 +9,6 @@ from .backend import Backend backend_var: ContextVar[Backend] = ContextVar("backend") -endpoint_var: ContextVar[str] = ContextVar("endpoint") api = Blueprint("api", url_prefix="/") @@ -78,9 +77,6 @@ async def socket_handler(request: Request, websocket: Websocket): global lifetime_connections backend = backend_var.get() socket_id = None - endpoint = endpoint_var.get() - send = f"{endpoint}/socket/{socket_id}/send" - disconnect = f"{endpoint_var.get()}/socket/{socket_id}/disconnect" try: # register user LOGGER.info("new client connected") @@ -92,23 +88,15 @@ async def socket_handler(request: Request, websocket: Websocket): LOGGER.info("Request headers: %s", dict(request.headers.items())) await backend.socket_connected( - { - "connection_id": socket_id, - "headers": dict(request.headers.items()), - "send": send, - "disconnect": disconnect, - }, + connection_id=socket_id, + headers=dict(request.headers.items()), ) async for message in websocket: if message: await backend.inbound_socket_message( - { - "connection_id": socket_id, - "send": send, - "disconnect": disconnect, - }, - message, + connection_id=socket_id, + message=message, ) else: LOGGER.warning("empty message received") @@ -118,4 +106,4 @@ async def socket_handler(request: Request, websocket: Websocket): if socket_id: del active_connections[socket_id] LOGGER.info("Removed connection: %s", socket_id) - await backend.socket_disconnected({"connection_id": socket_id}) + await backend.socket_disconnected(socket_id) diff --git a/socketdock/backend.py b/socketdock/backend.py index ebc6ad0..29c8227 100644 --- a/socketdock/backend.py +++ b/socketdock/backend.py @@ -1,25 +1,28 @@ """Backend interface for SocketDock.""" from abc import ABC, abstractmethod -from typing import Union +from typing import Dict, Union class Backend(ABC): """Backend interface for SocketDock.""" @abstractmethod - async def socket_connected(self, callback_uris: dict): + async def socket_connected( + self, + connection_id: str, + headers: Dict[str, str], + ): """Handle new socket connections, with calback provided.""" - raise NotImplementedError() @abstractmethod async def inbound_socket_message( - self, callback_uris: dict, message: Union[str, bytes] + self, + connection_id: str, + message: Union[str, bytes], ): """Handle inbound socket message, with calback provided.""" - raise NotImplementedError() @abstractmethod - async def socket_disconnected(self, bundle: dict): + async def socket_disconnected(self, connection_id: str): """Handle socket disconnected.""" - raise NotImplementedError() diff --git a/socketdock/httpbackend.py b/socketdock/httpbackend.py index f2ad632..c132ce2 100644 --- a/socketdock/httpbackend.py +++ b/socketdock/httpbackend.py @@ -1,7 +1,7 @@ """HTTP backend for SocketDock.""" import logging -from typing import Union +from typing import Dict, Union import aiohttp @@ -14,16 +14,46 @@ class HTTPBackend(Backend): """HTTP backend for SocketDock.""" - def __init__(self, connect_uri: str, message_uri: str, disconnect_uri: str): + def __init__( + self, + socket_base_uri: str, + connect_uri: str, + message_uri: str, + disconnect_uri: str, + ): """Initialize HTTP backend.""" self._connect_uri = connect_uri self._message_uri = message_uri self._disconnect_uri = disconnect_uri + self.socket_base_uri = socket_base_uri + + def send_callback(self, connection_id: str) -> str: + """Return the callback URI for sending a message to a connected socket.""" + return f"{self.socket_base_uri}/{connection_id}/send" + + def disconnect_callback(self, connection_id: str) -> str: + """Return the callback URI for disconnecting a connected socket.""" + return f"{self.socket_base_uri}/{connection_id}/disconnect" - async def socket_connected(self, callback_uris: dict): + def callback_uris(self, connection_id: str) -> Dict[str, str]: + """Return labelled callback URIs.""" + return { + "send": self.send_callback(connection_id), + "disconnect": self.disconnect_callback(connection_id), + } + + async def socket_connected( + self, + connection_id: str, + headers: Dict[str, str], + ): """Handle inbound socket message, with calback provided.""" http_body = { - "meta": callback_uris, + "meta": { + **self.callback_uris(connection_id), + "headers": headers, + "connection_id": connection_id, + }, } if self._connect_uri: @@ -37,11 +67,16 @@ async def socket_connected(self, callback_uris: dict): LOGGER.debug("Response: %s", response) async def inbound_socket_message( - self, callback_uris: dict, message: Union[str, bytes] + self, + connection_id: str, + message: Union[str, bytes], ): """Handle inbound socket message, with calback provided.""" http_body = { - "meta": callback_uris, + "meta": { + **self.callback_uris(connection_id), + "connection_id": connection_id, + }, "message": message.decode("utf-8") if isinstance(message, bytes) else message, } @@ -54,11 +89,15 @@ async def inbound_socket_message( else: LOGGER.debug("Response: %s", response) - async def socket_disconnected(self, bundle: dict): + async def socket_disconnected(self, connection_id: str): """Handle socket disconnected.""" async with aiohttp.ClientSession() as session: - LOGGER.info("Notifying of disconnect: %s %s", self._disconnect_uri, bundle) - async with session.post(self._disconnect_uri, json=bundle) as resp: + LOGGER.info( + "Notifying of disconnect: %s %s", self._disconnect_uri, connection_id + ) + async with session.post( + self._disconnect_uri, json={"connection_id": connection_id} + ) as resp: response = await resp.text() if resp.status != 200: LOGGER.error("Error posting to disconnect uri: %s", response) diff --git a/socketdock/testbackend.py b/socketdock/testbackend.py index dcacc53..ed43de3 100644 --- a/socketdock/testbackend.py +++ b/socketdock/testbackend.py @@ -1,6 +1,6 @@ """Test backend for SocketDock.""" -from typing import Union +from typing import Dict, Union import aiohttp from .backend import Backend @@ -9,27 +9,33 @@ class TestBackend(Backend): """Test backend for SocketDock.""" - async def socket_connected(self, callback_uris: dict): + def __init__(self, base_uri: str): + """Initialize backend.""" + self.base_uri = base_uri + + async def socket_connected( + self, + connection_id: str, + headers: Dict[str, str], + ): """Socket connected. This test backend doesn't care, but can be useful to clean up state. """ async def inbound_socket_message( - self, callback_uris: dict, message: Union[str, bytes] + self, + connection_id: str, + message: Union[str, bytes], ): """Receive socket message.""" - # send three backend messages in response - # TODO: send response message via callback URI for sending a message - send_uri = callback_uris["send"] + send_uri = f"{self.base_uri}/{connection_id}/send" async with aiohttp.ClientSession() as session: async with session.post(send_uri, data="Hello yourself") as resp: response = await resp.text() print(response) - # response = requests.post(send_uri, data="Hello yourself!") - - async def socket_disconnected(self, bundle: dict): + async def socket_disconnected(self, connection_id: str): """Socket disconnected. This test backend doesn't care, but can be useful to clean up state.