Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented ThreadLocalSingleton. #146

Merged
merged 7 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/providers/singleton.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ await MyContainer.singleton.async_resolve()
# calling sync_resolve concurrently in different threads will create only one instance
MyContainer.singleton.sync_resolve()
```
## ThreadLocalSingleton

For cases when you need to have a separate instance for each thread, you can use `ThreadLocalSingleton` provider. It will create a new instance for each thread and cache it for future injections in the same thread.

```python
from that_depends.providers import ThreadLocalSingleton
import threading
import random

# Define a factory function
def factory() -> int:
return random.randint(1, 100)

# Create a ThreadLocalSingleton instance
singleton = ThreadLocalSingleton(factory)

# Same thread, same instance
instance1 = singleton.sync_resolve() # 56
instance2 = singleton.sync_resolve() # 56

# Example usage in multiple threads
def thread_task():
instance = singleton.sync_resolve()
return instance

# Create and start threads
threads = [threading.Thread(target=thread_task) for i in range(2)]
for thread in threads:
thread.start()
for thread in threads:
results = thread.join()

# Results will be different for each thread
print(results) # [56, 78]
```


## Example with `pydantic-settings`
Let's say we are storing our application configuration using [pydantic-settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/):
Expand Down
120 changes: 120 additions & 0 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import dataclasses
import random
import threading
import time
import typing
Expand All @@ -9,6 +10,7 @@
import pytest

from that_depends import BaseContainer, providers
from that_depends.providers.singleton import ThreadLocalSingleton


@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
Expand Down Expand Up @@ -148,6 +150,124 @@ async def test_async_singleton_asyncio_concurrency() -> None:
assert all(val is results[0] for val in results)


async def test_thread_local_singleton_throws_on_async_resolve() -> None:
with pytest.raises(RuntimeError, match="ThreadLocalSingleton cannot be resolved in an async context."):
await ThreadLocalSingleton(lambda: None).async_resolve()


async def test_async_singleton_sync_resolve_failure() -> None:
with pytest.raises(RuntimeError, match="AsyncSingleton cannot be resolved in an sync context."):
DIContainer.singleton_async.sync_resolve()


def test_thread_local_singleton_same_thread() -> None:
"""Test that the same instance is returned within a single thread."""

def factory() -> int:
return random.randint(1, 100) # noqa: S311

provider = ThreadLocalSingleton(factory)

instance1 = provider.sync_resolve()
instance2 = provider.sync_resolve()

assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical."

provider.tear_down()

assert provider._instance is None, "Tear down failed: Instance should be reset to None."


def test_thread_local_singleton_different_threads() -> None:
"""Test that different threads receive different instances."""

def factory() -> int:
return random.randint(1, 100) # noqa: S311

provider = ThreadLocalSingleton(factory)
results = []

def resolve_in_thread() -> None:
results.append(provider.sync_resolve())

number_of_threads = 10

threads = [threading.Thread(target=resolve_in_thread) for _ in range(number_of_threads)]

for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert len(results) == number_of_threads, "Test failed: Expected results from two threads."
assert results[0] != results[1], "Thread-local failed: Instances across threads should differ."


def test_thread_local_singleton_override() -> None:
"""Test overriding the ThreadLocalSingleton and resetting the override."""

def _factory() -> int:
return random.randint(1, 100) # noqa: S311

provider = ThreadLocalSingleton(_factory)

override_value = 101
provider.override(override_value)
instance = provider.sync_resolve()
assert instance == override_value, "Override failed: Expected overridden value."

# Reset override and ensure a new instance is created
provider.reset_override()
new_instance = provider.sync_resolve()
assert new_instance != override_value, "Reset override failed: Should no longer return overridden value."


def test_thread_local_singleton_override_in_threads() -> None:
"""Test that resetting the override in one thread does not affect another thread."""

def _factory() -> int:
return random.randint(1, 100) # noqa: S311

provider = ThreadLocalSingleton(_factory)
results = {}

def _thread_task(thread_id: int, override_value: int | None = None) -> None:
if override_value is not None:
provider.override(override_value)
results[thread_id] = provider.sync_resolve()
if override_value is not None:
provider.reset_override()

override_value: typing.Final[int] = 101
thread1 = threading.Thread(target=_thread_task, args=(1, override_value))
thread2 = threading.Thread(target=_thread_task, args=(2,))

thread1.start()
thread2.start()

thread1.join()
thread2.join()

# Validate results
assert results[1] == override_value, "Thread 1: Override failed."
assert results[2] != override_value, "Thread 2: Should not be affected by Thread 1's override."
assert results[1] != results[2], "Instances should be unique across threads."


def test_thread_local_singleton_override_temporarily() -> None:
"""Test temporarily overriding the ThreadLocalSingleton."""

def factory() -> int:
return random.randint(1, 100) # noqa: S311

provider = ThreadLocalSingleton(factory)

override_value: typing.Final = 101
# Set a temporary override
with provider.override_context(override_value):
instance = provider.sync_resolve()
assert instance == override_value, "Override context failed: Expected overridden value."

# After the context, reset to the factory
new_instance = provider.sync_resolve()
assert new_instance != override_value, "Override context failed: Value should reset after context."
alexanderlazarev0 marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from that_depends.providers.object import Object
from that_depends.providers.resources import Resource
from that_depends.providers.selector import Selector
from that_depends.providers.singleton import AsyncSingleton, Singleton
from that_depends.providers.singleton import AsyncSingleton, Singleton, ThreadLocalSingleton


__all__ = [
Expand All @@ -28,5 +28,6 @@
"Resource",
"Selector",
"Singleton",
"ThreadLocalSingleton",
"container_context",
]
80 changes: 80 additions & 0 deletions that_depends/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,86 @@ async def tear_down(self) -> None:
self._instance = None


class ThreadLocalSingleton(AbstractProvider[T_co]):
"""Creates a new instance for each thread using a thread-local store.

This provider ensures that each thread gets its own instance, which is
created via the specified factory function. Once created, the instance is
cached for future injections within the same thread.

Example:
```python
def factory():
return random.randint(1, 100)

singleton = ThreadLocalSingleton(factory)

# Same thread, same instance
instance1 = singleton.sync_resolve()
instance2 = singleton.sync_resolve()

def thread_task():
return singleton.sync_resolve()

threads = [threading.Thread(target=thread_task) for i in range(10)]
for thread in threads:
thread.start() # Each thread will get a different instance
```

"""

def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None:
"""Initialize the ThreadLocalSingleton provider.

Args:
factory: A callable that returns a new instance of the dependency.
*args: Positional arguments to pass to the factory.
**kwargs: Keyword arguments to pass to the factory.

"""
super().__init__()
self._factory: typing.Final = factory
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
self._thread_local = threading.local()

@property
def _instance(self) -> T_co | None:
return getattr(self._thread_local, "instance", None)

@_instance.setter
def _instance(self, value: T_co | None) -> None:
self._thread_local.instance = value

@override
async def async_resolve(self) -> T_co:
msg = "ThreadLocalSingleton cannot be resolved in an async context."
raise NotImplementedError(msg)
lesnik512 marked this conversation as resolved.
Show resolved Hide resolved

@override
def sync_resolve(self) -> T_co:
if self._override is not None:
return typing.cast(T_co, self._override)

if self._instance is not None:
return self._instance

self._instance = self._factory(
*[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
**{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type]
)
return self._instance

def tear_down(self) -> None:
"""Reset the thread-local instance.

After calling this method, subsequent calls to `sync_resolve` on the
same thread will produce a new instance.
"""
if self._instance is not None:
self._instance = None


class AsyncSingleton(AbstractProvider[T_co]):
"""A provider that creates an instance asynchronously and caches it for subsequent injections.

Expand Down
Loading