Skip to content

Commit

Permalink
Merge pull request #93 from vintasoftware/td/add-tests-to-views-threads
Browse files Browse the repository at this point in the history
Adds views tests for Threads
  • Loading branch information
fjsj authored Jun 19, 2024
2 parents c987e93 + eb43ab5 commit 59ff7ed
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 7 deletions.
23 changes: 21 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ipython = "^8.24.0"
pytest-asyncio = "^0.23.7"
pytest-recording = "^0.13.1"
coveralls = "^4.0.1"
model-bakery = "^1.18.1"

[tool.poetry.group.example.dependencies]
django-webpack-loader = "^3.1.0"
Expand Down
161 changes: 156 additions & 5 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from http import HTTPStatus

from django.contrib.auth.models import User
from django.urls import reverse

import pytest
from model_bakery import baker

from django_ai_assistant.exceptions import AIAssistantNotDefinedError
from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant
from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool
from django_ai_assistant.models import Thread


# Set up
Expand All @@ -14,7 +19,6 @@
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
description = "A temperature assistant that provides temperature information."
instructions = "You are a temperature bot."
model = "gpt-4o"

Expand All @@ -36,11 +40,18 @@ def fetch_forecast_temperature(self, location: str, dt_str: str) -> str:
return "35 degrees Celsius"


@pytest.fixture
def authenticated_client(client):
User.objects.create_user(username="testuser", password="password")
client.login(username="testuser", password="password")
return client


# Assistant Views


def test_list_assistants_with_results(client):
response = client.get("/assistants/")
response = client.get(reverse("django_ai_assistant:assistants_list"))

assert response.status_code == HTTPStatus.OK
assert response.json() == [{"id": "temperature_assistant", "name": "Temperature Assistant"}]
Expand All @@ -52,15 +63,23 @@ def test_does_not_list_assistants_if_unauthorized():


def test_get_assistant_that_exists(client):
response = client.get("/assistants/temperature_assistant/")
response = client.get(
reverse(
"django_ai_assistant:assistant_detail", kwargs={"assistant_id": "temperature_assistant"}
)
)

assert response.status_code == HTTPStatus.OK
assert response.json() == {"id": "temperature_assistant", "name": "Temperature Assistant"}


def test_get_assistant_that_does_not_exist(client):
with pytest.raises(AIAssistantNotDefinedError):
client.get("/assistants/fake_assistant/")
client.get(
reverse(
"django_ai_assistant:assistant_detail", kwargs={"assistant_id": "fake_assistant"}
)
)


def test_does_not_return_assistant_if_unauthorized():
Expand All @@ -70,4 +89,136 @@ def test_does_not_return_assistant_if_unauthorized():

# Threads Views

# Up next
# GET


@pytest.mark.django_db(transaction=True)
def test_list_threads_without_results(authenticated_client):
response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create"))

assert response.status_code == HTTPStatus.OK
assert response.json() == []


@pytest.mark.django_db(transaction=True)
def test_list_threads_with_results(authenticated_client):
user = User.objects.first()
baker.make(Thread, created_by=user, _quantity=2)
response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create"))

assert response.status_code == HTTPStatus.OK
assert len(response.json()) == 2


@pytest.mark.django_db(transaction=True)
def test_does_not_list_other_users_threads(authenticated_client):
baker.make(Thread)
response = authenticated_client.get(reverse("django_ai_assistant:threads_list_create"))

assert response.status_code == HTTPStatus.OK
assert response.json() == []


@pytest.mark.django_db(transaction=True)
def test_gets_specific_thread(authenticated_client):
thread = baker.make(Thread, created_by=User.objects.first())
response = authenticated_client.get(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id})
)

assert response.status_code == HTTPStatus.OK
assert response.json().get("id") == thread.id


def test_does_not_list_threads_if_unauthorized():
# TODO: Implement this test once permissions are in place
pass


# POST


@pytest.mark.django_db(transaction=True)
def test_create_thread(authenticated_client):
response = authenticated_client.post(
reverse("django_ai_assistant:threads_list_create"), data={}, content_type="application/json"
)

thread = Thread.objects.first()

assert response.status_code == HTTPStatus.OK
assert response.json().get("id") == thread.id


def test_cannot_create_thread_if_unauthorized():
# TODO: Implement this test once permissions are in place
pass


# PATCH


@pytest.mark.django_db(transaction=True)
def test_update_thread(authenticated_client):
thread = baker.make(Thread, created_by=User.objects.first())
response = authenticated_client.patch(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}),
data={"name": "New name"},
content_type="application/json",
)

assert response.status_code == HTTPStatus.OK
assert Thread.objects.filter(id=thread.id).first().name == "New name"


@pytest.mark.django_db(transaction=True)
def test_cannot_update_other_users_threads(authenticated_client):
thread = baker.make(Thread)
response = authenticated_client.patch(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id}),
data={"name": "New name"},
content_type="application/json",
)

assert response.status_code == HTTPStatus.FORBIDDEN
assert Thread.objects.filter(id=thread.id).first().name != "New name"


def test_cannot_update_thread_if_unauthorized():
# TODO: Implement this test once permissions are in place
pass


# DELETE


@pytest.mark.django_db(transaction=True)
def test_delete_thread(authenticated_client):
thread = baker.make(Thread, created_by=User.objects.first())
response = authenticated_client.delete(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id})
)

assert response.status_code == HTTPStatus.NO_CONTENT
assert not Thread.objects.filter(id=thread.id).exists()


@pytest.mark.django_db(transaction=True)
def test_cannot_delete_other_users_threads(authenticated_client):
thread = baker.make(Thread)
response = authenticated_client.delete(
reverse("django_ai_assistant:thread_detail_update_delete", kwargs={"thread_id": thread.id})
)

assert response.status_code == HTTPStatus.FORBIDDEN
assert Thread.objects.filter(id=thread.id).exists()


def test_cannot_delete_thread_if_unauthorized():
# TODO: Implement this test once permissions are in place
pass


# Threads Messages Views (will need VCR)

# TBD

0 comments on commit 59ff7ed

Please sign in to comment.