diff --git a/django_ai_assistant/ai/chat_message_histories.py b/django_ai_assistant/ai/chat_message_histories.py index 8b0b4d2..33b9760 100644 --- a/django_ai_assistant/ai/chat_message_histories.py +++ b/django_ai_assistant/ai/chat_message_histories.py @@ -9,6 +9,8 @@ import logging from typing import Any, List, Sequence, cast +from django.db import transaction + from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict @@ -68,12 +70,19 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None: Args: messages: A list of BaseMessage objects to store. """ - Message.objects.bulk_create( - [ + with transaction.atomic(): + message_objects = [ Message(thread_id=self._thread_id, message=message_to_dict(message)) for message in messages ] - ) + + created_messages = Message.objects.bulk_create(message_objects) + + # Update langchain message IDs with Django message IDs + for created_message in created_messages: + created_message.message["data"]["id"] = str(created_message.id) + + Message.objects.bulk_update(created_messages, ["message"]) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: """Add messages to the chat thread. @@ -81,12 +90,36 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: Args: messages: A list of BaseMessage objects to store. """ - await Message.objects.abulk_create( - [ - Message(thread_id=self._thread_id, message=message_to_dict(message)) - for message in messages - ] - ) + # NOTE: This method does not use transactions because it do not yet work in async mode. + # Source: https://docs.djangoproject.com/en/5.0/topics/async/#queries-the-orm + message_objects = [ + Message(thread_id=self._thread_id, message=message_to_dict(message)) + for message in messages + ] + + created_messages = await Message.objects.abulk_create(message_objects) + + # Update langchain message IDs with Django message IDs + for created_message in created_messages: + created_message.message["data"]["id"] = str(created_message.id) + + await Message.objects.abulk_update(created_messages, ["message"]) + + def remove_messages(self, message_ids: List[str]) -> None: + """Remove messages from the chat thread. + + Args: + message_ids: A list of message IDs to remove. + """ + Message.objects.filter(id__in=message_ids).delete() + + async def aremove_messages(self, message_ids: List[str]) -> None: + """Remove messages from the chat thread. + + Args: + message_ids: A list of message IDs to remove. + """ + await Message.objects.filter(id__in=message_ids).adelete() def _get_messages_qs(self): return Message.objects.filter(thread_id=self._thread_id).order_by("created_at") diff --git a/django_ai_assistant/conf.py b/django_ai_assistant/conf.py index 3c9cea2..24eb586 100644 --- a/django_ai_assistant/conf.py +++ b/django_ai_assistant/conf.py @@ -13,6 +13,7 @@ "CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread", "CAN_DELETE_THREAD_FN": "django_ai_assistant.permissions.owns_thread", "CAN_CREATE_MESSAGE_FN": "django_ai_assistant.permissions.owns_thread", + "CAN_DELETE_MESSAGE_FN": "django_ai_assistant.permissions.owns_thread", "CAN_RUN_ASSISTANT": "django_ai_assistant.permissions.allow_all", } diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index 9411551..fb5a4d3 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -44,10 +44,11 @@ AIAssistantNotDefinedError, AIUserNotAllowedError, ) -from django_ai_assistant.models import Thread +from django_ai_assistant.models import Message, Thread from django_ai_assistant.permissions import ( can_create_message, can_create_thread, + can_delete_message, can_delete_thread, can_run_assistant, ) @@ -416,3 +417,15 @@ def create_thread_message_as_user( raise AIUserNotAllowedError("User is not allowed to create messages in this thread") DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)]) + + +def delete_message( + message: Message, + user: Any, + request: HttpRequest | None = None, + view: View | None = None, +): + if not can_delete_message(message=message, user=user, request=request, view=view): + raise AIUserNotAllowedError("User is not allowed to delete this message") + + return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([message.id]) diff --git a/django_ai_assistant/permissions.py b/django_ai_assistant/permissions.py index ffa43e3..0720169 100644 --- a/django_ai_assistant/permissions.py +++ b/django_ai_assistant/permissions.py @@ -4,7 +4,7 @@ from django.views import View from django_ai_assistant.conf import app_settings -from django_ai_assistant.models import Thread +from django_ai_assistant.models import Message, Thread def _get_default_kwargs(user: Any, request: HttpRequest | None, view: View | None): @@ -57,6 +57,21 @@ def can_create_message( ) +def can_delete_message( + message: Message, + user: Any, + request: HttpRequest | None = None, + view: View | None = None, + **kwargs, +) -> bool: + return app_settings.call_fn( + "CAN_DELETE_MESSAGE_FN", + **_get_default_kwargs(user, request, view), + message=message, + thread=message.thread, + ) + + def can_run_assistant( assistant_cls, user: Any, @@ -71,11 +86,11 @@ def can_run_assistant( ) -def allow_all(**kwargs): +def allow_all(**kwargs) -> bool: return True -def owns_thread(user, thread, **kwargs): +def owns_thread(user: Any, thread: Thread, **kwargs) -> bool: if user.is_superuser: return True diff --git a/django_ai_assistant/schemas.py b/django_ai_assistant/schemas.py index 7eabb9f..950179b 100644 --- a/django_ai_assistant/schemas.py +++ b/django_ai_assistant/schemas.py @@ -42,5 +42,6 @@ class ThreadMessageTypeEnum(str, Enum): class ThreadMessagesSchemaOut(Schema): + id: str # noqa: A003 type: ThreadMessageTypeEnum # noqa: A003 content: str diff --git a/django_ai_assistant/views.py b/django_ai_assistant/views.py index 8c3893f..0813725 100644 --- a/django_ai_assistant/views.py +++ b/django_ai_assistant/views.py @@ -15,7 +15,7 @@ get_thread_messages, get_threads, ) -from .models import Thread +from .models import Message, Thread from .schemas import ( AssistantSchema, ThreadMessagesSchemaIn, @@ -96,3 +96,17 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema ) return 201, None + + +@api.delete( + "threads/{thread_id}/messages/{message_id}/", response={204: None}, url_name="messages_delete" +) +def delete_thread_message(request, thread_id: str, message_id: str): + message = get_object_or_404(Message, id=message_id, thread_id=thread_id) + assistants.delete_message( + message=message, + user=request.user, + request=request, + view=None, + ) + return 204, None diff --git a/example/assets/js/Chat/Chat.tsx b/example/assets/js/Chat/Chat.tsx index 21af812..6f66241 100644 --- a/example/assets/js/Chat/Chat.tsx +++ b/example/assets/js/Chat/Chat.tsx @@ -1,19 +1,24 @@ import { - Container, - Text, - Stack, - Title, - Textarea, + ActionIcon, + Avatar, Box, Button, + Container, + Group, LoadingOverlay, + Paper, ScrollArea, + Stack, + Text, + Textarea, + Title, + Tooltip, } from "@mantine/core"; import { ThreadsNav } from "./ThreadsNav"; import classes from "./Chat.module.css"; import { useCallback, useEffect, useRef, useState } from "react"; -import { IconSend2 } from "@tabler/icons-react"; +import { IconSend2, IconTrash } from "@tabler/icons-react"; import { getHotkeyHandler } from "@mantine/hooks"; import Markdown from "react-markdown"; @@ -25,31 +30,101 @@ import { useThread, } from "django-ai-assistant-client"; -function ChatMessage({ message }: { message: ThreadMessagesSchemaOut }) { +function ChatMessage({ + threadId, + message, + deleteMessage, +}: { + threadId: string; + message: ThreadMessagesSchemaOut; + deleteMessage: ({ + threadId, + messageId, + }: { + threadId: string; + messageId: string; + }) => Promise; +}) { + const isUserMessage = message.type === "human"; + + const DeleteButton = () => ( + + { + await deleteMessage({ threadId, messageId: message.id }); + }} + aria-label="Delete message" + > + + + + ); + return ( - - {message.type === "ai" ? "AI" : "User"} - {message.content} - + + {!isUserMessage ? ( + + AI + + ) : null} + + {isUserMessage ? : null} + + + + {message.content} + + + + {!isUserMessage ? : null} + ); } function ChatMessageList({ + threadId, messages, + deleteMessage, }: { + threadId: string; messages: ThreadMessagesSchemaOut[]; + deleteMessage: ({ + threadId, + messageId, + }: { + threadId: string; + messageId: string; + }) => Promise; }) { if (messages.length === 0) { return No messages.; } - // TODO: check why horizontal scroll appears return ( -
+ {messages.map((message, index) => ( - + ))} -
+ ); } @@ -66,9 +141,12 @@ export function Chat() { loadingFetchMessages, createMessage, loadingCreateMessage, + deleteMessage, + loadingDeleteMessage, } = useMessage(); - const loadingMessages = loadingFetchMessages || loadingCreateMessage; + const loadingMessages = + loadingFetchMessages || loadingCreateMessage || loadingDeleteMessage; const isThreadSelected = assistantId && activeThread; const isChatActive = assistantId && activeThread && !loadingMessages; @@ -154,7 +232,11 @@ export function Chat() { overlayProps={{ blur: 2 }} /> {isThreadSelected ? ( - + ) : ( Select or create a thread to start chatting. diff --git a/example/example/settings.py b/example/example/settings.py index 39615ea..91101b1 100644 --- a/example/example/settings.py +++ b/example/example/settings.py @@ -159,6 +159,7 @@ AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" +AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all" diff --git a/example_movies/example_movies/settings.py b/example_movies/example_movies/settings.py index 3990a31..158518e 100644 --- a/example_movies/example_movies/settings.py +++ b/example_movies/example_movies/settings.py @@ -159,4 +159,5 @@ AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" +AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all" diff --git a/example_rag/example_rag/settings.py b/example_rag/example_rag/settings.py index 05a44e0..d08c669 100644 --- a/example_rag/example_rag/settings.py +++ b/example_rag/example_rag/settings.py @@ -159,6 +159,7 @@ AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" +AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all" diff --git a/frontend/openapi_schema.json b/frontend/openapi_schema.json index 877ca63..466bc24 100644 --- a/frontend/openapi_schema.json +++ b/frontend/openapi_schema.json @@ -163,6 +163,37 @@ "required": true } } + }, + "/threads/{thread_id}/messages/{message_id}/": { + "delete": { + "operationId": "django_ai_assistant_views_delete_thread_message", + "summary": "Delete Thread Message", + "parameters": [ + { + "in": "path", + "name": "thread_id", + "schema": { + "title": "Thread Id", + "type": "string" + }, + "required": true + }, + { + "in": "path", + "name": "message_id", + "schema": { + "title": "Message Id", + "type": "string" + }, + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + } + } + } } }, "components": { @@ -252,6 +283,10 @@ }, "ThreadMessagesSchemaOut": { "properties": { + "id": { + "title": "Id", + "type": "string" + }, "type": { "$ref": "#/components/schemas/ThreadMessageTypeEnum" }, @@ -261,6 +296,7 @@ } }, "required": [ + "id", "type", "content" ], diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index 0bbc27f..f451e3f 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -76,6 +76,10 @@ export const $ThreadMessageTypeEnum = { export const $ThreadMessagesSchemaOut = { properties: { + id: { + title: 'Id', + type: 'string' + }, type: { '$ref': '#/components/schemas/ThreadMessageTypeEnum' }, @@ -84,7 +88,7 @@ export const $ThreadMessagesSchemaOut = { type: 'string' } }, - required: ['type', 'content'], + required: ['id', 'type', 'content'], title: 'ThreadMessagesSchemaOut', type: 'object' } as const; diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 1bb5c8e..0f4e394 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { DjangoAiAssistantViewsListAssistantsResponse, DjangoAiAssistantViewsListThreadsResponse, DjangoAiAssistantViewsCreateThreadData, DjangoAiAssistantViewsCreateThreadResponse, DjangoAiAssistantViewsDeleteThreadData, DjangoAiAssistantViewsDeleteThreadResponse, DjangoAiAssistantViewsListThreadMessagesData, DjangoAiAssistantViewsListThreadMessagesResponse, DjangoAiAssistantViewsCreateThreadMessageData, DjangoAiAssistantViewsCreateThreadMessageResponse } from './types.gen'; +import type { DjangoAiAssistantViewsListAssistantsResponse, DjangoAiAssistantViewsListThreadsResponse, DjangoAiAssistantViewsCreateThreadData, DjangoAiAssistantViewsCreateThreadResponse, DjangoAiAssistantViewsDeleteThreadData, DjangoAiAssistantViewsDeleteThreadResponse, DjangoAiAssistantViewsListThreadMessagesData, DjangoAiAssistantViewsListThreadMessagesResponse, DjangoAiAssistantViewsCreateThreadMessageData, DjangoAiAssistantViewsCreateThreadMessageResponse, DjangoAiAssistantViewsDeleteThreadMessageData, DjangoAiAssistantViewsDeleteThreadMessageResponse } from './types.gen'; /** * List Assistants @@ -85,4 +85,21 @@ export const djangoAiAssistantViewsCreateThreadMessage = (data: DjangoAiAssistan }, body: data.requestBody, mediaType: 'application/json' +}); }; + +/** + * Delete Thread Message + * @param data The data for the request. + * @param data.threadId + * @param data.messageId + * @returns void No Content + * @throws ApiError + */ +export const djangoAiAssistantViewsDeleteThreadMessage = (data: DjangoAiAssistantViewsDeleteThreadMessageData): CancelablePromise => { return __request(OpenAPI, { + method: 'DELETE', + url: '/threads/{thread_id}/messages/{message_id}/', + path: { + thread_id: data.threadId, + message_id: data.messageId + } }); }; \ No newline at end of file diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 7505ee1..11292e9 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -19,6 +19,7 @@ export type ThreadSchemaIn = { export type ThreadMessageTypeEnum = 'human' | 'ai' | 'generic' | 'system' | 'function' | 'tool'; export type ThreadMessagesSchemaOut = { + id: string; type: ThreadMessageTypeEnum; content: string; }; @@ -57,6 +58,13 @@ export type DjangoAiAssistantViewsCreateThreadMessageData = { export type DjangoAiAssistantViewsCreateThreadMessageResponse = unknown; +export type DjangoAiAssistantViewsDeleteThreadMessageData = { + messageId: string; + threadId: string; +}; + +export type DjangoAiAssistantViewsDeleteThreadMessageResponse = void; + export type $OpenApiTs = { '/assistants/': { get: { @@ -118,4 +126,15 @@ export type $OpenApiTs = { }; }; }; + '/threads/{thread_id}/messages/{message_id}/': { + delete: { + req: DjangoAiAssistantViewsDeleteThreadMessageData; + res: { + /** + * No Content + */ + 204: void; + }; + }; + }; }; \ No newline at end of file diff --git a/frontend/src/hooks/useMessage.ts b/frontend/src/hooks/useMessage.ts index 9121291..eca4757 100644 --- a/frontend/src/hooks/useMessage.ts +++ b/frontend/src/hooks/useMessage.ts @@ -1,9 +1,10 @@ import { useCallback } from "react"; import { useState } from "react"; import { - ThreadMessagesSchemaOut, djangoAiAssistantViewsCreateThreadMessage, + djangoAiAssistantViewsDeleteThreadMessage, djangoAiAssistantViewsListThreadMessages, + ThreadMessagesSchemaOut, } from "../client"; /** @@ -17,6 +18,8 @@ export function useMessage() { useState(false); const [loadingCreateMessage, setLoadingCreateMessage] = useState(false); + const [loadingDeleteMessage, setLoadingDeleteMessage] = + useState(false); /** * Fetches a list of messages. @@ -50,7 +53,6 @@ export function useMessage() { * @param threadId The ID of the thread in which to create the message. * @param assistantId The ID of the assistant. * @param messageTextValue The content of the message. - * @returns A promise that resolves to undefined when the message is created. */ const createMessage = useCallback( async ({ @@ -61,7 +63,7 @@ export function useMessage() { threadId: string; assistantId: string; messageTextValue: string; - }): Promise => { + }): Promise => { try { setLoadingCreateMessage(true); // successful response is 201, None @@ -73,7 +75,6 @@ export function useMessage() { }, }); await fetchMessages({ threadId }); - return undefined; } finally { setLoadingCreateMessage(false); } @@ -81,6 +82,34 @@ export function useMessage() { [fetchMessages] ); + /** + * Deletes a message in a thread. + * + * @param threadId The ID of the thread in which to delete the message. + * @param messageId The ID of the message to delete. + */ + const deleteMessage = useCallback( + async ({ + threadId, + messageId, + }: { + threadId: string; + messageId: string; + }): Promise => { + try { + setLoadingDeleteMessage(true); + await djangoAiAssistantViewsDeleteThreadMessage({ + threadId, + messageId, + }); + await fetchMessages({ threadId }); + } finally { + setLoadingDeleteMessage(false); + } + }, + [fetchMessages] + ); + return { /** * Function to fetch messages for a thread from the server. @@ -90,6 +119,10 @@ export function useMessage() { * Function to create a new message in a thread. */ createMessage, + /** + * Function to delete a message in a thread. + */ + deleteMessage, /** * Array of fetched messages. */ @@ -102,5 +135,9 @@ export function useMessage() { * Loading state of the create operation. */ loadingCreateMessage, + /** + * Loading state of the delete operation. + */ + loadingDeleteMessage, }; } diff --git a/frontend/tests/useMessage.test.ts b/frontend/tests/useMessage.test.ts index cb75e3b..d18ef6c 100644 --- a/frontend/tests/useMessage.test.ts +++ b/frontend/tests/useMessage.test.ts @@ -2,7 +2,9 @@ import { act, renderHook } from "@testing-library/react"; import { useMessage } from "../src/hooks"; import { djangoAiAssistantViewsCreateThreadMessage, + djangoAiAssistantViewsDeleteThreadMessage, djangoAiAssistantViewsListThreadMessages, + ThreadMessagesSchemaOut, } from "../src/client"; jest.mock("../src/client", () => ({ @@ -12,6 +14,9 @@ jest.mock("../src/client", () => ({ djangoAiAssistantViewsListThreadMessages: jest .fn() .mockImplementation(() => Promise.resolve()), + djangoAiAssistantViewsDeleteThreadMessage: jest + .fn() + .mockImplementation(() => Promise.resolve()), })); describe("useMessage", () => { @@ -19,12 +24,14 @@ describe("useMessage", () => { jest.clearAllMocks(); }); - const mockMessages = [ + const mockMessages: ThreadMessagesSchemaOut[] = [ { + id: "1", type: "human", content: "Hello!", }, { + id: "2", type: "ai", content: "Hello! How can I assist you today?", }, @@ -143,4 +150,61 @@ describe("useMessage", () => { expect(result.current.loadingCreateMessage).toBe(false); }); }); + + describe("deleteMessage", () => { + it("should delete a message and update state correctly", async () => { + const deletedMessageId = mockMessages[0].id; + (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockResolvedValue( + mockMessages.filter((message) => message.id !== deletedMessageId) + ); + + const { result } = renderHook(() => useMessage()); + + result.current.messages = mockMessages; + + expect(result.current.messages).toEqual(mockMessages); + expect(result.current.loadingDeleteMessage).toBe(false); + + await act(async () => { + await result.current.deleteMessage({ + threadId: "1", + messageId: deletedMessageId, + }); + }); + + expect(result.current.messages).toEqual( + mockMessages.filter((message) => message.id !== deletedMessageId) + ); + expect(result.current.loadingDeleteMessage).toBe(false); + }); + + it("should set loading to false if delete fails", async () => { + const deletedMessageId = mockMessages[0].id; + (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockResolvedValue( + mockMessages.filter((message) => message.id !== deletedMessageId) + ); + ( + djangoAiAssistantViewsDeleteThreadMessage as jest.Mock + ).mockRejectedValue(new Error("Failed to delete")); + + const { result } = renderHook(() => useMessage()); + + result.current.messages = mockMessages; + + expect(result.current.messages).toEqual(mockMessages); + expect(result.current.loadingDeleteMessage).toBe(false); + + await expect(async () => { + await act(async () => { + await result.current.deleteMessage({ + threadId: "1", + messageId: deletedMessageId, + }); + }); + }).rejects.toThrow("Failed to delete"); + + expect(result.current.messages).toEqual(mockMessages); + expect(result.current.loadingDeleteMessage).toBe(false); + }); + }); }); diff --git a/tests/ai/test_chat_message_histories.py b/tests/ai/test_chat_message_histories.py index 4e363be..1127717 100644 --- a/tests/ai/test_chat_message_histories.py +++ b/tests/ai/test_chat_message_histories.py @@ -55,6 +55,71 @@ async def test_aadd_messages(thread_aaa, thread_bbb): assert await other_thread.messages.acount() == 0 +def test_remove_messages(thread_aaa, thread_bbb): + thread = Thread.objects.get(name="AAA") + other_thread = Thread.objects.get(name="BBB") + + Message.objects.bulk_create( + [ + Message(thread=thread, message={"data": {"content": "Hello, world!"}, "type": "human"}), + Message(thread=thread, message={"data": {"content": "Hi! How are you?"}, "type": "ai"}), + Message(thread=other_thread, message={"data": {"content": "Olá!"}, "type": "human"}), + Message( + thread=other_thread, message={"data": {"content": "Olá! Como vai?"}, "type": "ai"} + ), + Message( + thread=other_thread, + message={"data": {"content": "Bem, está quente em Recife?"}, "type": "human"}, + ), + ] + ) + + assert thread.messages.count() == 2 + + messages = thread.messages.order_by("created_at") + message_history = DjangoChatMessageHistory(thread_id=thread.id) + message_history.remove_messages([messages[0].id]) + + assert messages.count() == 1 + assert messages.first().message["data"]["content"] == "Hi! How are you?" + assert other_thread.messages.count() == 3 + + +@pytest.mark.asyncio +@pytest.mark.django_db(transaction=True) +async def test_aremove_messages(thread_aaa, thread_bbb): + thread = await Thread.objects.aget(name="AAA") + other_thread = await Thread.objects.aget(name="BBB") + + await Message.objects.abulk_create( + [ + Message(thread=thread, message={"data": {"content": "Hello, world!"}, "type": "human"}), + Message(thread=thread, message={"data": {"content": "Hi! How are you?"}, "type": "ai"}), + Message(thread=other_thread, message={"data": {"content": "Olá!"}, "type": "human"}), + Message( + thread=other_thread, message={"data": {"content": "Olá! Como vai?"}, "type": "ai"} + ), + Message( + thread=other_thread, + message={"data": {"content": "Bem, está quente em Recife?"}, "type": "human"}, + ), + ] + ) + + assert await thread.messages.acount() == 2 + + message_history = DjangoChatMessageHistory(thread_id=thread.id) + await message_history.aremove_messages( + [(await thread.messages.order_by("created_at").afirst()).id] + ) + + assert await thread.messages.acount() == 1 + assert (await thread.messages.order_by("created_at").afirst()).message["data"][ + "content" + ] == "Hi! How are you?" + assert await other_thread.messages.acount() == 3 + + def test_get_messages(thread_aaa, thread_bbb): thread = Thread.objects.get(name="AAA") other_thread = Thread.objects.get(name="BBB") diff --git a/tests/settings.py b/tests/settings.py index 64869d4..30f7741 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -112,4 +112,5 @@ AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" +AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread" AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all" diff --git a/tests/test_assistants.py b/tests/test_assistants.py index b5bf3ea..dd2bb7f 100644 --- a/tests/test_assistants.py +++ b/tests/test_assistants.py @@ -50,6 +50,9 @@ def test_AIAssistant_invoke(): thread_id=thread.id, ) + messages = thread.messages.order_by("created_at").values_list("message", flat=True) + messages_ids = thread.messages.order_by("created_at").values_list("id", flat=True) + assert response_0 == { "history": [], "input": "What is the temperature today in Recife?", @@ -57,23 +60,28 @@ def test_AIAssistant_invoke(): } assert response_1 == { "history": [ - HumanMessage(content="What is the temperature today in Recife?"), - AIMessage(content="The current temperature in Recife today is 32 degrees Celsius."), + HumanMessage(content="What is the temperature today in Recife?", id=messages_ids[0]), + AIMessage( + content="The current temperature in Recife today is 32 degrees Celsius.", + id=messages_ids[1], + ), ], "input": "What about tomorrow?", "output": "The forecasted temperature in Recife for tomorrow, June 10, 2024, is " "expected to be 35 degrees Celsius.", } - assert list( - thread.messages.order_by("created_at").values_list("message", flat=True) - ) == messages_to_dict( + assert list(messages) == messages_to_dict( [ - HumanMessage(content="What is the temperature today in Recife?"), - AIMessage(content="The current temperature in Recife today is 32 degrees Celsius."), - HumanMessage(content="What about tomorrow?"), + HumanMessage(content="What is the temperature today in Recife?", id=messages_ids[0]), + AIMessage( + content="The current temperature in Recife today is 32 degrees Celsius.", + id=messages_ids[1], + ), + HumanMessage(content="What about tomorrow?", id=messages_ids[2]), AIMessage( content="The forecasted temperature in Recife for tomorrow, June 10, 2024, is " - "expected to be 35 degrees Celsius." + "expected to be 35 degrees Celsius.", + id=messages_ids[3], ), ] ) @@ -140,6 +148,9 @@ def test_AIAssistant_with_rag_invoke(): thread_id=thread.id, ) + messages = thread.messages.order_by("created_at").values_list("message", flat=True) + messages_ids = thread.messages.order_by("created_at").values_list("id", flat=True) + assert response_0 == { "history": [], "input": "I'm at Central Park W & 79st, New York, NY 10024, United States.", @@ -152,8 +163,8 @@ def test_AIAssistant_with_rag_invoke(): } assert response_1 == { "history": [ - HumanMessage(content=response_0["input"]), - AIMessage(content=response_0["output"]), + HumanMessage(content=response_0["input"], id=messages_ids[0]), + AIMessage(content=response_0["output"], id=messages_ids[1]), ], "input": "11 W 53rd St, New York, NY 10019, United States.", "output": "You're at the location of the Museum of Modern Art (MoMA), home to an " @@ -163,13 +174,11 @@ def test_AIAssistant_with_rag_invoke(): "observation deck. These attractions offer a blend of artistic and urban " "experiences.", } - assert list( - thread.messages.order_by("created_at").values_list("message", flat=True) - ) == messages_to_dict( + assert list(messages) == messages_to_dict( [ - HumanMessage(content=response_0["input"]), - AIMessage(content=response_0["output"]), - HumanMessage(content=response_1["input"]), - AIMessage(content=response_1["output"]), + HumanMessage(content=response_0["input"], id=messages_ids[0]), + AIMessage(content=response_0["output"], id=messages_ids[1]), + HumanMessage(content=response_1["input"], id=messages_ids[2]), + AIMessage(content=response_1["output"], id=messages_ids[3]), ] )