Skip to content

Commit

Permalink
Merge pull request #75 from vintasoftware/feat/message-delete
Browse files Browse the repository at this point in the history
Implement Message deletion
  • Loading branch information
fjsj authored Jun 14, 2024
2 parents 0b1889f + 277910b commit 8bace80
Show file tree
Hide file tree
Showing 19 changed files with 470 additions and 56 deletions.
51 changes: 42 additions & 9 deletions django_ai_assistant/ai/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
from typing import Any, List, Sequence, cast

from django.db import transaction

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict

Expand Down Expand Up @@ -68,25 +70,56 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
Args:
messages: A list of BaseMessage objects to store.
"""
Message.objects.bulk_create(
[
with transaction.atomic():
message_objects = [
Message(thread_id=self._thread_id, message=message_to_dict(message))
for message in messages
]
)

created_messages = Message.objects.bulk_create(message_objects)

# Update langchain message IDs with Django message IDs
for created_message in created_messages:
created_message.message["data"]["id"] = str(created_message.id)

Message.objects.bulk_update(created_messages, ["message"])

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the chat thread.
Args:
messages: A list of BaseMessage objects to store.
"""
await Message.objects.abulk_create(
[
Message(thread_id=self._thread_id, message=message_to_dict(message))
for message in messages
]
)
# NOTE: This method does not use transactions because it do not yet work in async mode.
# Source: https://docs.djangoproject.com/en/5.0/topics/async/#queries-the-orm
message_objects = [
Message(thread_id=self._thread_id, message=message_to_dict(message))
for message in messages
]

created_messages = await Message.objects.abulk_create(message_objects)

# Update langchain message IDs with Django message IDs
for created_message in created_messages:
created_message.message["data"]["id"] = str(created_message.id)

await Message.objects.abulk_update(created_messages, ["message"])

def remove_messages(self, message_ids: List[str]) -> None:
"""Remove messages from the chat thread.
Args:
message_ids: A list of message IDs to remove.
"""
Message.objects.filter(id__in=message_ids).delete()

async def aremove_messages(self, message_ids: List[str]) -> None:
"""Remove messages from the chat thread.
Args:
message_ids: A list of message IDs to remove.
"""
await Message.objects.filter(id__in=message_ids).adelete()

def _get_messages_qs(self):
return Message.objects.filter(thread_id=self._thread_id).order_by("created_at")
Expand Down
1 change: 1 addition & 0 deletions django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_DELETE_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_CREATE_MESSAGE_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_DELETE_MESSAGE_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_RUN_ASSISTANT": "django_ai_assistant.permissions.allow_all",
}

Expand Down
15 changes: 14 additions & 1 deletion django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
AIAssistantNotDefinedError,
AIUserNotAllowedError,
)
from django_ai_assistant.models import Thread
from django_ai_assistant.models import Message, Thread
from django_ai_assistant.permissions import (
can_create_message,
can_create_thread,
can_delete_message,
can_delete_thread,
can_run_assistant,
)
Expand Down Expand Up @@ -416,3 +417,15 @@ def create_thread_message_as_user(
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")

DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)])


def delete_message(
message: Message,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_delete_message(message=message, user=user, request=request, view=view):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([message.id])
21 changes: 18 additions & 3 deletions django_ai_assistant/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.views import View

from django_ai_assistant.conf import app_settings
from django_ai_assistant.models import Thread
from django_ai_assistant.models import Message, Thread


def _get_default_kwargs(user: Any, request: HttpRequest | None, view: View | None):
Expand Down Expand Up @@ -57,6 +57,21 @@ def can_create_message(
)


def can_delete_message(
message: Message,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
**kwargs,
) -> bool:
return app_settings.call_fn(
"CAN_DELETE_MESSAGE_FN",
**_get_default_kwargs(user, request, view),
message=message,
thread=message.thread,
)


def can_run_assistant(
assistant_cls,
user: Any,
Expand All @@ -71,11 +86,11 @@ def can_run_assistant(
)


def allow_all(**kwargs):
def allow_all(**kwargs) -> bool:
return True


def owns_thread(user, thread, **kwargs):
def owns_thread(user: Any, thread: Thread, **kwargs) -> bool:
if user.is_superuser:
return True

Expand Down
1 change: 1 addition & 0 deletions django_ai_assistant/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ class ThreadMessageTypeEnum(str, Enum):


class ThreadMessagesSchemaOut(Schema):
id: str # noqa: A003
type: ThreadMessageTypeEnum # noqa: A003
content: str
16 changes: 15 additions & 1 deletion django_ai_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_thread_messages,
get_threads,
)
from .models import Thread
from .models import Message, Thread
from .schemas import (
AssistantSchema,
ThreadMessagesSchemaIn,
Expand Down Expand Up @@ -96,3 +96,17 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema
)

return 201, None


@api.delete(
"threads/{thread_id}/messages/{message_id}/", response={204: None}, url_name="messages_delete"
)
def delete_thread_message(request, thread_id: str, message_id: str):
message = get_object_or_404(Message, id=message_id, thread_id=thread_id)
assistants.delete_message(
message=message,
user=request.user,
request=request,
view=None,
)
return 204, None
116 changes: 99 additions & 17 deletions example/assets/js/Chat/Chat.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import {
Container,
Text,
Stack,
Title,
Textarea,
ActionIcon,
Avatar,
Box,
Button,
Container,
Group,
LoadingOverlay,
Paper,
ScrollArea,
Stack,
Text,
Textarea,
Title,
Tooltip,
} from "@mantine/core";
import { ThreadsNav } from "./ThreadsNav";

import classes from "./Chat.module.css";
import { useCallback, useEffect, useRef, useState } from "react";
import { IconSend2 } from "@tabler/icons-react";
import { IconSend2, IconTrash } from "@tabler/icons-react";
import { getHotkeyHandler } from "@mantine/hooks";
import Markdown from "react-markdown";

Expand All @@ -25,31 +30,101 @@ import {
useThread,
} from "django-ai-assistant-client";

function ChatMessage({ message }: { message: ThreadMessagesSchemaOut }) {
function ChatMessage({
threadId,
message,
deleteMessage,
}: {
threadId: string;
message: ThreadMessagesSchemaOut;
deleteMessage: ({
threadId,
messageId,
}: {
threadId: string;
messageId: string;
}) => Promise<void>;
}) {
const isUserMessage = message.type === "human";

const DeleteButton = () => (
<Tooltip label="Delete message" withArrow position="bottom">
<ActionIcon
variant="light"
color="red"
size="sm"
onClick={async () => {
await deleteMessage({ threadId, messageId: message.id });
}}
aria-label="Delete message"
>
<IconTrash style={{ width: "70%", height: "70%" }} stroke={1.5} />
</ActionIcon>
</Tooltip>
);

return (
<Box mb="md">
<Text fw={700}>{message.type === "ai" ? "AI" : "User"}</Text>
<Markdown className={classes.mdMessage}>{message.content}</Markdown>
</Box>
<Group
gap="lg"
align="flex-end"
justify={isUserMessage ? "flex-end" : "flex-start"}
>
{!isUserMessage ? (
<Avatar color="green" radius="xl">
AI
</Avatar>
) : null}

{isUserMessage ? <DeleteButton /> : null}

<Paper
flex={1}
maw="75%"
shadow="none"
radius="lg"
p="lg"
bg="var(--mantine-color-gray-0)"
>
<Group gap="md" justify="space-between" align="top">
<Markdown className={classes.mdMessage}>{message.content}</Markdown>
</Group>
</Paper>

{!isUserMessage ? <DeleteButton /> : null}
</Group>
);
}

function ChatMessageList({
threadId,
messages,
deleteMessage,
}: {
threadId: string;
messages: ThreadMessagesSchemaOut[];
deleteMessage: ({
threadId,
messageId,
}: {
threadId: string;
messageId: string;
}) => Promise<void>;
}) {
if (messages.length === 0) {
return <Text c="dimmed">No messages.</Text>;
}

// TODO: check why horizontal scroll appears
return (
<div>
<Stack gap="xl">
{messages.map((message, index) => (
<ChatMessage key={index} message={message} />
<ChatMessage
key={index}
threadId={threadId}
message={message}
deleteMessage={deleteMessage}
/>
))}
</div>
</Stack>
);
}

Expand All @@ -66,9 +141,12 @@ export function Chat() {
loadingFetchMessages,
createMessage,
loadingCreateMessage,
deleteMessage,
loadingDeleteMessage,
} = useMessage();

const loadingMessages = loadingFetchMessages || loadingCreateMessage;
const loadingMessages =
loadingFetchMessages || loadingCreateMessage || loadingDeleteMessage;
const isThreadSelected = assistantId && activeThread;
const isChatActive = assistantId && activeThread && !loadingMessages;

Expand Down Expand Up @@ -154,7 +232,11 @@ export function Chat() {
overlayProps={{ blur: 2 }}
/>
{isThreadSelected ? (
<ChatMessageList messages={messages || []} />
<ChatMessageList
threadId={activeThread.id}
messages={messages || []}
deleteMessage={deleteMessage}
/>
) : (
<Text c="dimmed">
Select or create a thread to start chatting.
Expand Down
1 change: 1 addition & 0 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all"


Expand Down
1 change: 1 addition & 0 deletions example_movies/example_movies/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all"
1 change: 1 addition & 0 deletions example_rag/example_rag/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all"


Expand Down
Loading

0 comments on commit 8bace80

Please sign in to comment.