From 1703b6501f9c9f600bfbebe389e7f63a7bf78bdd Mon Sep 17 00:00:00 2001 From: Stanislav Kiriukhin <44553725+stkrizh@users.noreply.github.com> Date: Sun, 10 Nov 2024 14:12:54 +0400 Subject: [PATCH] allow overriding Object provider (#118) --- tests/container.py | 1 + tests/providers/test_providers_overriding.py | 16 ++++++++++++++++ that_depends/providers/object.py | 4 +++- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/container.py b/tests/container.py index 944c92d..69b1277 100644 --- a/tests/container.py +++ b/tests/container.py @@ -66,3 +66,4 @@ class DIContainer(BaseContainer): async_resource=async_resource.cast, ) singleton = providers.Singleton(SingletonFactory, dep1=True) + object = providers.Object(object()) diff --git a/tests/providers/test_providers_overriding.py b/tests/providers/test_providers_overriding.py index a6dd0b5..1f19b10 100644 --- a/tests/providers/test_providers_overriding.py +++ b/tests/providers/test_providers_overriding.py @@ -11,6 +11,7 @@ async def test_batch_providers_overriding() -> None: async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) singleton_mock = container.SingletonFactory(dep1=False) + object_mock = object() providers_for_overriding = { "async_resource": async_resource_mock, @@ -18,6 +19,7 @@ async def test_batch_providers_overriding() -> None: "simple_factory": simple_factory_mock, "singleton": singleton_mock, "async_factory": async_factory_mock, + "object": object_mock, } with container.DIContainer.override_providers(providers_for_overriding): @@ -25,6 +27,7 @@ async def test_batch_providers_overriding() -> None: dependent_factory = await container.DIContainer.dependent_factory() singleton = await container.DIContainer.singleton() async_factory = await container.DIContainer.async_factory() + obj = await container.DIContainer.object() assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 @@ -32,6 +35,7 @@ async def test_batch_providers_overriding() -> None: assert dependent_factory.async_resource == async_resource_mock assert singleton is singleton_mock assert async_factory is async_factory_mock + assert obj is object_mock assert (await container.DIContainer.async_resource()) != async_resource_mock @@ -41,12 +45,14 @@ async def test_batch_providers_overriding_sync_resolve() -> None: sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) singleton_mock = container.SingletonFactory(dep1=False) + object_mock = object() providers_for_overriding = { "async_resource": async_resource_mock, "sync_resource": sync_resource_mock, "simple_factory": simple_factory_mock, "singleton": singleton_mock, + "object": object_mock, } with container.DIContainer.override_providers(providers_for_overriding): @@ -54,12 +60,14 @@ async def test_batch_providers_overriding_sync_resolve() -> None: await container.DIContainer.async_resource.async_resolve() dependent_factory = container.DIContainer.dependent_factory.sync_resolve() singleton = container.DIContainer.singleton.sync_resolve() + obj = container.DIContainer.object.sync_resolve() assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 assert dependent_factory.sync_resource == sync_resource_mock assert dependent_factory.async_resource == async_resource_mock assert singleton is singleton_mock + assert obj is object_mock assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock @@ -88,16 +96,19 @@ async def test_providers_overriding() -> None: async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) singleton_mock = container.SingletonFactory(dep1=False) + object_mock = object() container.DIContainer.async_resource.override(async_resource_mock) container.DIContainer.sync_resource.override(sync_resource_mock) container.DIContainer.simple_factory.override(simple_factory_mock) container.DIContainer.singleton.override(singleton_mock) container.DIContainer.async_factory.override(async_factory_mock) + container.DIContainer.object.override(object_mock) await container.DIContainer.simple_factory() dependent_factory = await container.DIContainer.dependent_factory() singleton = await container.DIContainer.singleton() async_factory = await container.DIContainer.async_factory() + obj = await container.DIContainer.object() assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 @@ -105,6 +116,7 @@ async def test_providers_overriding() -> None: assert dependent_factory.async_resource == async_resource_mock assert singleton is singleton_mock assert async_factory is async_factory_mock + assert obj is object_mock container.DIContainer.reset_override() assert (await container.DIContainer.async_resource()) != async_resource_mock @@ -115,21 +127,25 @@ async def test_providers_overriding_sync_resolve() -> None: sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) singleton_mock = container.SingletonFactory(dep1=False) + object_mock = object() container.DIContainer.async_resource.override(async_resource_mock) container.DIContainer.sync_resource.override(sync_resource_mock) container.DIContainer.simple_factory.override(simple_factory_mock) container.DIContainer.singleton.override(singleton_mock) + container.DIContainer.object.override(object_mock) container.DIContainer.simple_factory.sync_resolve() await container.DIContainer.async_resource.async_resolve() dependent_factory = container.DIContainer.dependent_factory.sync_resolve() singleton = container.DIContainer.singleton.sync_resolve() + obj = container.DIContainer.object.sync_resolve() assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 assert dependent_factory.sync_resource == sync_resource_mock assert dependent_factory.async_resource == async_resource_mock assert singleton is singleton_mock + assert obj is object_mock container.DIContainer.reset_override() assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock diff --git a/that_depends/providers/object.py b/that_depends/providers/object.py index a0095c7..2b218a3 100644 --- a/that_depends/providers/object.py +++ b/that_depends/providers/object.py @@ -15,7 +15,9 @@ def __init__(self, obj: T_co) -> None: self._obj: typing.Final = obj async def async_resolve(self) -> T_co: - return self._obj + return self.sync_resolve() def sync_resolve(self) -> T_co: + if self._override is not None: + return typing.cast(T_co, self._override) return self._obj