diff --git a/packages/modern-di/modern_di/container.py b/packages/modern-di/modern_di/container.py index cfbfe2d..f598415 100644 --- a/packages/modern-di/modern_di/container.py +++ b/packages/modern-di/modern_di/container.py @@ -15,13 +15,13 @@ class Container(contextlib.AbstractAsyncContextManager["Container"]): __slots__ = ( - "scope", - "parent_container", - "context", "_is_async", - "_provider_states", "_overrides", + "_provider_states", "_use_threading_lock", + "context", + "parent_container", + "scope", ) def __init__( @@ -37,7 +37,7 @@ def __init__( self.context: dict[str, typing.Any] = context or {} self._is_async: bool | None = None self._provider_states: dict[str, ProviderState[typing.Any]] = {} - self._overrides: dict[str, typing.Any] = {} + self._overrides: dict[str, typing.Any] = parent_container._overrides if parent_container else {} # noqa: SLF001 self._use_threading_lock = use_threading_lock def _exit(self) -> None: diff --git a/packages/modern-di/modern_di/provider_state.py b/packages/modern-di/modern_di/provider_state.py index b6edc8f..1ca6423 100644 --- a/packages/modern-di/modern_di/provider_state.py +++ b/packages/modern-di/modern_di/provider_state.py @@ -8,7 +8,7 @@ class ProviderState(typing.Generic[T_co]): - __slots__ = "context_stack", "instance", "asyncio_lock", "threading_lock" + __slots__ = "asyncio_lock", "context_stack", "instance", "threading_lock" def __init__(self, use_asyncio_lock: bool, use_threading_lock: bool) -> None: self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None diff --git a/packages/modern-di/modern_di/providers/__init__.py b/packages/modern-di/modern_di/providers/__init__.py index 377359b..13a47ea 100644 --- a/packages/modern-di/modern_di/providers/__init__.py +++ b/packages/modern-di/modern_di/providers/__init__.py @@ -13,10 +13,10 @@ "AbstractProvider", "ContainerProvider", "ContextAdapter", - "Factory", "Dict", + "Factory", "List", + "Resource", "Selector", "Singleton", - "Resource", ] diff --git a/packages/modern-di/modern_di/providers/abstract.py b/packages/modern-di/modern_di/providers/abstract.py index 05fe510..cbd4571 100644 --- a/packages/modern-di/modern_di/providers/abstract.py +++ b/packages/modern-di/modern_di/providers/abstract.py @@ -39,10 +39,10 @@ def _check_providers_scope(self, providers: typing.Iterable[typing.Any]) -> None class AbstractOverrideProvider(AbstractProvider[T_co], abc.ABC): def override(self, override_object: object, container: Container) -> None: - container.find_container(self.scope).override(self.provider_id, override_object) + container.override(self.provider_id, override_object) def reset_override(self, container: Container) -> None: - container.find_container(self.scope).reset_override(self.provider_id) + container.reset_override(self.provider_id) class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC): diff --git a/packages/modern-di/tests_core/providers/test_factory.py b/packages/modern-di/tests_core/providers/test_factory.py index 8e66f3a..a1b6e68 100644 --- a/packages/modern-di/tests_core/providers/test_factory.py +++ b/packages/modern-di/tests_core/providers/test_factory.py @@ -57,28 +57,44 @@ async def test_app_factory_in_request_scope() -> None: assert instance1 is not instance2 -async def test_factory_overridden() -> None: +async def test_factory_overridden_app_scope() -> None: async with Container(scope=Scope.APP) as app_container: - with app_container.build_child_container(scope=Scope.REQUEST) as request_container: - instance1 = app_factory.sync_resolve(app_container) + instance1 = app_factory.sync_resolve(app_container) - app_factory.override(SimpleCreator(dep1="override"), container=request_container) + app_factory.override(SimpleCreator(dep1="override"), container=app_container) - instance2 = app_factory.sync_resolve(app_container) - instance3 = await app_factory.async_resolve(app_container) - assert instance1 is not instance2 - assert instance2 is instance3 - assert instance2.dep1 != instance1.dep1 + instance2 = app_factory.sync_resolve(app_container) + instance3 = await app_factory.async_resolve(app_container) + assert instance1 is not instance2 + assert instance2 is instance3 + assert instance2.dep1 != instance1.dep1 - app_factory.reset_override(app_container) + app_factory.reset_override(app_container) - instance4 = app_factory.sync_resolve(app_container) + instance4 = app_factory.sync_resolve(app_container) - assert instance4.dep1 == instance1.dep1 + assert instance4.dep1 == instance1.dep1 assert instance3 is not instance4 +async def test_factory_overridden_request_scope() -> None: + async with Container(scope=Scope.APP) as app_container: + request_factory.override(DependentCreator(dep1=SimpleCreator(dep1="override")), app_container) + + with app_container.build_child_container(scope=Scope.REQUEST) as request_container: + instance1 = request_factory.sync_resolve(request_container) + instance2 = request_factory.sync_resolve(request_container) + assert instance1 is instance2 + assert instance2.dep1.dep1 == instance1.dep1.dep1 == "override" + + request_factory.reset_override(request_container) + + instance3 = request_factory.sync_resolve(request_container) + + assert instance3 is not instance1 + + async def test_factory_wrong_dependency_scope() -> None: def some_factory(_: SimpleCreator) -> None: ...