Skip to content

Commit

Permalink
add Singleton provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lesnik512 committed Mar 20, 2024
1 parent 91f0137 commit 982a84b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "that-depends"
version = "1.2.0"
version = "1.3.0"
description = "Simple Dependency Injection framework"
authors = ["Artur Shiriev <[email protected]>"]
readme = "README.md"
Expand Down
6 changes: 6 additions & 0 deletions tests/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class AsyncDependentFactory:
async_resource: str


@dataclasses.dataclass(kw_only=True, slots=True)
class SingletonFactory:
dep1: bool


class DIContainer(BaseContainer):
sync_resource = providers.Resource[str](create_sync_resource)
async_resource = providers.AsyncResource[str](create_async_resource)
Expand All @@ -54,3 +59,4 @@ class DIContainer(BaseContainer):
independent_factory=independent_factory,
async_resource=async_resource,
)
singleton = providers.Singleton(SingletonFactory, dep1=True)
4 changes: 4 additions & 0 deletions tests/test_di.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ async def test_di() -> None:
sync_dependent_factory = await DIContainer.sync_dependent_factory()
async_dependent_factory = await DIContainer.async_dependent_factory()
sequence = await DIContainer.sequence()
singleton1 = await DIContainer.singleton()
singleton2 = await DIContainer.singleton()

assert sync_dependent_factory.independent_factory is not independent_factory
assert sync_dependent_factory.sync_resource == "sync resource"
assert async_dependent_factory.async_resource == "async resource"
assert sequence == ["sync resource", "async resource"]
assert singleton1 is singleton2


def test_wrong_providers_init() -> None:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_overriding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@ async def test_overriding() -> None:
async_resource_mock = "async overriding"
sync_resource_mock = "sync overriding"
independent_factory_mock = container.IndependentFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)
container.DIContainer.async_resource.override(async_resource_mock)
container.DIContainer.sync_resource.override(sync_resource_mock)
container.DIContainer.independent_factory.override(independent_factory_mock)
container.DIContainer.singleton.override(singleton_mock)

await container.DIContainer.independent_factory()
sync_dependent_factory = await container.DIContainer.sync_dependent_factory()
async_dependent_factory = await container.DIContainer.async_dependent_factory()
singleton = await container.DIContainer.singleton()

assert sync_dependent_factory.independent_factory.dep1 == independent_factory_mock.dep1
assert sync_dependent_factory.independent_factory.dep2 == independent_factory_mock.dep2
assert sync_dependent_factory.sync_resource == sync_resource_mock
assert async_dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock

container.DIContainer.reset_override()
assert (await container.DIContainer.async_resource()) == "async resource"
20 changes: 20 additions & 0 deletions that_depends/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,23 @@ async def resolve(self) -> list[T]: # type: ignore[override]

async def __call__(self) -> list[T]: # type: ignore[override]
return await self.resolve()


class Singleton(AbstractProvider[T]):
def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
self._factory = factory
self._args = args
self._kwargs = kwargs
self._override = None
self._instance: T | None = None

async def resolve(self) -> T:
if self._override:
return typing.cast(T, self._override)

if not self._instance:
self._instance = self._factory(
*[await x() if isinstance(x, AbstractProvider) else x for x in self._args],
**{k: await v() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)
return self._instance

0 comments on commit 982a84b

Please sign in to comment.