Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lesnik512 committed Nov 21, 2024
1 parent 8a90e8e commit 45c0845
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 63 deletions.
16 changes: 5 additions & 11 deletions tests/providers/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ async def create_resource() -> typing.AsyncIterator[str]:

resource = providers.Resource(create_resource)

async def resolve_resource() -> str:
return await resource.async_resolve()

await asyncio.gather(resolve_resource(), resolve_resource())
await asyncio.gather(resource.async_resolve(), resource.async_resolve())

assert calls == 1

Expand All @@ -157,15 +154,12 @@ def create_resource() -> typing.Iterator[str]:

resource = providers.Resource(create_resource)

def resolve_resource() -> str:
return resource.sync_resolve()

with ThreadPoolExecutor(max_workers=4) as pool:
tasks = [
pool.submit(resolve_resource),
pool.submit(resolve_resource),
pool.submit(resolve_resource),
pool.submit(resolve_resource),
pool.submit(resource.sync_resolve),
pool.submit(resource.sync_resolve),
pool.submit(resource.sync_resolve),
pool.submit(resource.sync_resolve),
]
results = [x.result() for x in as_completed(tasks)]

Expand Down
100 changes: 48 additions & 52 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,46 +45,6 @@ async def test_singleton_provider() -> None:
await DIContainer.tear_down()


async def test_singleton_async_provider() -> None:
singleton1 = await DIContainer.singleton_async()
singleton2 = await DIContainer.singleton_async()
singleton3 = await DIContainer.singleton_async.async_resolve()
await DIContainer.singleton_async.tear_down()
singleton4 = await DIContainer.singleton_async.async_resolve()

assert singleton1 is singleton2 is singleton3
assert singleton4 is not singleton1

await DIContainer.tear_down()


async def test_singleton_async_provider_override() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
singleton_async.override(SingletonFactory(dep1="bar"))

result = await singleton_async.async_resolve()
assert result == SingletonFactory(dep1="bar")


async def test_singleton_async_provider_concurrent() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")

results = await asyncio.gather(
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
)

assert all(val is results[0] for val in results)


async def test_singleton_async_provider_sync_resolve() -> None:
with pytest.raises(RuntimeError, match="AsyncSingleton cannot be resolved in an sync context."):
DIContainer.singleton_async.sync_resolve()


async def test_singleton_attr_getter() -> None:
singleton1 = await DIContainer.singleton()

Expand Down Expand Up @@ -118,17 +78,16 @@ class SimpleCreator:
resource = providers.Resource(create_resource)
factory_with_resource = providers.Singleton(SimpleCreator, dep1=resource.cast)

async def resolve_factory() -> SimpleCreator:
return await factory_with_resource.async_resolve()

client1, client2 = await asyncio.gather(resolve_factory(), resolve_factory())
client1, client2 = await asyncio.gather(
factory_with_resource.async_resolve(), factory_with_resource.async_resolve()
)

assert client1 is client2
assert calls == 1


@pytest.mark.repeat(10)
def test_singleton_sync_resolve_concurrency() -> None:
def test_singleton_threading_concurrency() -> None:
calls: int = 0
lock = threading.Lock()

Expand All @@ -141,17 +100,54 @@ def create_singleton() -> str:

singleton = providers.Singleton(create_singleton)

def resolve_singleton() -> str:
return singleton.sync_resolve()

with ThreadPoolExecutor(max_workers=4) as pool:
tasks = [
pool.submit(resolve_singleton),
pool.submit(resolve_singleton),
pool.submit(resolve_singleton),
pool.submit(resolve_singleton),
pool.submit(singleton.sync_resolve),
pool.submit(singleton.sync_resolve),
pool.submit(singleton.sync_resolve),
pool.submit(singleton.sync_resolve),
]
results = [x.result() for x in as_completed(tasks)]

assert all(x == "" for x in results)
assert calls == 1


async def test_async_singleton() -> None:
singleton1 = await DIContainer.singleton_async()
singleton2 = await DIContainer.singleton_async()
singleton3 = await DIContainer.singleton_async.async_resolve()
await DIContainer.singleton_async.tear_down()
singleton4 = await DIContainer.singleton_async.async_resolve()

assert singleton1 is singleton2 is singleton3
assert singleton4 is not singleton1

await DIContainer.tear_down()


async def test_async_singleton_override() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
singleton_async.override(SingletonFactory(dep1="bar"))

result = await singleton_async.async_resolve()
assert result == SingletonFactory(dep1="bar")


async def test_async_singleton_asyncio_concurrency() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")

results = await asyncio.gather(
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
)

assert all(val is results[0] for val in results)


async def test_async_singleton_sync_resolve_failure() -> None:
with pytest.raises(RuntimeError, match="AsyncSingleton cannot be resolved in an sync context."):
DIContainer.singleton_async.sync_resolve()

0 comments on commit 45c0845

Please sign in to comment.