Skip to content

Commit

Permalink
separate sync_container_context to fix problems with the same context…
Browse files Browse the repository at this point in the history
…var token for every context (#93)
  • Loading branch information
lesnik512 authored Sep 28, 2024
1 parent 97ca478 commit ad3a43f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 92 deletions.
20 changes: 8 additions & 12 deletions docs/providers/context-resources.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,18 @@ value = MyContainer.sync_resource.sync_resolve()
> RuntimeError: Context is not set. Use container_context
```

### Resolving async and sync dependencies:
``container_context`` implements both ``AsyncContextManager`` and ``ContextManager``.
This means that you can enter an async context with:

```python
async with container_context():
...
```
An async context will allow resolution of both sync and async dependencies.

### Resolving sync dependencies:
`container_context` implements only `AsyncContextManager`.
For sync context and `ContextManager` use `sync_container_context`.
A sync context can be entered using:
```python
with container_context():
with sync_container_context():
...
```
A sync context will only allow resolution of sync dependencies:
```python
async def my_func():
with container_context(): # enter sync context
with sync_container_context(): # enter sync context
# try to resolve async dependency.
await MyContainer.async_resource.async_resolve()

Expand All @@ -70,6 +63,7 @@ async def my_func():
Each time you enter `container_context` a new context is created in the background.
Resources are cached in the context after first resolution.
Resources created in a context are torn down again when `container_context` exits.
Same for `sync_container_context`.
```python
async with container_context():
value_outer = await MyContainer.resource.async_resolve()
Expand All @@ -90,3 +84,5 @@ async def insert_into_database(session = Provide[MyContainer.session]):
...
```
Each time ``await insert_into_database()`` is called new instance of ``session`` will be injected.

Same for `sync_container_context`.
14 changes: 9 additions & 5 deletions tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from that_depends import BaseContainer, fetch_context_item, providers
from that_depends.providers import container_context
from that_depends.providers.base import ResourceContext
from that_depends.providers.context_resources import sync_container_context


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,15 +77,18 @@ async def test_context_resource(context_resource: providers.ContextResource[str]
assert await context_resource() is context_resource_result


@container_context()
@sync_container_context()
def test_sync_context_resource(sync_context_resource: providers.ContextResource[str]) -> None:
context_resource_result = sync_context_resource.sync_resolve()

assert sync_context_resource.sync_resolve() is context_resource_result


async def test_async_context_resource_in_sync_context(async_context_resource: providers.ContextResource[str]) -> None:
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved in an sync context."), container_context():
with (
pytest.raises(RuntimeError, match="AsyncResource cannot be resolved in an sync context."),
sync_container_context(),
):
await async_context_resource()


Expand Down Expand Up @@ -149,10 +153,10 @@ async def test_context_resource_with_dynamic_resource() -> None:


async def test_early_exit_of_container_context() -> None:
with pytest.raises(RuntimeError, match="Context is not set, call ``__aenter__`` first"):
with pytest.raises(RuntimeError, match="generator didn't stop"):
await container_context().__aexit__(None, None, None)
with pytest.raises(RuntimeError, match="Context is not set, call ``__enter__`` first"):
container_context().__exit__(None, None, None)
with pytest.raises(RuntimeError, match="generator didn't stop"):
sync_container_context().__exit__(None, None, None)


async def test_resource_context_early_teardown() -> None:
Expand Down
4 changes: 2 additions & 2 deletions that_depends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from that_depends import providers
from that_depends.container import BaseContainer
from that_depends.injection import Provide, inject
from that_depends.providers import container_context
from that_depends.providers.context_resources import fetch_context_item
from that_depends.providers import container_context, fetch_context_item, sync_container_context


__all__ = [
"container_context",
"sync_container_context",
"fetch_context_item",
"providers",
"BaseContainer",
Expand Down
4 changes: 4 additions & 0 deletions that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
ContextResource,
DIContextMiddleware,
container_context,
fetch_context_item,
sync_container_context,
)
from that_depends.providers.factories import AsyncFactory, Factory
from that_depends.providers.object import Object
Expand All @@ -30,4 +32,6 @@
"Selector",
"Singleton",
"container_context",
"sync_container_context",
"fetch_context_item",
]
95 changes: 22 additions & 73 deletions that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import inspect
import contextlib
import logging
import typing
import uuid
import warnings
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from contextvars import ContextVar, Token
from functools import wraps
from types import TracebackType
from contextvars import ContextVar

from that_depends.providers.base import AbstractResource, ResourceContext

Expand All @@ -26,84 +23,36 @@
ContextType = dict[str, typing.Any]


class container_context( # noqa: N801
AbstractAsyncContextManager[ContextType], AbstractContextManager[ContextType]
):
"""Manage the context of ContextResources.
Can be entered using ``async with container_context()`` or with ``with container_context()``
as async-context-manager or context-manager respectively.
When used as async-context-manager, it will allow setup & teardown of both sync and async resources.
When used as sync-context-manager, it will only allow setup & teardown of sync resources.
"""

def __init__(self, initial_context: ContextType | None = None) -> None:
self._initial_context: ContextType = initial_context or {}
self._context_token: Token[ContextType] | None = None

def __enter__(self) -> ContextType:
self._initial_context[_ASYNC_CONTEXT_KEY] = False
return self._enter()

async def __aenter__(self) -> ContextType:
self._initial_context[_ASYNC_CONTEXT_KEY] = True
return self._enter()

def _enter(self) -> ContextType:
self._context_token = _CONTAINER_CONTEXT.set(self._initial_context or {})
return _CONTAINER_CONTEXT.get()

def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
if self._context_token is None:
msg = "Context is not set, call ``__enter__`` first"
raise RuntimeError(msg)

@contextlib.asynccontextmanager
async def container_context(initial_context: dict[str, typing.Any] | None = None) -> typing.AsyncIterator[None]:
initial_context_: ContextType = initial_context or {}
initial_context_[_ASYNC_CONTEXT_KEY] = True
token: typing.Final = _CONTAINER_CONTEXT.set(initial_context_)
try:
yield
finally:
try:
for context_item in reversed(_CONTAINER_CONTEXT.get().values()):
if isinstance(context_item, ResourceContext):
# we don't need to handle the case where the ResourceContext is async
context_item.sync_tear_down()

await context_item.tear_down()
finally:
_CONTAINER_CONTEXT.reset(self._context_token)
_CONTAINER_CONTEXT.reset(token)

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None
) -> None:
if self._context_token is None:
msg = "Context is not set, call ``__aenter__`` first"
raise RuntimeError(msg)

@contextlib.contextmanager
def sync_container_context(initial_context: dict[str, typing.Any] | None = None) -> typing.Iterator[None]:
initial_context_: ContextType = initial_context or {}
initial_context_[_ASYNC_CONTEXT_KEY] = False
token: typing.Final = _CONTAINER_CONTEXT.set(initial_context_)
try:
yield
finally:
try:
for context_item in reversed(_CONTAINER_CONTEXT.get().values()):
if not isinstance(context_item, ResourceContext):
continue

if context_item.is_context_stack_async(context_item.context_stack):
await context_item.tear_down()
else:
if isinstance(context_item, ResourceContext):
context_item.sync_tear_down()
finally:
_CONTAINER_CONTEXT.reset(self._context_token)

def __call__(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]:
if inspect.iscoroutinefunction(func):

@wraps(func)
async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T:
async with self:
return await func(*args, **kwargs) # type: ignore[no-any-return]

return typing.cast(typing.Callable[P, T], _async_inner)

@wraps(func)
def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T:
with self:
return func(*args, **kwargs)

return _sync_inner
_CONTAINER_CONTEXT.reset(token)


class DIContextMiddleware:
Expand Down

0 comments on commit ad3a43f

Please sign in to comment.