Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented ThreadLocalSingleton. #146

Merged
merged 7 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default: install lint test
install:
uv lock --upgrade
uv sync --only-dev --frozen
uv run pre-commit install --overwrite
just hook

lint:
uv run ruff format
Expand All @@ -24,7 +24,7 @@ publish:
uv publish --token $PYPI_TOKEN

hook:
uv run pre-commit install
uv run pre-commit install --install-hooks --overwrite

unhook:
uv run pre-commit uninstall
36 changes: 36 additions & 0 deletions docs/providers/singleton.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ await MyContainer.singleton.async_resolve()
# calling sync_resolve concurrently in different threads will create only one instance
MyContainer.singleton.sync_resolve()
```
## ThreadLocalSingleton

For cases when you need to have a separate instance for each thread, you can use `ThreadLocalSingleton` provider. It will create a new instance for each thread and cache it for future injections in the same thread.

```python
from that_depends.providers import ThreadLocalSingleton
import threading
import random

# Define a factory function
def factory() -> int:
return random.randint(1, 100)

# Create a ThreadLocalSingleton instance
singleton = ThreadLocalSingleton(factory)

# Same thread, same instance
instance1 = singleton.sync_resolve() # 56
instance2 = singleton.sync_resolve() # 56

# Example usage in multiple threads
def thread_task():
instance = singleton.sync_resolve()
return instance

# Create and start threads
threads = [threading.Thread(target=thread_task) for i in range(2)]
for thread in threads:
thread.start()
for thread in threads:
results = thread.join()

# Results will be different for each thread
print(results) # [56, 78]
```


## Example with `pydantic-settings`
Let's say we are storing our application configuration using [pydantic-settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/):
Expand Down
178 changes: 178 additions & 0 deletions tests/providers/test_local_singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import asyncio
import random
import threading
import time
import typing
from concurrent.futures.thread import ThreadPoolExecutor

import pytest

from that_depends.providers import AsyncFactory, ThreadLocalSingleton


async def _async_factory() -> int:
await asyncio.sleep(0.01)
return threading.get_ident()


def _factory() -> int:
time.sleep(0.01)
return random.randint(1, 100) # noqa: S311


def test_thread_local_singleton_same_thread() -> None:
"""Test that the same instance is returned within a single thread."""
provider = ThreadLocalSingleton(_factory)

instance1 = provider.sync_resolve()
instance2 = provider.sync_resolve()

assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical."

provider.tear_down()

assert provider._instance is None, "Tear down failed: Instance should be reset to None."


def test_thread_local_singleton_different_threads() -> None:
"""Test that different threads receive different instances."""
provider = ThreadLocalSingleton(_factory)
results = []

def resolve_in_thread() -> None:
results.append(provider.sync_resolve())

number_of_threads = 10

threads = [threading.Thread(target=resolve_in_thread) for _ in range(number_of_threads)]

for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert len(results) == number_of_threads, "Test failed: Expected results from two threads."
assert results[0] != results[1], "Thread-local failed: Instances across threads should differ."


def test_thread_local_singleton_override() -> None:
"""Test overriding the ThreadLocalSingleton and resetting the override."""
provider = ThreadLocalSingleton(_factory)

override_value = 101
provider.override(override_value)
instance = provider.sync_resolve()
assert instance == override_value, "Override failed: Expected overridden value."

# Reset override and ensure a new instance is created
provider.reset_override()
new_instance = provider.sync_resolve()
assert new_instance != override_value, "Reset override failed: Should no longer return overridden value."


def test_thread_local_singleton_override_in_threads() -> None:
"""Test that resetting the override in one thread does not affect another thread."""
provider = ThreadLocalSingleton(_factory)
results = {}

def _thread_task(thread_id: int, override_value: int | None = None) -> None:
if override_value is not None:
provider.override(override_value)
results[thread_id] = provider.sync_resolve()
if override_value is not None:
provider.reset_override()

override_value: typing.Final[int] = 101
thread1 = threading.Thread(target=_thread_task, args=(1, override_value))
thread2 = threading.Thread(target=_thread_task, args=(2,))

thread1.start()
thread2.start()

thread1.join()
thread2.join()

# Validate results
assert results[1] == override_value, "Thread 1: Override failed."
assert results[2] != override_value, "Thread 2: Should not be affected by Thread 1's override."
assert results[1] != results[2], "Instances should be unique across threads."


def test_thread_local_singleton_override_temporarily() -> None:
"""Test temporarily overriding the ThreadLocalSingleton."""
provider = ThreadLocalSingleton(_factory)

override_value: typing.Final = 101
# Set a temporary override
with provider.override_context(override_value):
instance = provider.sync_resolve()
assert instance == override_value, "Override context failed: Expected overridden value."

# After the context, reset to the factory
new_instance = provider.sync_resolve()
assert new_instance != override_value, "Override context failed: Value should reset after context."


async def test_async_thread_local_singleton_asyncio_concurrency() -> None:
singleton_async = ThreadLocalSingleton(_factory)

expected = await singleton_async.async_resolve()

results = await asyncio.gather(
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
)

assert all(val is results[0] for val in results)
for val in results:
assert val == expected, "Instances should be identical across threads."


async def test_thread_local_singleton_async_resolve_with_async_dependencies() -> None:
lesnik512 marked this conversation as resolved.
Show resolved Hide resolved
async_provider = AsyncFactory(_async_factory)

def _dependent_creator(v: int) -> int:
return v

provider = ThreadLocalSingleton(_dependent_creator, v=async_provider.cast)

expected = await provider.async_resolve()

assert expected == await provider.async_resolve()

results = await asyncio.gather(*[provider.async_resolve() for _ in range(10)])

for val in results:
assert val == expected, "Instances should be identical across threads."
with pytest.raises(RuntimeError):
# This should raise an error because the provider is async and resolution is attempted.
await asyncio.gather(asyncio.to_thread(provider.sync_resolve))

def _run_async_in_thread(coroutine: typing.Awaitable[typing.Any]) -> typing.Any: # noqa: ANN401
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine)
loop.close()
return result

with ThreadPoolExecutor() as executor:
future = executor.submit(_run_async_in_thread, provider.async_resolve())

result = future.result()

assert result != expected, (
"Since singleton should have been newly resolved, it should not have the same thread id."
)


async def test_thread_local_singleton_async_resolve_override() -> None:
provider = ThreadLocalSingleton(_factory)

override_value = 101

provider.override(override_value)

assert await provider.async_resolve() == override_value, "Override failed: Expected overridden value."
29 changes: 29 additions & 0 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,21 @@ async def create_async_obj(value: str) -> SingletonFactory:
return SingletonFactory(dep1=f"async {value}")


async def _async_creator() -> int:
await asyncio.sleep(0.001)
return threading.get_ident()


def _sync_creator_with_dependency(dep: int) -> str:
return f"Singleton {dep}"


class DIContainer(BaseContainer):
factory: providers.AsyncFactory[int] = providers.AsyncFactory(_async_creator)
settings: Settings = providers.Singleton(Settings).cast
singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting)
singleton_async = providers.AsyncSingleton(create_async_obj, value=settings.some_setting)
singleton_with_dependency = providers.Singleton(_sync_creator_with_dependency, dep=factory.cast)


async def test_singleton_provider() -> None:
Expand Down Expand Up @@ -151,3 +162,21 @@ async def test_async_singleton_asyncio_concurrency() -> None:
async def test_async_singleton_sync_resolve_failure() -> None:
with pytest.raises(RuntimeError, match="AsyncSingleton cannot be resolved in an sync context."):
DIContainer.singleton_async.sync_resolve()


async def test_singleton_async_resolve_with_async_dependencies() -> None:
lesnik512 marked this conversation as resolved.
Show resolved Hide resolved
expected = await DIContainer.singleton_with_dependency.async_resolve()

assert expected == await DIContainer.singleton_with_dependency.async_resolve()

results = await asyncio.gather(*[DIContainer.singleton_with_dependency.async_resolve() for _ in range(10)])

for val in results:
assert val == expected

results = await asyncio.gather(
*[asyncio.to_thread(DIContainer.singleton_with_dependency.sync_resolve) for _ in range(10)],
)

for val in results:
assert val == expected
2 changes: 2 additions & 0 deletions that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
container_context,
)
from that_depends.providers.factories import AsyncFactory, Factory
from that_depends.providers.local_singleton import ThreadLocalSingleton
from that_depends.providers.object import Object
from that_depends.providers.resources import Resource
from that_depends.providers.selector import Selector
Expand All @@ -28,5 +29,6 @@
"Resource",
"Selector",
"Singleton",
"ThreadLocalSingleton",
"container_context",
]
Loading
Loading