From 4e6824886715a869ddb560614bae68bbbbef9c31 Mon Sep 17 00:00:00 2001 From: amandasavluchinske Date: Mon, 17 Jun 2024 18:58:25 +0100 Subject: [PATCH] Updates threads tests to use pytest standards --- tests/test_views.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/tests/test_views.py b/tests/test_views.py index 4402049..6746e8b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,10 +1,14 @@ from http import HTTPStatus +from django.contrib.auth.models import User + import pytest +from model_bakery import baker from django_ai_assistant.exceptions import AIAssistantNotDefinedError from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool +from django_ai_assistant.models import Thread # Set up @@ -71,16 +75,26 @@ def test_does_not_return_assistant_if_unauthorized(): # Threads Views -# @pytest.mark.django_db(transaction=True) -# def test_list_threads_without_results(client): -# response = client.get("/threads/") +@pytest.fixture +def authenticated_client(client): + User.objects.create_user(username="testuser", password="password") + client.login(username="testuser", password="password") + return client + + +@pytest.mark.django_db(transaction=True) +def test_list_threads_without_results(authenticated_client): + response = authenticated_client.get("/threads/") + + assert response.status_code == HTTPStatus.OK + assert response.json() == [] -# assert response.status_code == HTTPStatus.OK -# assert response.json() == [] -# @pytest.mark.django_db(transaction=True) -# def test_list_threads_with_results(client): -# response = client.get("/threads/") +@pytest.mark.django_db(transaction=True) +def test_list_threads_with_results(authenticated_client): + user = User.objects.first() + thread = baker.make(Thread, created_by=user) + response = authenticated_client.get("/threads/") -# assert response.status_code == HTTPStatus.OK -# assert response.json()[0].get("id") == thread.id + assert response.status_code == HTTPStatus.OK + assert response.json()[0].get("id") == thread.id