diff --git a/src/murfey/server/websocket.py b/src/murfey/server/websocket.py index 614b5139..73a565d9 100644 --- a/src/murfey/server/websocket.py +++ b/src/murfey/server/websocket.py @@ -7,9 +7,9 @@ from typing import Any, Dict, Generic, TypeVar from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from sqlmodel import select +from sqlmodel import Session, create_engine, select -from murfey.server.murfey_db import get_murfey_db_session +from murfey.server.murfey_db import get_murfey_db_session, url from murfey.util.db import ClientEnvironment from murfey.util.state import State, global_state @@ -31,23 +31,25 @@ async def connect(self, websocket: WebSocket, client_id: int): self._register_new_client(client_id) await websocket.send_json({"message": "state-full", "state": self._state.data}) - @staticmethod - def _register_new_client(client_id: int): + def _register_new_client(self, client_id: int): + engine = create_engine(url()) new_client = ClientEnvironment(client_id=client_id, connected=True) murfey_db = next(get_murfey_db_session()) - murfey_db.add(new_client) - murfey_db.commit() - murfey_db.close() + with Session(engine) as murfey_db: + murfey_db.add(new_client) + murfey_db.commit() def disconnect(self, websocket: WebSocket, client_id: int): + engine = create_engine(url()) self.active_connections.pop(client_id) - murfey_db = next(get_murfey_db_session()) - client_env = murfey_db.exec( - select(ClientEnvironment).where(ClientEnvironment.client_id == client_id) - ).one() - murfey_db.delete(client_env) - murfey_db.commit() - murfey_db.close() + with Session(engine) as murfey_db: + client_env = murfey_db.exec( + select(ClientEnvironment).where( + ClientEnvironment.client_id == client_id + ) + ).one() + murfey_db.delete(client_env) + murfey_db.commit() async def broadcast(self, message: str): for connection in self.active_connections: @@ -119,14 +121,15 @@ async def forward_log(logrecord: dict[str, Any], websocket: WebSocket): @ws.delete("/test/{client_id}") async def close_ws_connection(client_id: int): + engine = create_engine(url()) murfey_db = next(get_murfey_db_session()) - client_env = murfey_db.exec( - select(ClientEnvironment).where(ClientEnvironment.client_id == client_id) - ).one() - client_env.connected = False - murfey_db.add(client_env) - murfey_db.commit() - murfey_db.close() + with Session(engine) as murfey_db: + client_env = murfey_db.exec( + select(ClientEnvironment).where(ClientEnvironment.client_id == client_id) + ).one() + client_env.connected = False + murfey_db.add(client_env) + murfey_db.commit() client_id_str = str(client_id).replace("\r\n", "").replace("\n", "") log.info(f"Disconnecting {client_id_str}") manager.disconnect(manager.active_connections[client_id], client_id)