Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: backend interface #19

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docker-compose-local.yaml → demo/docker-compose-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ version: '3'

services:
websocket-gateway:
build: .
build: ..
ports:
- "8765:8765"
volumes:
- ./server:/code
- ./wait-for-tunnel.sh:/wait-for-tunnel.sh:ro,z
entrypoint: /wait-for-tunnel.sh
command: >
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.yaml → demo/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ version: '3'

services:
websocket-gateway:
build: .
build: ..
ports:
- "8765:8765"
volumes:
- ./socketdock:/usr/src/app/socketdock:z
- ../socketdock:/usr/src/app/socketdock:z
command: >
--bindip 0.0.0.0
--backend http
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion wait-for-tunnel.sh → demo/wait-for-tunnel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ done
WS_ENDPOINT=$(curl --silent "${TUNNEL_ENDPOINT}/start" | python -c "import sys, json; print(json.load(sys.stdin)['url'])" | sed -rn 's#https?://([^/]+).*#\1#p')
echo "fetched hostname and port [$WS_ENDPOINT]"

exec "$@" --externalhostandport ${WS_ENDPOINT}
exec "$@" --externalhostandport ${WS_ENDPOINT}
7 changes: 4 additions & 3 deletions socketdock/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse
from sanic import Sanic

from .api import api, backend_var, endpoint_var
from .api import api, backend_var


def config() -> argparse.Namespace:
Expand Down Expand Up @@ -38,12 +38,13 @@ def main():
elif args.backend == "http":
from .httpbackend import HTTPBackend

backend = HTTPBackend(args.connect_uri, args.message_uri, args.disconnect_uri)
backend = HTTPBackend(
args.endpoint, args.connect_uri, args.message_uri, args.disconnect_uri
)
else:
raise ValueError("Invalid backend type")

backend_var.set(backend)
endpoint_var.set(args.endpoint)

logging.basicConfig(level=args.log_level)

Expand Down
22 changes: 5 additions & 17 deletions socketdock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .backend import Backend

backend_var: ContextVar[Backend] = ContextVar("backend")
endpoint_var: ContextVar[str] = ContextVar("endpoint")

api = Blueprint("api", url_prefix="/")

Expand Down Expand Up @@ -78,9 +77,6 @@ async def socket_handler(request: Request, websocket: Websocket):
global lifetime_connections
backend = backend_var.get()
socket_id = None
endpoint = endpoint_var.get()
send = f"{endpoint}/socket/{socket_id}/send"
disconnect = f"{endpoint_var.get()}/socket/{socket_id}/disconnect"
try:
# register user
LOGGER.info("new client connected")
Expand All @@ -92,23 +88,15 @@ async def socket_handler(request: Request, websocket: Websocket):
LOGGER.info("Request headers: %s", dict(request.headers.items()))

await backend.socket_connected(
{
"connection_id": socket_id,
"headers": dict(request.headers.items()),
"send": send,
"disconnect": disconnect,
},
connection_id=socket_id,
headers=dict(request.headers.items()),
)

async for message in websocket:
if message:
await backend.inbound_socket_message(
{
"connection_id": socket_id,
"send": send,
"disconnect": disconnect,
},
message,
connection_id=socket_id,
message=message,
)
else:
LOGGER.warning("empty message received")
Expand All @@ -118,4 +106,4 @@ async def socket_handler(request: Request, websocket: Websocket):
if socket_id:
del active_connections[socket_id]
LOGGER.info("Removed connection: %s", socket_id)
await backend.socket_disconnected({"connection_id": socket_id})
await backend.socket_disconnected(socket_id)
17 changes: 10 additions & 7 deletions socketdock/backend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
"""Backend interface for SocketDock."""

from abc import ABC, abstractmethod
from typing import Union
from typing import Dict, Union


class Backend(ABC):
"""Backend interface for SocketDock."""

@abstractmethod
async def socket_connected(self, callback_uris: dict):
async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Handle new socket connections, with calback provided."""
raise NotImplementedError()

@abstractmethod
async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Handle inbound socket message, with calback provided."""
raise NotImplementedError()

@abstractmethod
async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Handle socket disconnected."""
raise NotImplementedError()
57 changes: 48 additions & 9 deletions socketdock/httpbackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""HTTP backend for SocketDock."""

import logging
from typing import Union
from typing import Dict, Union

import aiohttp

Expand All @@ -14,16 +14,46 @@
class HTTPBackend(Backend):
"""HTTP backend for SocketDock."""

def __init__(self, connect_uri: str, message_uri: str, disconnect_uri: str):
def __init__(
self,
socket_base_uri: str,
connect_uri: str,
message_uri: str,
disconnect_uri: str,
):
"""Initialize HTTP backend."""
self._connect_uri = connect_uri
self._message_uri = message_uri
self._disconnect_uri = disconnect_uri
self.socket_base_uri = socket_base_uri

def send_callback(self, connection_id: str) -> str:
"""Return the callback URI for sending a message to a connected socket."""
return f"{self.socket_base_uri}/{connection_id}/send"

def disconnect_callback(self, connection_id: str) -> str:
"""Return the callback URI for disconnecting a connected socket."""
return f"{self.socket_base_uri}/{connection_id}/disconnect"

async def socket_connected(self, callback_uris: dict):
def callback_uris(self, connection_id: str) -> Dict[str, str]:
"""Return labelled callback URIs."""
return {
"send": self.send_callback(connection_id),
"disconnect": self.disconnect_callback(connection_id),
}

async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Handle inbound socket message, with calback provided."""
http_body = {
"meta": callback_uris,
"meta": {
**self.callback_uris(connection_id),
"headers": headers,
"connection_id": connection_id,
},
}

if self._connect_uri:
Expand All @@ -37,11 +67,16 @@ async def socket_connected(self, callback_uris: dict):
LOGGER.debug("Response: %s", response)

async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Handle inbound socket message, with calback provided."""
http_body = {
"meta": callback_uris,
"meta": {
**self.callback_uris(connection_id),
"connection_id": connection_id,
},
"message": message.decode("utf-8") if isinstance(message, bytes) else message,
}

Expand All @@ -54,11 +89,15 @@ async def inbound_socket_message(
else:
LOGGER.debug("Response: %s", response)

async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Handle socket disconnected."""
async with aiohttp.ClientSession() as session:
LOGGER.info("Notifying of disconnect: %s %s", self._disconnect_uri, bundle)
async with session.post(self._disconnect_uri, json=bundle) as resp:
LOGGER.info(
"Notifying of disconnect: %s %s", self._disconnect_uri, connection_id
)
async with session.post(
self._disconnect_uri, json={"connection_id": connection_id}
) as resp:
response = await resp.text()
if resp.status != 200:
LOGGER.error("Error posting to disconnect uri: %s", response)
Expand Down
24 changes: 15 additions & 9 deletions socketdock/testbackend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test backend for SocketDock."""

from typing import Union
from typing import Dict, Union
import aiohttp

from .backend import Backend
Expand All @@ -9,27 +9,33 @@
class TestBackend(Backend):
"""Test backend for SocketDock."""

async def socket_connected(self, callback_uris: dict):
def __init__(self, base_uri: str):
"""Initialize backend."""
self.base_uri = base_uri

async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Socket connected.

This test backend doesn't care, but can be useful to clean up state.
"""

async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Receive socket message."""
# send three backend messages in response
# TODO: send response message via callback URI for sending a message
send_uri = callback_uris["send"]
send_uri = f"{self.base_uri}/{connection_id}/send"
async with aiohttp.ClientSession() as session:
async with session.post(send_uri, data="Hello yourself") as resp:
response = await resp.text()
print(response)

# response = requests.post(send_uri, data="Hello yourself!")

async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Socket disconnected.

This test backend doesn't care, but can be useful to clean up state.
Expand Down
Loading