From b941db32b91e85a65398c4022d51a44656fd3cf4 Mon Sep 17 00:00:00 2001 From: Samantha Hughes Date: Sat, 25 Sep 2021 12:29:34 -0700 Subject: [PATCH] typing for sync/async decorators --- asgiref/sync.py | 62 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index 3710a7f1..1a33c60f 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -7,7 +7,9 @@ import warnings import weakref from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, overload +from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, overload + +from typing_extensions import ParamSpec from .compatibility import current_task, get_running_loop from .current_thread_executor import CurrentThreadExecutor @@ -98,7 +100,11 @@ async def __aexit__(self, exc, value, tb): pass -class AsyncToSync: +a_cls_params = ParamSpec("a_cls_params") +a_cls_return = TypeVar("a_cls_return") + + +class AsyncToSync(Generic[a_cls_params, a_cls_return]): """ Utility class which turns an awaitable that only works on the thread with the event loop into a synchronous callable that works in a subthread. @@ -118,7 +124,11 @@ class AsyncToSync: # Local, not a threadlocal, so that tasks can work out what their parent used. executors = Local() - def __init__(self, awaitable, force_new_loop=False): + def __init__( + self, + awaitable: Callable[a_cls_params, Awaitable[a_cls_return]], + force_new_loop=False, + ): if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable): # Python does not have very reliable detection of async functions # (lots of false negatives) so this is just a warning. @@ -151,7 +161,9 @@ def __init__(self, awaitable, force_new_loop=False): else: self.main_event_loop = None - def __call__(self, *args, **kwargs): + def __call__( + self, *args: a_cls_params.args, **kwargs: a_cls_params.kwargs + ) -> a_cls_return: # You can't call AsyncToSync from a thread with a running event loop try: event_loop = get_running_loop() @@ -172,7 +184,7 @@ def __call__(self, *args, **kwargs): context = None # Make a future for the return information - call_result = Future() + call_result: Future[a_cls_return] = Future() # Get the source thread source_thread = threading.current_thread() # Make a CurrentThreadExecutor we'll use to idle in this thread - we @@ -271,7 +283,13 @@ def __get__(self, parent, objtype): return functools.update_wrapper(func, self.awaitable) async def main_wrap( - self, args, kwargs, call_result, source_thread, exc_info, context + self, + args, + kwargs, + call_result: a_cls_return, + source_thread, + exc_info, + context, ): """ Wraps the awaitable with something that puts the result into the @@ -303,7 +321,11 @@ async def main_wrap( context[0] = contextvars.copy_context() -class SyncToAsync: +s_cls_params = ParamSpec("s_cls_params") +s_cls_return = TypeVar("s_cls_return") + + +class SyncToAsync(Generic[s_cls_params, s_cls_return]): """ Utility class which turns a synchronous callable into an awaitable that runs in a threadpool. It also sets a threadlocal inside the thread so @@ -369,7 +391,7 @@ class SyncToAsync: def __init__( self, - func: Callable[..., Any], + func: Callable[s_cls_params, s_cls_return], thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, ) -> None: @@ -387,7 +409,9 @@ def __init__( except AttributeError: pass - async def __call__(self, *args, **kwargs): + async def __call__( + self, *args: s_cls_params.args, **kwargs: s_cls_params.kwargs + ) -> s_cls_return: loop = get_running_loop() # Work out what thread to run the code in @@ -461,7 +485,15 @@ def __get__(self, parent, objtype): """ return functools.partial(self.__call__, parent) - def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs): + def thread_handler( + self, + loop, + source_task, + exc_info, + func: Callable[s_cls_params, s_cls_return], + *args: s_cls_params.args, + **kwargs: s_cls_params.kwargs + ): """ Wraps the sync application with exception handling. """ @@ -511,21 +543,25 @@ def get_current_task(): async_to_sync = AsyncToSync +s_params = ParamSpec("s_params") +s_return = TypeVar("s_return") + + @overload def sync_to_async( func: None = None, thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> Callable[[Callable[..., Any]], SyncToAsync]: +) -> Callable[[Callable[s_params, s_return]], SyncToAsync[s_params, s_return]]: ... @overload def sync_to_async( - func: Callable[..., Any], + func: Callable[s_params, s_return], thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> SyncToAsync: +) -> SyncToAsync[s_params, s_return]: ...