Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly signal stream initiation and termination #29

Merged
merged 11 commits into from
Aug 7, 2024
40 changes: 15 additions & 25 deletions jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from contextlib import AbstractContextManager, contextmanager
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from uuid import UUID

from anyio import create_task_group, create_unix_listener
from anyio import create_unix_listener
from anyio.from_thread import BlockingPortal
from google.protobuf import duration_pb2
from grpc.aio import Channel
Expand Down Expand Up @@ -70,30 +70,20 @@ class Lease:
def __post_init__(self, *args):
jumpstarter_pb2_grpc.ControllerServiceStub.__init__(self, self.channel)

@contextmanager
def connect(self):
response = self.portal.call(self.Dial, jumpstarter_pb2.DialRequest(uuid=str(self.uuid)))
@asynccontextmanager
async def connect_async(self):
response = await self.Dial(jumpstarter_pb2.DialRequest(uuid=str(self.uuid)))

with TemporaryDirectory() as tempdir:
socketpath = Path(tempdir) / "socket"
async with await create_unix_listener(socketpath) as listener:
async with await insecure_channel(f"unix://{socketpath}") as inner:
inner.get_state(try_to_connect=True)
async with await listener.accept() as stream:
async with connect_router_stream(response.router_endpoint, response.router_token, stream):
yield await client_from_channel(inner, self.portal)

with self.portal.wrap_async_context_manager(self.portal.call(create_unix_listener, socketpath)) as listener:

async def create_tg():
return create_task_group()

with self.portal.wrap_async_context_manager(self.portal.call(create_tg)) as tg:

async def start_soon():
tg.start_soon(self.__accept, listener, response)

self.portal.call(start_soon)

with self.portal.wrap_async_context_manager(
self.portal.call(insecure_channel, f"unix://{socketpath}")
) as inner:
yield self.portal.call(client_from_channel, inner, self.portal)

async def __accept(self, listener, response):
async with await listener.accept() as stream:
await connect_router_stream(response.router_endpoint, response.router_token, stream)
@contextmanager
def connect(self):
with self.portal.wrap_async_context_manager(self.connect_async()) as client:
yield client
101 changes: 59 additions & 42 deletions jumpstarter/common/streams.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,73 @@
import logging
from contextlib import asynccontextmanager

import grpc
from anyio import BrokenResourceError, create_memory_object_stream, create_task_group
from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream, create_task_group
from anyio.abc import ByteStream, ObjectStream
from anyio.streams.stapled import StapledObjectStream

from jumpstarter.v1 import router_pb2, router_pb2_grpc

logger = logging.getLogger(__name__)

async def forward_server_stream(request_iterator, stream):
async with create_task_group() as tg:

async def client_to_server():
try:
async for frame in request_iterator:
await stream.send(frame.payload)
except BrokenResourceError:
pass
finally:
await stream.send_eof()

tg.start_soon(client_to_server)

# server_to_client
try:
async for payload in stream:
yield router_pb2.StreamResponse(payload=payload)
except BrokenResourceError:
pass
async def encapsulate_stream(rx, cls):
try:
yield cls(frame_type=router_pb2.FRAME_TYPE_PING)
async for payload in rx:
yield cls(payload=payload)
yield cls(frame_type=router_pb2.FRAME_TYPE_GOAWAY)
except (BrokenResourceError, ClosedResourceError):
logger.debug("stream encapsulation error ignored")


async def forward_client_stream(router, stream, metadata):
async def client_to_server():
try:
async for payload in stream:
yield router_pb2.StreamRequest(payload=payload)
except BrokenResourceError:
pass

# server_to_client
async def decapsulate_stream(tx, rx, tg):
try:
async for frame in router.Stream(
client_to_server(),
metadata=metadata,
):
if not frame.payload:
break
await stream.send(frame.payload)
except grpc.aio.AioRpcError:
# TODO: handle connection error
pass
async for frame in rx:
match frame.frame_type:
case router_pb2.FRAME_TYPE_DATA:
await tx.send(frame.payload)
case router_pb2.FRAME_TYPE_GOAWAY:
if isinstance(tx, ObjectStream) or isinstance(tx, ByteStream):
await tx.send_eof()
case _:
logger.debug(f"unrecognized frame ignored: {frame}")
# ignore peer disconnect
except BrokenResourceError:
pass
logger.debug("stream decapsulation peer disconnect ignored")
# ignore rpc cancellation and internal error
except grpc.aio.AioRpcError as e:
match e.code():
case grpc.StatusCode.CANCELLED | grpc.StatusCode.INTERNAL:
logger.debug("stream decapsulation grpc error ignored")
case _:
raise
finally:
await stream.aclose()
tg.cancel_scope.cancel()


async def forward_server_stream(request_iterator, stream):
async with create_task_group() as tg:
tg.start_soon(decapsulate_stream, stream, request_iterator, tg)

async for v in encapsulate_stream(stream, router_pb2.StreamResponse):
yield v


@asynccontextmanager
async def forward_client_stream(router, stream, metadata):
response_iterator = router.Stream(
encapsulate_stream(stream, router_pb2.StreamRequest),
metadata=metadata,
)

async with create_task_group() as tg:
tg.start_soon(decapsulate_stream, stream, response_iterator, tg)
yield
tg.cancel_scope.cancel()


@asynccontextmanager
async def connect_router_stream(endpoint, token, stream):
credentials = grpc.composite_channel_credentials(
grpc.local_channel_credentials(), # TODO: Use TLS
Expand All @@ -61,7 +76,9 @@ async def connect_router_stream(endpoint, token, stream):

async with grpc.aio.secure_channel(endpoint, credentials) as channel:
router = router_pb2_grpc.RouterServiceStub(channel)
await forward_client_stream(router, stream, ())

async with forward_client_stream(router, stream, ()):
yield


def create_memory_stream():
Expand Down
4 changes: 3 additions & 1 deletion jumpstarter/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@ async def Stream(self, request_iterator, context):
self.resources[resource_uuid] = resource

await resource.send(str(resource_uuid).encode("utf-8"))
await resource.send_eof()

async with remote:
async for v in forward_server_stream(request_iterator, remote):
yield v

del self.resources[resource_uuid]
# del self.resources[resource_uuid]
# small resources might be fully buffered in memory
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by this?

Copy link
Collaborator Author

@NickCao NickCao Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence of event:

  1. client opens a stream for sharing resource
  2. exporter accepts the stream and registers the other end of the stream as a resource
  3. exporter starts task to forward data from the client side of the stream to the exporter side of the stream
    3.1 (actually happened) stream is copied into internal buffer before being consumed, forward tasks exits, resource unregistered
  4. (originally intended) driver get the stream from the dict and begins consuming the stream
    4.1 (actually happened) driver fails to find resource
  5. (originally intended) stream is fully consumed by the driver, forward tasks exits, resource unregistered

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR: resource shouldn't be unregistered when forwarding task exits, but only when it's fully consumed (or client decides to cancel)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation I see now :-)


async def GetReport(self, request, context):
"""
Expand Down
39 changes: 15 additions & 24 deletions jumpstarter/drivers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass

from anyio import create_task_group
from anyio import create_task_group, sleep_forever
from anyio.streams.stapled import StapledObjectStream
from google.protobuf import json_format, struct_pb2
from grpc.aio import Channel
Expand Down Expand Up @@ -65,26 +65,24 @@ async def streamingcall_async(self, method, *args):
@asynccontextmanager
async def stream_async(self, method):
client_stream, device_stream = create_memory_stream()

async with create_task_group() as tg:
tg.start_soon(
forward_client_stream,
self,
device_stream,
{"kind": "connect", "uuid": str(self.uuid), "method": method}.items(),
)
async with forward_client_stream(
self,
device_stream,
{"kind": "connect", "uuid": str(self.uuid), "method": method}.items(),
):
async with client_stream:
yield client_stream

@asynccontextmanager
async def portforward_async(self, method, listener):
async def handle(client):
async with client:
await forward_client_stream(
async with forward_client_stream(
self,
client,
{"kind": "connect", "uuid": str(self.uuid), "method": method}.items(),
)
):
await sleep_forever()

async with create_task_group() as tg:
tg.start_soon(listener.serve, handle)
Expand All @@ -102,17 +100,10 @@ async def resource_async(

combined = StapledObjectStream(tx, ProgressStream(stream=stream))

async def handle(stream):
async with stream:
await forward_client_stream(
self,
stream,
{"kind": "resource", "uuid": str(self.uuid)}.items(),
)

async with create_task_group() as tg:
tg.start_soon(handle, combined)
try:
async with combined:
async with forward_client_stream(
self,
combined,
{"kind": "resource", "uuid": str(self.uuid)}.items(),
):
yield (await rx.receive()).decode()
finally:
tg.cancel_scope.cancel()
5 changes: 3 additions & 2 deletions jumpstarter/exporter/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tempfile import TemporaryDirectory

import grpc
from anyio import connect_unix
from anyio import connect_unix, sleep_forever

from jumpstarter.common import Metadata
from jumpstarter.common.streams import connect_router_stream
Expand Down Expand Up @@ -65,6 +65,7 @@ async def serve(self):
await server.start()

async with await connect_unix(socketpath) as stream:
await connect_router_stream(request.router_endpoint, request.router_token, stream)
async with connect_router_stream(request.router_endpoint, request.router_token, stream):
await sleep_forever()
finally:
await server.stop(grace=None)
Loading