diff --git a/loguru/_logger.py b/loguru/_logger.py index 911dd6043..c9a98cc8e 100644 --- a/loguru/_logger.py +++ b/loguru/_logger.py @@ -114,10 +114,16 @@ from ._simple_sinks import AsyncSink, CallableSink, StandardSink, StreamSink if sys.version_info >= (3, 6): + from collections.abc import AsyncGenerator + from inspect import isasyncgenfunction from os import PathLike + else: from pathlib import PurePath as PathLike + def isasyncgenfunction(func): + return False + Level = namedtuple("Level", ["name", "no", "color", "icon"]) # noqa: PYI024 @@ -1293,6 +1299,34 @@ def catch_wrapper(*args, **kwargs): return (yield from function(*args, **kwargs)) return default + elif isasyncgenfunction(function): + + class AsyncGenCatchWrapper(AsyncGenerator): + + def __init__(self, gen): + self._gen = gen + + async def asend(self, value): + stop = False + with catcher: + try: + return await self._gen.asend(value) + except StopAsyncIteration: + stop = True + except Exception as e: + stop = True + raise e + if stop: + raise StopAsyncIteration + return None + + async def athrow(self, *args, **kwargs): + return await self._gen.athrow(*args, **kwargs) + + def catch_wrapper(*args, **kwargs): + gen = function(*args, **kwargs) + return AsyncGenCatchWrapper(gen) + else: def catch_wrapper(*args, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index 269b914eb..c689c0114 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,9 @@ convention = "numpy" [tool.typos.default] extend-ignore-re = ["(?Rm)^.*# spellchecker: disable-line$"] +[tool.typos.default.extend-identifiers] +asend = "asend" + [tool.typos.files] extend-exclude = [ "tests/exceptions/output/**", # False positive due to ansi sequences. diff --git a/tests/exceptions/output/modern/decorate_async_generator.txt b/tests/exceptions/output/modern/decorate_async_generator.txt new file mode 100644 index 000000000..a965a70ed --- /dev/null +++ b/tests/exceptions/output/modern/decorate_async_generator.txt @@ -0,0 +1 @@ +Done diff --git a/tests/exceptions/output/modern/exception_formatting_async_generator.txt b/tests/exceptions/output/modern/exception_formatting_async_generator.txt new file mode 100644 index 000000000..c5fdfab61 --- /dev/null +++ b/tests/exceptions/output/modern/exception_formatting_async_generator.txt @@ -0,0 +1,42 @@ + +Traceback (most recent call last): + File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 20, in + f.send(None) + File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 14, in foo + yield a / b +ZeroDivisionError: division by zero + +Traceback (most recent call last): + + File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 20, in + f.send(None) + │ └ + └ .Catcher.__call__..AsyncGenCatchWrapper.asend at 0xDEADBEEF> + + File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 14, in foo + yield a / b + │ └ 0 + └ 1 + +ZeroDivisionError: division by zero + +Traceback (most recent call last): +> File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 20, in + f.send(None) + File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 14, in foo + yield a / b +ZeroDivisionError: division by zero + +Traceback (most recent call last): + +> File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 20, in + f.send(None) + │ └ + └ .Catcher.__call__..AsyncGenCatchWrapper.asend at 0xDEADBEEF> + + File "tests/exceptions/source/modern/exception_formatting_async_generator.py", line 14, in foo + yield a / b + │ └ 0 + └ 1 + +ZeroDivisionError: division by zero diff --git a/tests/exceptions/source/modern/decorate_async_generator.py b/tests/exceptions/source/modern/decorate_async_generator.py new file mode 100644 index 000000000..b1dffed7c --- /dev/null +++ b/tests/exceptions/source/modern/decorate_async_generator.py @@ -0,0 +1,111 @@ +from loguru import logger +import asyncio +import sys + +logger.remove() + +# We're truly only testing whether the tests succeed, we do not care about the formatting. +# These should be regular Pytest test cases, but that is not possible because the syntax is not valid in Python 3.5. +logger.add(lambda m: None, format="", diagnose=True, backtrace=True, colorize=True) + +def test_decorate_async_generator(): + @logger.catch(reraise=True) + async def generator(x, y): + yield x + yield y + + async def coro(): + out = [] + async for val in generator(1, 2): + out.append(val) + return out + + res = asyncio.run(coro()) + assert res == [1, 2] + + +def test_decorate_async_generator_with_error(): + @logger.catch(reraise=False) + async def generator(x, y): + yield x + yield y + raise ValueError + + async def coro(): + out = [] + async for val in generator(1, 2): + out.append(val) + return out + + res = asyncio.run(coro()) + assert res == [1, 2] + +def test_decorate_async_generator_with_error_reraised(): + @logger.catch(reraise=True) + async def generator(x, y): + yield x + yield y + raise ValueError + + async def coro(): + out = [] + try: + async for val in generator(1, 2): + out.append(val) + except ValueError: + pass + else: + raise AssertionError("ValueError not raised") + return out + + res = asyncio.run(coro()) + assert res == [1, 2] + + +def test_decorate_async_generator_then_async_send(): + @logger.catch + async def generator(x, y): + yield x + yield y + + async def coro(): + gen = generator(1, 2) + await gen.asend(None) + await gen.asend(None) + try: + await gen.asend(None) + except StopAsyncIteration: + pass + else: + raise AssertionError("StopAsyncIteration not raised") + + asyncio.run(coro()) + + +def test_decorate_async_generator_then_async_throw(): + @logger.catch + async def generator(x, y): + yield x + yield y + + async def coro(): + gen = generator(1, 2) + await gen.asend(None) + try: + await gen.athrow(ValueError) + except ValueError: + pass + else: + raise AssertionError("ValueError not raised") + + asyncio.run(coro()) + + +test_decorate_async_generator() +test_decorate_async_generator_with_error() +test_decorate_async_generator_with_error_reraised() +test_decorate_async_generator_then_async_send() +test_decorate_async_generator_then_async_throw() + +logger.add(sys.stderr, format="{message}") +logger.info("Done") diff --git a/tests/exceptions/source/modern/exception_formatting_async_generator.py b/tests/exceptions/source/modern/exception_formatting_async_generator.py new file mode 100644 index 000000000..120bc4cf0 --- /dev/null +++ b/tests/exceptions/source/modern/exception_formatting_async_generator.py @@ -0,0 +1,22 @@ +import sys + +from loguru import logger + +logger.remove() +logger.add(sys.stderr, format="", diagnose=False, backtrace=False, colorize=False) +logger.add(sys.stderr, format="", diagnose=True, backtrace=False, colorize=False) +logger.add(sys.stderr, format="", diagnose=False, backtrace=True, colorize=False) +logger.add(sys.stderr, format="", diagnose=True, backtrace=True, colorize=False) + + +@logger.catch +async def foo(a, b): + yield a / b + + +f = foo(1, 0).asend(None) + +try: + f.send(None) +except StopAsyncIteration: + pass diff --git a/tests/test_exceptions_catch.py b/tests/test_exceptions_catch.py index 1e7aea97e..71ebc2917 100644 --- a/tests/test_exceptions_catch.py +++ b/tests/test_exceptions_catch.py @@ -407,9 +407,9 @@ def foo(x, y, z): def test_decorate_generator_with_error(): @logger.catch def foo(): - for i in range(3): - 1 / (2 - i) - yield i + yield 0 + yield 1 + raise ValueError assert list(foo()) == [0, 1] diff --git a/tests/test_exceptions_formatting.py b/tests/test_exceptions_formatting.py index 47e05cff2..46ccfd340 100644 --- a/tests/test_exceptions_formatting.py +++ b/tests/test_exceptions_formatting.py @@ -23,16 +23,21 @@ def normalize(exception): def fix_filepath(match): filepath = match.group(1) + + # Pattern to check if the filepath contains ANSI escape codes. pattern = ( r'((?:\x1b\[[0-9]*m)+)([^"]+?)((?:\x1b\[[0-9]*m)+)([^"]+?)((?:\x1b\[[0-9]*m)+)' ) + match = re.match(pattern, filepath) start_directory = os.path.dirname(os.path.dirname(__file__)) if match: + # Simplify the path while preserving the color highlighting of the file basename. groups = list(match.groups()) groups[1] = os.path.relpath(os.path.abspath(groups[1]), start_directory) + "/" relpath = "".join(groups) else: + # We can straightforwardly convert from absolute to relative path. relpath = os.path.relpath(os.path.abspath(filepath), start_directory) return 'File "%s"' % relpath.replace("\\", "/") @@ -241,6 +246,8 @@ def test_exception_others(filename): ("filename", "minimum_python_version"), [ ("type_hints", (3, 6)), + ("decorate_async_generator", (3, 6)), + ("exception_formatting_async_generator", (3, 6)), ("positional_only_argument", (3, 8)), ("walrus_operator", (3, 8)), ("match_statement", (3, 10)),