diff --git a/poetry.lock b/poetry.lock index 8c30480..aef4010 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1259,6 +1259,24 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "model-bakery" +version = "1.18.1" +description = "Smart object creation facility for Django." +optional = false +python-versions = ">=3.8" +files = [ + {file = "model_bakery-1.18.1-py3-none-any.whl", hash = "sha256:49d672d41e8c377854d42e50df61ff5ee56b91a3b2ff24372344c05c29166eb2"}, + {file = "model_bakery-1.18.1.tar.gz", hash = "sha256:8cc2b7b0879a2fc400808225a4c1f830b322b181007577b588f4d4aac5d4cb4b"}, +] + +[package.dependencies] +django = ">=4.2" + +[package.extras] +docs = ["myst-parser", "sphinx", "sphinx-rtd-theme"] +test = ["black", "coverage", "mypy", "pillow", "pytest", "pytest-django", "ruff"] + [[package]] name = "multidict" version = "6.0.5" @@ -1869,6 +1887,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2705,4 +2724,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.12,<3.13" -content-hash = "ccd01a998eb526dd7a4b6b4a401f34462d16a6bf5eb703c10b5b0a7175e74ef7" +content-hash = "2bc5e1b1fe25c2765af0cd3a3b22a560ad9db2f8694771dd10ffab07669673f0" diff --git a/pyproject.toml b/pyproject.toml index 32485db..9425735 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ ipython = "^8.24.0" pytest-asyncio = "^0.23.7" pytest-recording = "^0.13.1" coveralls = "^4.0.1" +model-bakery = "^1.18.1" [tool.poetry.group.example.dependencies] django-webpack-loader = "^3.1.0" diff --git a/tests/test_views.py b/tests/test_views.py index 5ce68bc..701739e 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,10 +1,15 @@ from http import HTTPStatus +from django.contrib.auth.models import User +from django.urls import reverse + import pytest +from model_bakery import baker from django_ai_assistant.exceptions import AIAssistantNotDefinedError from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool +from django_ai_assistant.models import Thread # Set up @@ -14,7 +19,6 @@ class TemperatureAssistant(AIAssistant): id = "temperature_assistant" # noqa: A003 name = "Temperature Assistant" - description = "A temperature assistant that provides temperature information." instructions = "You are a temperature bot." model = "gpt-4o" @@ -36,11 +40,18 @@ def fetch_forecast_temperature(self, location: str, dt_str: str) -> str: return "35 degrees Celsius" +@pytest.fixture +def authenticated_client(client): + User.objects.create_user(username="testuser", password="password") + client.login(username="testuser", password="password") + return client + + # Assistant Views def test_list_assistants_with_results(client): - response = client.get("/assistants/") + response = client.get(reverse("django_ai_assistant:assistants_list")) assert response.status_code == HTTPStatus.OK assert response.json() == [{"id": "temperature_assistant", "name": "Temperature Assistant"}] @@ -52,7 +63,11 @@ def test_does_not_list_assistants_if_unauthorized(): def test_get_assistant_that_exists(client): - response = client.get("/assistants/temperature_assistant/") + response = client.get( + reverse( + "django_ai_assistant:assistant_detail", kwargs={"assistant_id": "temperature_assistant"} + ) + ) assert response.status_code == HTTPStatus.OK assert response.json() == {"id": "temperature_assistant", "name": "Temperature Assistant"} @@ -60,7 +75,11 @@ def test_get_assistant_that_exists(client): def test_get_assistant_that_does_not_exist(client): with pytest.raises(AIAssistantNotDefinedError): - client.get("/assistants/fake_assistant/") + client.get( + reverse( + "django_ai_assistant:assistant_detail", kwargs={"assistant_id": "fake_assistant"} + ) + ) def test_does_not_return_assistant_if_unauthorized(): @@ -70,4 +89,136 @@ def test_does_not_return_assistant_if_unauthorized(): # Threads Views -# Up next +# GET + + +@pytest.mark.django_db(transaction=True) +def test_list_threads_without_results(authenticated_client): + response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create")) + + assert response.status_code == HTTPStatus.OK + assert response.json() == [] + + +@pytest.mark.django_db(transaction=True) +def test_list_threads_with_results(authenticated_client): + user = User.objects.first() + baker.make(Thread, created_by=user, _quantity=2) + response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create")) + + assert response.status_code == HTTPStatus.OK + assert len(response.json()) == 2 + + +@pytest.mark.django_db(transaction=True) +def test_does_not_list_other_users_threads(authenticated_client): + baker.make(Thread) + response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create")) + + assert response.status_code == HTTPStatus.OK + assert response.json() == [] + + +@pytest.mark.django_db(transaction=True) +def test_gets_specific_thread(authenticated_client): + thread = baker.make(Thread, created_by=User.objects.first()) + response = authenticated_client.get( + reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}) + ) + + assert response.status_code == HTTPStatus.OK + assert response.json().get("id") == thread.id + + +def test_does_not_list_threads_if_unauthorized(): + # TODO: Implement this test once permissions are in place + pass + + +# POST + + +@pytest.mark.django_db(transaction=True) +def test_create_thread(authenticated_client): + response = authenticated_client.post( + reverse("django_ai_assistant:threads_list_create"), data={}, content_type="application/json" + ) + + thread = Thread.objects.first() + + assert response.status_code == HTTPStatus.OK + assert response.json().get("id") == thread.id + + +def test_cannot_create_thread_if_unauthorized(): + # TODO: Implement this test once permissions are in place + pass + + +# PATCH + + +@pytest.mark.django_db(transaction=True) +def test_update_thread(authenticated_client): + thread = baker.make(Thread, created_by=User.objects.first()) + response = authenticated_client.patch( + reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}), + data={"name": "New name"}, + content_type="application/json", + ) + + assert response.status_code == HTTPStatus.OK + assert Thread.objects.filter(id=thread.id).first().name == "New name" + + +@pytest.mark.django_db(transaction=True) +def test_cannot_update_other_users_threads(authenticated_client): + thread = baker.make(Thread) + response = authenticated_client.patch( + reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}), + data={"name": "New name"}, + content_type="application/json", + ) + + assert response.status_code == HTTPStatus.FORBIDDEN + assert Thread.objects.filter(id=thread.id).first().name != "New name" + + +def test_cannot_update_thread_if_unauthorized(): + # TODO: Implement this test once permissions are in place + pass + + +# DELETE + + +@pytest.mark.django_db(transaction=True) +def test_delete_thread(authenticated_client): + thread = baker.make(Thread, created_by=User.objects.first()) + response = authenticated_client.delete( + reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}) + ) + + assert response.status_code == HTTPStatus.NO_CONTENT + assert not Thread.objects.filter(id=thread.id).exists() + + +@pytest.mark.django_db(transaction=True) +def test_cannot_delete_other_users_threads(authenticated_client): + thread = baker.make(Thread) + response = authenticated_client.delete( + reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}) + ) + + assert response.status_code == HTTPStatus.FORBIDDEN + assert Thread.objects.filter(id=thread.id).exists() + + +def test_cannot_delete_thread_if_unauthorized(): + # TODO: Implement this test once permissions are in place + pass + + +# Threads Messages Views (will need VCR) + +# TBD