diff --git a/jumpstarter/client/lease.py b/jumpstarter/client/lease.py index abfaa5ac..6214497b 100644 --- a/jumpstarter/client/lease.py +++ b/jumpstarter/client/lease.py @@ -1,4 +1,5 @@ -from contextlib import AbstractContextManager, contextmanager +from asyncio.exceptions import InvalidStateError +from contextlib import AbstractContextManager, asynccontextmanager, contextmanager from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory @@ -70,32 +71,27 @@ class Lease: def __post_init__(self, *args): jumpstarter_pb2_grpc.ControllerServiceStub.__init__(self, self.channel) - @contextmanager - def connect(self): - response = self.portal.call(self.Dial, jumpstarter_pb2.DialRequest(uuid=str(self.uuid))) + @asynccontextmanager + async def connect_async(self): + response = await self.Dial(jumpstarter_pb2.DialRequest(uuid=str(self.uuid))) with TemporaryDirectory() as tempdir: socketpath = Path(tempdir) / "socket" + async with await create_unix_listener(socketpath) as listener: + async with create_task_group() as tg: + tg.start_soon(self.__accept, listener, response) + async with await insecure_channel(f"unix://{socketpath}") as inner: + yield await client_from_channel(inner, self.portal) + tg.cancel_scope.cancel() - with self.portal.wrap_async_context_manager(self.portal.call(create_unix_listener, socketpath)) as listener: - - async def create_tg(): - return create_task_group() - - with self.portal.wrap_async_context_manager(self.portal.call(create_tg)) as tg: - - async def start_soon(): - tg.start_soon(self.__accept, listener, response) - - self.portal.call(start_soon) - - with self.portal.wrap_async_context_manager( - self.portal.call(insecure_channel, f"unix://{socketpath}") - ) as inner: - yield self.portal.call(client_from_channel, inner, self.portal) - - self.portal.call(tg.cancel_scope.cancel) + @contextmanager + def connect(self): + with self.portal.wrap_async_context_manager(self.connect_async()) as client: + yield client async def __accept(self, listener, response): - async with await listener.accept() as stream: - await connect_router_stream(response.router_endpoint, response.router_token, stream) + try: + async with await listener.accept() as stream: + await connect_router_stream(response.router_endpoint, response.router_token, stream) + except InvalidStateError: + pass diff --git a/jumpstarter/common/streams.py b/jumpstarter/common/streams.py index b0aa3906..849af8fe 100644 --- a/jumpstarter/common/streams.py +++ b/jumpstarter/common/streams.py @@ -51,8 +51,6 @@ async def forward_client_stream(router, stream, metadata): except grpc.aio.AioRpcError: # TODO: handle connection error pass - finally: - await stream.aclose() async def connect_router_stream(endpoint, token, stream):