Skip to content

Commit

Permalink
typing for sync/async decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
shughes-uk committed Nov 25, 2021
1 parent 6689c0a commit b941db3
Showing 1 changed file with 49 additions and 13 deletions.
62 changes: 49 additions & 13 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import warnings
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Dict, Optional, overload
from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, overload

from typing_extensions import ParamSpec

from .compatibility import current_task, get_running_loop
from .current_thread_executor import CurrentThreadExecutor
Expand Down Expand Up @@ -98,7 +100,11 @@ async def __aexit__(self, exc, value, tb):
pass


class AsyncToSync:
a_cls_params = ParamSpec("a_cls_params")
a_cls_return = TypeVar("a_cls_return")


class AsyncToSync(Generic[a_cls_params, a_cls_return]):
"""
Utility class which turns an awaitable that only works on the thread with
the event loop into a synchronous callable that works in a subthread.
Expand All @@ -118,7 +124,11 @@ class AsyncToSync:
# Local, not a threadlocal, so that tasks can work out what their parent used.
executors = Local()

def __init__(self, awaitable, force_new_loop=False):
def __init__(
self,
awaitable: Callable[a_cls_params, Awaitable[a_cls_return]],
force_new_loop=False,
):
if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable):
# Python does not have very reliable detection of async functions
# (lots of false negatives) so this is just a warning.
Expand Down Expand Up @@ -151,7 +161,9 @@ def __init__(self, awaitable, force_new_loop=False):
else:
self.main_event_loop = None

def __call__(self, *args, **kwargs):
def __call__(
self, *args: a_cls_params.args, **kwargs: a_cls_params.kwargs
) -> a_cls_return:
# You can't call AsyncToSync from a thread with a running event loop
try:
event_loop = get_running_loop()
Expand All @@ -172,7 +184,7 @@ def __call__(self, *args, **kwargs):
context = None

# Make a future for the return information
call_result = Future()
call_result: Future[a_cls_return] = Future()
# Get the source thread
source_thread = threading.current_thread()
# Make a CurrentThreadExecutor we'll use to idle in this thread - we
Expand Down Expand Up @@ -271,7 +283,13 @@ def __get__(self, parent, objtype):
return functools.update_wrapper(func, self.awaitable)

async def main_wrap(
self, args, kwargs, call_result, source_thread, exc_info, context
self,
args,
kwargs,
call_result: a_cls_return,
source_thread,
exc_info,
context,
):
"""
Wraps the awaitable with something that puts the result into the
Expand Down Expand Up @@ -303,7 +321,11 @@ async def main_wrap(
context[0] = contextvars.copy_context()


class SyncToAsync:
s_cls_params = ParamSpec("s_cls_params")
s_cls_return = TypeVar("s_cls_return")


class SyncToAsync(Generic[s_cls_params, s_cls_return]):
"""
Utility class which turns a synchronous callable into an awaitable that
runs in a threadpool. It also sets a threadlocal inside the thread so
Expand Down Expand Up @@ -369,7 +391,7 @@ class SyncToAsync:

def __init__(
self,
func: Callable[..., Any],
func: Callable[s_cls_params, s_cls_return],
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> None:
Expand All @@ -387,7 +409,9 @@ def __init__(
except AttributeError:
pass

async def __call__(self, *args, **kwargs):
async def __call__(
self, *args: s_cls_params.args, **kwargs: s_cls_params.kwargs
) -> s_cls_return:
loop = get_running_loop()

# Work out what thread to run the code in
Expand Down Expand Up @@ -461,7 +485,15 @@ def __get__(self, parent, objtype):
"""
return functools.partial(self.__call__, parent)

def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs):
def thread_handler(
self,
loop,
source_task,
exc_info,
func: Callable[s_cls_params, s_cls_return],
*args: s_cls_params.args,
**kwargs: s_cls_params.kwargs
):
"""
Wraps the sync application with exception handling.
"""
Expand Down Expand Up @@ -511,21 +543,25 @@ def get_current_task():
async_to_sync = AsyncToSync


s_params = ParamSpec("s_params")
s_return = TypeVar("s_return")


@overload
def sync_to_async(
func: None = None,
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> Callable[[Callable[..., Any]], SyncToAsync]:
) -> Callable[[Callable[s_params, s_return]], SyncToAsync[s_params, s_return]]:
...


@overload
def sync_to_async(
func: Callable[..., Any],
func: Callable[s_params, s_return],
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> SyncToAsync:
) -> SyncToAsync[s_params, s_return]:
...


Expand Down

0 comments on commit b941db3

Please sign in to comment.