Skip to content

Commit

Permalink
Implemented ThreadLocalSingleton. (#146)
Browse files Browse the repository at this point in the history
* Implemented ThreadLocalSingleton.

* Fixed a docstring example.

* Updated Justfile to reference hook.

* Moved ThreadLocalSingleton to separate file.

* Added async resolution to ThreadSafeSingleton.

* Added more thread safety tests for singletons.

* Added sleep to test methods to ensure concurrency.
  • Loading branch information
alexanderlazarev0 authored Jan 19, 2025
1 parent 9781e3d commit 6e9f8cd
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default: install lint test
install:
uv lock --upgrade
uv sync --only-dev --frozen
uv run pre-commit install --overwrite
just hook

lint:
uv run ruff format
Expand All @@ -24,7 +24,7 @@ publish:
uv publish --token $PYPI_TOKEN

hook:
uv run pre-commit install
uv run pre-commit install --install-hooks --overwrite

unhook:
uv run pre-commit uninstall
36 changes: 36 additions & 0 deletions docs/providers/singleton.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ await MyContainer.singleton.async_resolve()
# calling sync_resolve concurrently in different threads will create only one instance
MyContainer.singleton.sync_resolve()
```
## ThreadLocalSingleton

For cases when you need to have a separate instance for each thread, you can use `ThreadLocalSingleton` provider. It will create a new instance for each thread and cache it for future injections in the same thread.

```python
from that_depends.providers import ThreadLocalSingleton
import threading
import random

# Define a factory function
def factory() -> int:
return random.randint(1, 100)

# Create a ThreadLocalSingleton instance
singleton = ThreadLocalSingleton(factory)

# Same thread, same instance
instance1 = singleton.sync_resolve() # 56
instance2 = singleton.sync_resolve() # 56

# Example usage in multiple threads
def thread_task():
instance = singleton.sync_resolve()
return instance

# Create and start threads
threads = [threading.Thread(target=thread_task) for i in range(2)]
for thread in threads:
thread.start()
for thread in threads:
results = thread.join()

# Results will be different for each thread
print(results) # [56, 78]
```


## Example with `pydantic-settings`
Let's say we are storing our application configuration using [pydantic-settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/):
Expand Down
178 changes: 178 additions & 0 deletions tests/providers/test_local_singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import asyncio
import random
import threading
import time
import typing
from concurrent.futures.thread import ThreadPoolExecutor

import pytest

from that_depends.providers import AsyncFactory, ThreadLocalSingleton


async def _async_factory() -> int:
await asyncio.sleep(0.01)
return threading.get_ident()


def _factory() -> int:
time.sleep(0.01)
return random.randint(1, 100) # noqa: S311


def test_thread_local_singleton_same_thread() -> None:
"""Test that the same instance is returned within a single thread."""
provider = ThreadLocalSingleton(_factory)

instance1 = provider.sync_resolve()
instance2 = provider.sync_resolve()

assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical."

provider.tear_down()

assert provider._instance is None, "Tear down failed: Instance should be reset to None."


def test_thread_local_singleton_different_threads() -> None:
"""Test that different threads receive different instances."""
provider = ThreadLocalSingleton(_factory)
results = []

def resolve_in_thread() -> None:
results.append(provider.sync_resolve())

number_of_threads = 10

threads = [threading.Thread(target=resolve_in_thread) for _ in range(number_of_threads)]

for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert len(results) == number_of_threads, "Test failed: Expected results from two threads."
assert results[0] != results[1], "Thread-local failed: Instances across threads should differ."


def test_thread_local_singleton_override() -> None:
"""Test overriding the ThreadLocalSingleton and resetting the override."""
provider = ThreadLocalSingleton(_factory)

override_value = 101
provider.override(override_value)
instance = provider.sync_resolve()
assert instance == override_value, "Override failed: Expected overridden value."

# Reset override and ensure a new instance is created
provider.reset_override()
new_instance = provider.sync_resolve()
assert new_instance != override_value, "Reset override failed: Should no longer return overridden value."


def test_thread_local_singleton_override_in_threads() -> None:
"""Test that resetting the override in one thread does not affect another thread."""
provider = ThreadLocalSingleton(_factory)
results = {}

def _thread_task(thread_id: int, override_value: int | None = None) -> None:
if override_value is not None:
provider.override(override_value)
results[thread_id] = provider.sync_resolve()
if override_value is not None:
provider.reset_override()

override_value: typing.Final[int] = 101
thread1 = threading.Thread(target=_thread_task, args=(1, override_value))
thread2 = threading.Thread(target=_thread_task, args=(2,))

thread1.start()
thread2.start()

thread1.join()
thread2.join()

# Validate results
assert results[1] == override_value, "Thread 1: Override failed."
assert results[2] != override_value, "Thread 2: Should not be affected by Thread 1's override."
assert results[1] != results[2], "Instances should be unique across threads."


def test_thread_local_singleton_override_temporarily() -> None:
"""Test temporarily overriding the ThreadLocalSingleton."""
provider = ThreadLocalSingleton(_factory)

override_value: typing.Final = 101
# Set a temporary override
with provider.override_context(override_value):
instance = provider.sync_resolve()
assert instance == override_value, "Override context failed: Expected overridden value."

# After the context, reset to the factory
new_instance = provider.sync_resolve()
assert new_instance != override_value, "Override context failed: Value should reset after context."


async def test_async_thread_local_singleton_asyncio_concurrency() -> None:
singleton_async = ThreadLocalSingleton(_factory)

expected = await singleton_async.async_resolve()

results = await asyncio.gather(
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
)

assert all(val is results[0] for val in results)
for val in results:
assert val == expected, "Instances should be identical across threads."


async def test_thread_local_singleton_async_resolve_with_async_dependencies() -> None:
async_provider = AsyncFactory(_async_factory)

def _dependent_creator(v: int) -> int:
return v

provider = ThreadLocalSingleton(_dependent_creator, v=async_provider.cast)

expected = await provider.async_resolve()

assert expected == await provider.async_resolve()

results = await asyncio.gather(*[provider.async_resolve() for _ in range(10)])

for val in results:
assert val == expected, "Instances should be identical across threads."
with pytest.raises(RuntimeError):
# This should raise an error because the provider is async and resolution is attempted.
await asyncio.gather(asyncio.to_thread(provider.sync_resolve))

def _run_async_in_thread(coroutine: typing.Awaitable[typing.Any]) -> typing.Any: # noqa: ANN401
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine)
loop.close()
return result

with ThreadPoolExecutor() as executor:
future = executor.submit(_run_async_in_thread, provider.async_resolve())

result = future.result()

assert result != expected, (
"Since singleton should have been newly resolved, it should not have the same thread id."
)


async def test_thread_local_singleton_async_resolve_override() -> None:
provider = ThreadLocalSingleton(_factory)

override_value = 101

provider.override(override_value)

assert await provider.async_resolve() == override_value, "Override failed: Expected overridden value."
29 changes: 29 additions & 0 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,21 @@ async def create_async_obj(value: str) -> SingletonFactory:
return SingletonFactory(dep1=f"async {value}")


async def _async_creator() -> int:
await asyncio.sleep(0.001)
return threading.get_ident()


def _sync_creator_with_dependency(dep: int) -> str:
return f"Singleton {dep}"


class DIContainer(BaseContainer):
factory: providers.AsyncFactory[int] = providers.AsyncFactory(_async_creator)
settings: Settings = providers.Singleton(Settings).cast
singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting)
singleton_async = providers.AsyncSingleton(create_async_obj, value=settings.some_setting)
singleton_with_dependency = providers.Singleton(_sync_creator_with_dependency, dep=factory.cast)


async def test_singleton_provider() -> None:
Expand Down Expand Up @@ -151,3 +162,21 @@ async def test_async_singleton_asyncio_concurrency() -> None:
async def test_async_singleton_sync_resolve_failure() -> None:
with pytest.raises(RuntimeError, match="AsyncSingleton cannot be resolved in an sync context."):
DIContainer.singleton_async.sync_resolve()


async def test_singleton_async_resolve_with_async_dependencies() -> None:
expected = await DIContainer.singleton_with_dependency.async_resolve()

assert expected == await DIContainer.singleton_with_dependency.async_resolve()

results = await asyncio.gather(*[DIContainer.singleton_with_dependency.async_resolve() for _ in range(10)])

for val in results:
assert val == expected

results = await asyncio.gather(
*[asyncio.to_thread(DIContainer.singleton_with_dependency.sync_resolve) for _ in range(10)],
)

for val in results:
assert val == expected
2 changes: 2 additions & 0 deletions that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
container_context,
)
from that_depends.providers.factories import AsyncFactory, Factory
from that_depends.providers.local_singleton import ThreadLocalSingleton
from that_depends.providers.object import Object
from that_depends.providers.resources import Resource
from that_depends.providers.selector import Selector
Expand All @@ -28,5 +29,6 @@
"Resource",
"Selector",
"Singleton",
"ThreadLocalSingleton",
"container_context",
]
Loading

0 comments on commit 6e9f8cd

Please sign in to comment.