Skip to content

Commit

Permalink
Fix stream handling under lease
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCao committed Aug 1, 2024
1 parent 97aeb76 commit 62cdfa6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 26 deletions.
44 changes: 20 additions & 24 deletions jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import AbstractContextManager, contextmanager
from asyncio.exceptions import InvalidStateError
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -70,32 +71,27 @@ class Lease:
def __post_init__(self, *args):
jumpstarter_pb2_grpc.ControllerServiceStub.__init__(self, self.channel)

@contextmanager
def connect(self):
response = self.portal.call(self.Dial, jumpstarter_pb2.DialRequest(uuid=str(self.uuid)))
@asynccontextmanager
async def connect_async(self):
response = await self.Dial(jumpstarter_pb2.DialRequest(uuid=str(self.uuid)))

with TemporaryDirectory() as tempdir:
socketpath = Path(tempdir) / "socket"
async with await create_unix_listener(socketpath) as listener:
async with create_task_group() as tg:
tg.start_soon(self.__accept, listener, response)
async with await insecure_channel(f"unix://{socketpath}") as inner:
yield await client_from_channel(inner, self.portal)
tg.cancel_scope.cancel()

with self.portal.wrap_async_context_manager(self.portal.call(create_unix_listener, socketpath)) as listener:

async def create_tg():
return create_task_group()

with self.portal.wrap_async_context_manager(self.portal.call(create_tg)) as tg:

async def start_soon():
tg.start_soon(self.__accept, listener, response)

self.portal.call(start_soon)

with self.portal.wrap_async_context_manager(
self.portal.call(insecure_channel, f"unix://{socketpath}")
) as inner:
yield self.portal.call(client_from_channel, inner, self.portal)

self.portal.call(tg.cancel_scope.cancel)
@contextmanager
def connect(self):
with self.portal.wrap_async_context_manager(self.connect_async()) as client:
yield client

async def __accept(self, listener, response):
async with await listener.accept() as stream:
await connect_router_stream(response.router_endpoint, response.router_token, stream)
try:
async with await listener.accept() as stream:
await connect_router_stream(response.router_endpoint, response.router_token, stream)
except InvalidStateError:
pass
2 changes: 0 additions & 2 deletions jumpstarter/common/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ async def forward_client_stream(router, stream, metadata):
except grpc.aio.AioRpcError:
# TODO: handle connection error
pass
finally:
await stream.aclose()


async def connect_router_stream(endpoint, token, stream):
Expand Down

0 comments on commit 62cdfa6

Please sign in to comment.