Skip to content

Commit

Permalink
Tests and docs for litestar testing with overriding (#123)
Browse files Browse the repository at this point in the history
* cover test case with mock overriding

* docs with examples about overriding with litestar
  • Loading branch information
nightblure authored Nov 19, 2024
1 parent d9dfff9 commit f185392
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
72 changes: 71 additions & 1 deletion docs/testing/provider-overriding.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,74 @@ async def test_case_2():
redis_url = func()
assert redis_url == MOCK_REDIS_URL # ASSERTION ERROR

```
```

---
## Using with Litestar
In order to be able to inject dependencies of any type instead of existing objects,
we need to **change the typing** for the injected parameter as follows:

```python3
import typing
from functools import partial
from typing import Annotated
from unittest.mock import Mock

from litestar import Litestar, Router, get
from litestar.di import Provide
from litestar.params import Dependency
from litestar.testing import TestClient

from that_depends import BaseContainer, providers


class ExampleService:
def do_smth(self) -> str:
return "something"


class DIContainer(BaseContainer):
example_service = providers.Factory(ExampleService)


@get(path="/another-endpoint", dependencies={"example_service": Provide(DIContainer.example_service)})
async def endpoint_handler(
example_service: Annotated[ExampleService, Dependency(skip_validation=True)],
) -> dict[str, typing.Any]:
return {"object": example_service.do_smth()}


# or if you want a little less code
NoValidationDependency = partial(Dependency, skip_validation=True)


@get(path="/another-endpoint", dependencies={"example_service": Provide(DIContainer.example_service)})
async def endpoint_handler(
example_service: Annotated[ExampleService, NoValidationDependency()],
) -> dict[str, typing.Any]:
return {"object": example_service.do_smth()}


router = Router(
path="/router",
route_handlers=[endpoint_handler],
)

app = Litestar(route_handlers=[router])
```

Now we are ready to write tests with **overriding** and this will work with **any types**:
```python3
def test_litestar_endpoint_with_overriding() -> None:
some_service_mock = Mock(do_smth=lambda: "mock func")

with DIContainer.example_service.override_context(some_service_mock), TestClient(app=app) as client:
response = client.get("/router/another-endpoint")

assert response.status_code == 200
assert response.json()["object"] == "mock func"
```

More about `Dependency`
in the [Litestar documentation](https://docs.litestar.dev/2/usage/dependency-injection.html#the-dependency-function).

36 changes: 35 additions & 1 deletion tests/integrations/test_litestar_di.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import typing
from functools import partial
from typing import Annotated
from unittest.mock import Mock

from litestar import Controller, Litestar, Router, get
from litestar.di import Provide
from litestar.params import Dependency
from litestar.status_codes import HTTP_200_OK
from litestar.testing import TestClient

Expand All @@ -24,18 +28,32 @@ def int_fn() -> int:
return 1


class SomeService:
def do_smth(self) -> str:
return "something"


class DIContainer(BaseContainer):
bool_fn = providers.Factory(bool_fn, value=False)
str_fn = providers.Factory(str_fn)
list_fn = providers.Factory(list_fn)
int_fn = providers.Factory(int_fn)
some_service = providers.Factory(SomeService)


_NoValidationDependency = partial(Dependency, skip_validation=True)


class MyController(Controller):
path = "/controller"
dependencies = {"controller_dependency": Provide(DIContainer.list_fn)} # noqa: RUF012

@get(path="/handler", dependencies={"local_dependency": Provide(DIContainer.int_fn)})
@get(
path="/handler",
dependencies={
"local_dependency": Provide(DIContainer.int_fn),
},
)
async def my_route_handler(
self,
app_dependency: bool,
Expand All @@ -50,6 +68,12 @@ async def my_route_handler(
"local_dependency": local_dependency,
}

@get(path="/mock_overriding", dependencies={"some_service": Provide(DIContainer.some_service)})
async def mock_overriding_endpoint_handler(
self, some_service: Annotated[SomeService, _NoValidationDependency()]
) -> dict[str, typing.Any]:
return {"object": some_service.do_smth()}


my_router = Router(
path="/router",
Expand All @@ -61,6 +85,16 @@ async def my_route_handler(
app = Litestar(route_handlers=[my_router], dependencies={"app_dependency": Provide(DIContainer.bool_fn)}, debug=True)


def test_litestar_endpoint_with_mock_overriding() -> None:
some_service_mock = Mock(do_smth=lambda: "mock func")

with DIContainer.some_service.override_context(some_service_mock), TestClient(app=app) as client:
response = client.get("/router/controller/mock_overriding")

assert response.status_code == HTTP_200_OK
assert response.json()["object"] == "mock func"


def test_litestar_di() -> None:
with TestClient(app=app) as client:
response = client.get("/router/controller/handler")
Expand Down

0 comments on commit f185392

Please sign in to comment.