Skip to content

Commit

Permalink
revert to unified container_context and fix error with same context v…
Browse files Browse the repository at this point in the history
…ar token when container_context used as decorator (#94)

* revert to unified container_context and fix error with same context var token when container_context used as decorator

---------

Co-authored-by: artur.shiriev <[email protected]>
  • Loading branch information
lesnik512 and artur.shiriev authored Sep 28, 2024
1 parent ad3a43f commit 47ecbb8
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 46 deletions.
20 changes: 12 additions & 8 deletions docs/providers/context-resources.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ value = MyContainer.sync_resource.sync_resolve()
> RuntimeError: Context is not set. Use container_context
```

### Resolving sync dependencies:
`container_context` implements only `AsyncContextManager`.
For sync context and `ContextManager` use `sync_container_context`.
### Resolving async and sync dependencies:
``container_context`` implements both ``AsyncContextManager`` and ``ContextManager``.
This means that you can enter an async context with:

```python
async with container_context():
...
```
An async context will allow resolution of both sync and async dependencies.

A sync context can be entered using:
```python
with sync_container_context():
with container_context():
...
```
A sync context will only allow resolution of sync dependencies:
```python
async def my_func():
with sync_container_context(): # enter sync context
with container_context(): # enter sync context
# try to resolve async dependency.
await MyContainer.async_resource.async_resolve()

Expand All @@ -63,7 +70,6 @@ async def my_func():
Each time you enter `container_context` a new context is created in the background.
Resources are cached in the context after first resolution.
Resources created in a context are torn down again when `container_context` exits.
Same for `sync_container_context`.
```python
async with container_context():
value_outer = await MyContainer.resource.async_resolve()
Expand All @@ -84,5 +90,3 @@ async def insert_into_database(session = Provide[MyContainer.session]):
...
```
Each time ``await insert_into_database()`` is called new instance of ``session`` will be injected.

Same for `sync_container_context`.
38 changes: 28 additions & 10 deletions tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import pytest

from that_depends import BaseContainer, fetch_context_item, providers
from that_depends import BaseContainer, Provide, fetch_context_item, inject, providers
from that_depends.providers import container_context
from that_depends.providers.base import ResourceContext
from that_depends.providers.context_resources import sync_container_context


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,18 +76,15 @@ async def test_context_resource(context_resource: providers.ContextResource[str]
assert await context_resource() is context_resource_result


@sync_container_context()
@container_context()
def test_sync_context_resource(sync_context_resource: providers.ContextResource[str]) -> None:
context_resource_result = sync_context_resource.sync_resolve()

assert sync_context_resource.sync_resolve() is context_resource_result


async def test_async_context_resource_in_sync_context(async_context_resource: providers.ContextResource[str]) -> None:
with (
pytest.raises(RuntimeError, match="AsyncResource cannot be resolved in an sync context."),
sync_container_context(),
):
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved in an sync context."), container_context():
await async_context_resource()


Expand Down Expand Up @@ -153,10 +149,10 @@ async def test_context_resource_with_dynamic_resource() -> None:


async def test_early_exit_of_container_context() -> None:
with pytest.raises(RuntimeError, match="generator didn't stop"):
with pytest.raises(RuntimeError, match="Context is not set, call ``__aenter__`` first"):
await container_context().__aexit__(None, None, None)
with pytest.raises(RuntimeError, match="generator didn't stop"):
sync_container_context().__exit__(None, None, None)
with pytest.raises(RuntimeError, match="Context is not set, call ``__enter__`` first"):
container_context().__exit__(None, None, None)


async def test_resource_context_early_teardown() -> None:
Expand All @@ -171,3 +167,25 @@ async def test_teardown_sync_container_context_with_async_resource() -> None:
resource_context.context_stack = AsyncExitStack()
with pytest.raises(RuntimeError, match="Cannot tear down async context in sync mode"):
resource_context.sync_tear_down()


async def test_sync_container_context_with_different_stack() -> None:
@container_context()
@inject
def some_injected(depth: int, val: str = Provide[DIContainer.sync_context_resource]) -> str:
if depth > 1:
return val
return some_injected(depth + 1)

some_injected(1)


async def test_async_container_context_with_different_stack() -> None:
@container_context()
@inject
async def some_injected(depth: int, val: str = Provide[DIContainer.async_context_resource]) -> str:
if depth > 1:
return val
return await some_injected(depth + 1)

await some_injected(1)
4 changes: 2 additions & 2 deletions that_depends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from that_depends import providers
from that_depends.container import BaseContainer
from that_depends.injection import Provide, inject
from that_depends.providers import container_context, fetch_context_item, sync_container_context
from that_depends.providers import container_context
from that_depends.providers.context_resources import fetch_context_item


__all__ = [
"container_context",
"sync_container_context",
"fetch_context_item",
"providers",
"BaseContainer",
Expand Down
4 changes: 0 additions & 4 deletions that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
ContextResource,
DIContextMiddleware,
container_context,
fetch_context_item,
sync_container_context,
)
from that_depends.providers.factories import AsyncFactory, Factory
from that_depends.providers.object import Object
Expand All @@ -32,6 +30,4 @@
"Selector",
"Singleton",
"container_context",
"sync_container_context",
"fetch_context_item",
]
95 changes: 73 additions & 22 deletions that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import inspect
import logging
import typing
import uuid
import warnings
from contextvars import ContextVar
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from contextvars import ContextVar, Token
from functools import wraps
from types import TracebackType

from that_depends.providers.base import AbstractResource, ResourceContext

Expand All @@ -23,36 +26,84 @@
ContextType = dict[str, typing.Any]


@contextlib.asynccontextmanager
async def container_context(initial_context: dict[str, typing.Any] | None = None) -> typing.AsyncIterator[None]:
initial_context_: ContextType = initial_context or {}
initial_context_[_ASYNC_CONTEXT_KEY] = True
token: typing.Final = _CONTAINER_CONTEXT.set(initial_context_)
try:
yield
finally:
class container_context( # noqa: N801
AbstractAsyncContextManager[ContextType], AbstractContextManager[ContextType]
):
"""Manage the context of ContextResources.
Can be entered using ``async with container_context()`` or with ``with container_context()``
as async-context-manager or context-manager respectively.
When used as async-context-manager, it will allow setup & teardown of both sync and async resources.
When used as sync-context-manager, it will only allow setup & teardown of sync resources.
"""

def __init__(self, initial_context: ContextType | None = None) -> None:
self._initial_context: ContextType = initial_context or {}
self._context_token: Token[ContextType] | None = None

def __enter__(self) -> ContextType:
self._initial_context[_ASYNC_CONTEXT_KEY] = False
return self._enter()

async def __aenter__(self) -> ContextType:
self._initial_context[_ASYNC_CONTEXT_KEY] = True
return self._enter()

def _enter(self) -> ContextType:
self._context_token = _CONTAINER_CONTEXT.set(self._initial_context or {})
return _CONTAINER_CONTEXT.get()

def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
if self._context_token is None:
msg = "Context is not set, call ``__enter__`` first"
raise RuntimeError(msg)

try:
for context_item in reversed(_CONTAINER_CONTEXT.get().values()):
if isinstance(context_item, ResourceContext):
await context_item.tear_down()
# we don't need to handle the case where the ResourceContext is async
context_item.sync_tear_down()

finally:
_CONTAINER_CONTEXT.reset(token)
_CONTAINER_CONTEXT.reset(self._context_token)

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None
) -> None:
if self._context_token is None:
msg = "Context is not set, call ``__aenter__`` first"
raise RuntimeError(msg)

@contextlib.contextmanager
def sync_container_context(initial_context: dict[str, typing.Any] | None = None) -> typing.Iterator[None]:
initial_context_: ContextType = initial_context or {}
initial_context_[_ASYNC_CONTEXT_KEY] = False
token: typing.Final = _CONTAINER_CONTEXT.set(initial_context_)
try:
yield
finally:
try:
for context_item in reversed(_CONTAINER_CONTEXT.get().values()):
if isinstance(context_item, ResourceContext):
if not isinstance(context_item, ResourceContext):
continue

if context_item.is_context_stack_async(context_item.context_stack):
await context_item.tear_down()
else:
context_item.sync_tear_down()
finally:
_CONTAINER_CONTEXT.reset(token)
_CONTAINER_CONTEXT.reset(self._context_token)

def __call__(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]:
if inspect.iscoroutinefunction(func):

@wraps(func)
async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T:
async with container_context(self._initial_context):
return await func(*args, **kwargs) # type: ignore[no-any-return]

return typing.cast(typing.Callable[P, T], _async_inner)

@wraps(func)
def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T:
with container_context(self._initial_context):
return func(*args, **kwargs)

return _sync_inner


class DIContextMiddleware:
Expand Down

0 comments on commit 47ecbb8

Please sign in to comment.