Skip to content

Commit

Permalink
allow overriding Object provider (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
stkrizh authored Nov 10, 2024
1 parent 3657c23 commit 1703b65
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ class DIContainer(BaseContainer):
async_resource=async_resource.cast,
)
singleton = providers.Singleton(SingletonFactory, dep1=True)
object = providers.Object(object())
16 changes: 16 additions & 0 deletions tests/providers/test_providers_overriding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,31 @@ async def test_batch_providers_overriding() -> None:
async_factory_mock = datetime.datetime.fromisoformat("2025-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
object_mock = object()

providers_for_overriding = {
"async_resource": async_resource_mock,
"sync_resource": sync_resource_mock,
"simple_factory": simple_factory_mock,
"singleton": singleton_mock,
"async_factory": async_factory_mock,
"object": object_mock,
}

with container.DIContainer.override_providers(providers_for_overriding):
await container.DIContainer.simple_factory()
dependent_factory = await container.DIContainer.dependent_factory()
singleton = await container.DIContainer.singleton()
async_factory = await container.DIContainer.async_factory()
obj = await container.DIContainer.object()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert async_factory is async_factory_mock
assert obj is object_mock

assert (await container.DIContainer.async_resource()) != async_resource_mock

Expand All @@ -41,25 +45,29 @@ async def test_batch_providers_overriding_sync_resolve() -> None:
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
object_mock = object()

providers_for_overriding = {
"async_resource": async_resource_mock,
"sync_resource": sync_resource_mock,
"simple_factory": simple_factory_mock,
"singleton": singleton_mock,
"object": object_mock,
}

with container.DIContainer.override_providers(providers_for_overriding):
container.DIContainer.simple_factory.sync_resolve()
await container.DIContainer.async_resource.async_resolve()
dependent_factory = container.DIContainer.dependent_factory.sync_resolve()
singleton = container.DIContainer.singleton.sync_resolve()
obj = container.DIContainer.object.sync_resolve()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert obj is object_mock

assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock

Expand Down Expand Up @@ -88,23 +96,27 @@ async def test_providers_overriding() -> None:
async_factory_mock = datetime.datetime.fromisoformat("2025-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
object_mock = object()
container.DIContainer.async_resource.override(async_resource_mock)
container.DIContainer.sync_resource.override(sync_resource_mock)
container.DIContainer.simple_factory.override(simple_factory_mock)
container.DIContainer.singleton.override(singleton_mock)
container.DIContainer.async_factory.override(async_factory_mock)
container.DIContainer.object.override(object_mock)

await container.DIContainer.simple_factory()
dependent_factory = await container.DIContainer.dependent_factory()
singleton = await container.DIContainer.singleton()
async_factory = await container.DIContainer.async_factory()
obj = await container.DIContainer.object()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert async_factory is async_factory_mock
assert obj is object_mock

container.DIContainer.reset_override()
assert (await container.DIContainer.async_resource()) != async_resource_mock
Expand All @@ -115,21 +127,25 @@ async def test_providers_overriding_sync_resolve() -> None:
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
object_mock = object()
container.DIContainer.async_resource.override(async_resource_mock)
container.DIContainer.sync_resource.override(sync_resource_mock)
container.DIContainer.simple_factory.override(simple_factory_mock)
container.DIContainer.singleton.override(singleton_mock)
container.DIContainer.object.override(object_mock)

container.DIContainer.simple_factory.sync_resolve()
await container.DIContainer.async_resource.async_resolve()
dependent_factory = container.DIContainer.dependent_factory.sync_resolve()
singleton = container.DIContainer.singleton.sync_resolve()
obj = container.DIContainer.object.sync_resolve()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert obj is object_mock

container.DIContainer.reset_override()
assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock
4 changes: 3 additions & 1 deletion that_depends/providers/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(self, obj: T_co) -> None:
self._obj: typing.Final = obj

async def async_resolve(self) -> T_co:
return self._obj
return self.sync_resolve()

def sync_resolve(self) -> T_co:
if self._override is not None:
return typing.cast(T_co, self._override)
return self._obj

0 comments on commit 1703b65

Please sign in to comment.