Skip to content

Commit

Permalink
Add ability to set awaiters on coroutines and futures
Browse files Browse the repository at this point in the history
Summary:
Fixed a refleak in the original change. This is essentially a squashed version of two diffs: D67813375 and D68596487.

Landing it together for safety.

Reviewed By: aleivag

Differential Revision: D68597387

fbshipit-source-id: 2bfe721a5cf74cd71a9782718da557d1fee9444a
  • Loading branch information
Aniket Panse authored and facebook-github-bot committed Jan 24, 2025
1 parent ae66b18 commit a8cfab6
Show file tree
Hide file tree
Showing 10 changed files with 927 additions and 23 deletions.
11 changes: 11 additions & 0 deletions Include/cpython/genobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
extern "C" {
#endif

static inline void Ci_PyAwaitable_SetAwaiter(PyObject *receiver, PyObject *awaiter) {
PyTypeObject *ty = Py_TYPE(receiver);
if (!PyType_HasFeature(ty, Ci_TPFLAGS_HAVE_AM_EXTRA)) {
return;
}
Ci_AsyncMethodsWithExtra *ame = (Ci_AsyncMethodsWithExtra *)ty->tp_as_async;
if ((ame != NULL) && (ame->ame_setawaiter != NULL)) {
ame->ame_setawaiter(receiver, awaiter);
}
}

/* --- Generators --------------------------------------------------------- */

/* _PyGenObject_HEAD defines the initial segment of generator
Expand Down
63 changes: 63 additions & 0 deletions Lib/asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'current_task', 'all_tasks',
'create_eager_task_factory', 'eager_task_factory',
'_register_task', '_unregister_task', '_enter_task', '_leave_task',
'get_async_stack',
)

import concurrent.futures
Expand All @@ -16,6 +17,7 @@
import inspect
import itertools
import types
import sys
import warnings
import weakref
from types import GenericAlias
Expand Down Expand Up @@ -732,6 +734,11 @@ def cancel(self, msg=None):
self._cancel_requested = True
return ret

def __set_awaiter__(self, awaiter):
for child in self._children:
if hasattr(child, "__set_awaiter__"):
child.__set_awaiter__(awaiter)


def gather(*coros_or_futures, return_exceptions=False):
"""Return a future aggregating results from the given coroutines/futures.
Expand Down Expand Up @@ -956,6 +963,62 @@ def callback():
return future


def get_async_stack():
"""Return the async call stack for the currently executing task as a list of
frames, with the most recent frame last.
The async call stack consists of the call stack for the currently executing
task, if any, plus the call stack formed by the transitive set of coroutines/async
generators awaiting the current task.
Consider the following example, where T represents a task, C represents
a coroutine, and A '->' B indicates A is awaiting B.
T0 +---> T1
| | |
C0 | C2
| | |
v | v
C1 | C3
| |
+-----|
The await stack from C3 would be C3, C2, C1, C0. In contrast, the
synchronous call stack while C3 is executing is only C3, C2.
"""
if not hasattr(sys, "_getframe"):
return []

task = current_task()
coro = task.get_coro()
coro_frame = coro.cr_frame

# Get the active portion of the stack
stack = []
frame = sys._getframe().f_back
while frame is not None:
stack.append(frame)
if frame is coro_frame:
break
frame = frame.f_back
assert frame is coro_frame

# Get the suspended portion of the stack
awaiter = coro.cr_awaiter
while awaiter is not None:
if hasattr(awaiter, "cr_frame"):
stack.append(awaiter.cr_frame)
awaiter = awaiter.cr_awaiter
elif hasattr(awaiter, "ag_frame"):
stack.append(awaiter.ag_frame)
awaiter = awaiter.ag_awaiter
else:
raise ValueError(f"Unexpected awaiter {awaiter}")

stack.reverse()
return stack


# WeakSet containing all alive tasks.
_all_tasks = weakref.WeakSet()


def create_eager_task_factory(custom_task_constructor):
"""Create a function suitable for use as a task factory on an event-loop.
Expand Down
25 changes: 25 additions & 0 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,5 +1903,30 @@ async def run():
self.loop.run_until_complete(run())


class AsyncGeneratorAwaiterTest(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
asyncio.set_event_loop_policy(None)

def test_basic_await(self):
async def async_gen():
self.assertIs(agen_obj.ag_awaiter, awaiter_obj)
yield 10

async def awaiter(agen):
async for x in agen:
pass

agen_obj = async_gen()
awaiter_obj = awaiter(agen_obj)
self.assertIsNone(agen_obj.ag_awaiter)
self.loop.run_until_complete(awaiter_obj)


if __name__ == "__main__":
unittest.main()
86 changes: 86 additions & 0 deletions Lib/test/test_asyncio/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,24 @@ def test_get_context(self):
finally:
loop.close()

def test_get_awaiter(self):
ctask = getattr(tasks, '_CTask', None)
if ctask is None or not issubclass(self.Task, ctask):
self.skipTest("Only subclasses of _CTask set cr_awaiter on wrapped coroutines")

async def coro():
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
return "ok"

async def awaiter(coro):
task = self.loop.create_task(coro)
return await task

coro_obj = coro()
awaiter_obj = awaiter(coro_obj)
self.assertIsNone(coro_obj.cr_awaiter)
self.assertEqual(self.loop.run_until_complete(awaiter_obj), "ok")
self.assertIsNone(coro_obj.cr_awaiter)

def add_subclass_tests(cls):
BaseTask = cls.Task
Expand Down Expand Up @@ -3237,6 +3255,22 @@ async def coro(s):
# NameError should not happen:
self.one_loop.call_exception_handler.assert_not_called()

def test_propagate_awaiter(self):
async def coro(idx):
self.assertIs(coro_objs[idx].cr_awaiter, awaiter_obj)
return "ok"

async def awaiter(coros):
tasks = [self.one_loop.create_task(c) for c in coros]
return await asyncio.gather(*tasks)

coro_objs = [coro(0), coro(1)]
awaiter_obj = awaiter(coro_objs)
self.assertIsNone(coro_objs[0].cr_awaiter)
self.assertIsNone(coro_objs[1].cr_awaiter)
self.assertEqual(self.one_loop.run_until_complete(awaiter_obj), ["ok", "ok"])
self.assertIsNone(coro_objs[0].cr_awaiter)
self.assertIsNone(coro_objs[1].cr_awaiter)

class RunCoroutineThreadsafeTests(test_utils.TestCase):
"""Test case for asyncio.run_coroutine_threadsafe."""
Expand Down Expand Up @@ -3449,5 +3483,57 @@ def tearDown(self):
super().tearDown()



class GetAsyncStackTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
asyncio.set_event_loop_policy(None)

def check_stack(self, frames, expected_funcs):
given = [f.f_code for f in frames]
expected = [f.__code__ for f in expected_funcs]
self.assertEqual(given, expected)

def test_single_task(self):
async def coro():
await coro2()

async def coro2():
stack = asyncio.get_async_stack()
self.check_stack(stack, [coro, coro2])

self.loop.run_until_complete(coro())

def test_cross_tasks(self):
async def coro():
t = asyncio.ensure_future(coro2())
await t

async def coro2():
t = asyncio.ensure_future(coro3())
await t

async def coro3():
stack = asyncio.get_async_stack()
self.check_stack(stack, [coro, coro2, coro3])

self.loop.run_until_complete(coro())

def test_cross_gather(self):
async def coro():
await asyncio.gather(coro2(), coro2())

async def coro2():
stack = asyncio.get_async_stack()
self.check_stack(stack, [coro, coro2])

self.loop.run_until_complete(coro())


if __name__ == '__main__':
unittest.main()
58 changes: 58 additions & 0 deletions Lib/test/test_coroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2478,5 +2478,63 @@ async def foo():
self.assertEqual(foo().send(None), 1)



class CoroutineAwaiterTest(unittest.TestCase):
def test_basic_await(self):
async def coro():
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)
return "success"

async def awaiter():
return await coro_obj

coro_obj = coro()
awaiter_obj = awaiter()
self.assertIsNone(coro_obj.cr_awaiter)
self.assertEqual(run_async(awaiter_obj), ([], "success"))

class FakeFuture:
def __await__(self):
return iter(["future"])

def test_coro_outlives_awaiter(self):
async def coro():
await self.FakeFuture()

async def awaiter(cr):
await cr

coro_obj = coro()
self.assertIsNone(coro_obj.cr_awaiter)
awaiter_obj = awaiter(coro_obj)
self.assertIsNone(coro_obj.cr_awaiter)

v1 = awaiter_obj.send(None)
self.assertEqual(v1, "future")
self.assertIs(coro_obj.cr_awaiter, awaiter_obj)

awaiter_id = id(awaiter_obj)
del awaiter_obj
self.assertEqual(id(coro_obj.cr_awaiter), awaiter_id)

def test_async_gen_awaiter(self):
async def coro():
self.assertIs(coro_obj.cr_awaiter, agen)
await self.FakeFuture()

async def async_gen(cr):
await cr
yield "hi"

coro_obj = coro()
self.assertIsNone(coro_obj.cr_awaiter)
agen = async_gen(coro_obj)
self.assertIsNone(coro_obj.cr_awaiter)

v1 = agen.asend(None).send(None)
self.assertEqual(v1, "future")



if __name__=="__main__":
unittest.main()
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ Noah Oxer
Joonas Paalasmaa
Yaroslav Pankovych
Martin Packman
Matt Page
Elisha Paine
Shriphani Palakodety
Julien Palard
Expand Down
Loading

0 comments on commit a8cfab6

Please sign in to comment.