Skip to content

Commit

Permalink
Merge pull request #18 from modern-python/17-store-overrides-in-one-c…
Browse files Browse the repository at this point in the history
…ommon-dict

use shared dict for overrides for all scopes
  • Loading branch information
lesnik512 authored Nov 23, 2024
2 parents 124b6b3 + 9bfc4b0 commit 9733827
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
10 changes: 5 additions & 5 deletions packages/modern-di/modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

class Container(contextlib.AbstractAsyncContextManager["Container"]):
__slots__ = (
"scope",
"parent_container",
"context",
"_is_async",
"_provider_states",
"_overrides",
"_provider_states",
"_use_threading_lock",
"context",
"parent_container",
"scope",
)

def __init__(
Expand All @@ -37,7 +37,7 @@ def __init__(
self.context: dict[str, typing.Any] = context or {}
self._is_async: bool | None = None
self._provider_states: dict[str, ProviderState[typing.Any]] = {}
self._overrides: dict[str, typing.Any] = {}
self._overrides: dict[str, typing.Any] = parent_container._overrides if parent_container else {} # noqa: SLF001
self._use_threading_lock = use_threading_lock

def _exit(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion packages/modern-di/modern_di/provider_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class ProviderState(typing.Generic[T_co]):
__slots__ = "context_stack", "instance", "asyncio_lock", "threading_lock"
__slots__ = "asyncio_lock", "context_stack", "instance", "threading_lock"

def __init__(self, use_asyncio_lock: bool, use_threading_lock: bool) -> None:
self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None
Expand Down
4 changes: 2 additions & 2 deletions packages/modern-di/modern_di/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
"AbstractProvider",
"ContainerProvider",
"ContextAdapter",
"Factory",
"Dict",
"Factory",
"List",
"Resource",
"Selector",
"Singleton",
"Resource",
]
4 changes: 2 additions & 2 deletions packages/modern-di/modern_di/providers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def _check_providers_scope(self, providers: typing.Iterable[typing.Any]) -> None

class AbstractOverrideProvider(AbstractProvider[T_co], abc.ABC):
def override(self, override_object: object, container: Container) -> None:
container.find_container(self.scope).override(self.provider_id, override_object)
container.override(self.provider_id, override_object)

def reset_override(self, container: Container) -> None:
container.find_container(self.scope).reset_override(self.provider_id)
container.reset_override(self.provider_id)


class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
Expand Down
40 changes: 28 additions & 12 deletions packages/modern-di/tests_core/providers/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,44 @@ async def test_app_factory_in_request_scope() -> None:
assert instance1 is not instance2


async def test_factory_overridden() -> None:
async def test_factory_overridden_app_scope() -> None:
async with Container(scope=Scope.APP) as app_container:
with app_container.build_child_container(scope=Scope.REQUEST) as request_container:
instance1 = app_factory.sync_resolve(app_container)
instance1 = app_factory.sync_resolve(app_container)

app_factory.override(SimpleCreator(dep1="override"), container=request_container)
app_factory.override(SimpleCreator(dep1="override"), container=app_container)

instance2 = app_factory.sync_resolve(app_container)
instance3 = await app_factory.async_resolve(app_container)
assert instance1 is not instance2
assert instance2 is instance3
assert instance2.dep1 != instance1.dep1
instance2 = app_factory.sync_resolve(app_container)
instance3 = await app_factory.async_resolve(app_container)
assert instance1 is not instance2
assert instance2 is instance3
assert instance2.dep1 != instance1.dep1

app_factory.reset_override(app_container)
app_factory.reset_override(app_container)

instance4 = app_factory.sync_resolve(app_container)
instance4 = app_factory.sync_resolve(app_container)

assert instance4.dep1 == instance1.dep1
assert instance4.dep1 == instance1.dep1

assert instance3 is not instance4


async def test_factory_overridden_request_scope() -> None:
async with Container(scope=Scope.APP) as app_container:
request_factory.override(DependentCreator(dep1=SimpleCreator(dep1="override")), app_container)

with app_container.build_child_container(scope=Scope.REQUEST) as request_container:
instance1 = request_factory.sync_resolve(request_container)
instance2 = request_factory.sync_resolve(request_container)
assert instance1 is instance2
assert instance2.dep1.dep1 == instance1.dep1.dep1 == "override"

request_factory.reset_override(request_container)

instance3 = request_factory.sync_resolve(request_container)

assert instance3 is not instance1


async def test_factory_wrong_dependency_scope() -> None:
def some_factory(_: SimpleCreator) -> None: ...

Expand Down

0 comments on commit 9733827

Please sign in to comment.