Skip to content

Commit

Permalink
feat: add reasoning model (vercel#750)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Apperson <[email protected]>
  • Loading branch information
jeremyphilemon and mattapperson authored Feb 3, 2025
1 parent 7680426 commit c61d4f9
Show file tree
Hide file tree
Showing 19 changed files with 335 additions and 202 deletions.
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Get your OpenAI API Key here: https://platform.openai.com/account/api-keys
# Get your OpenAI API Key here for chat models: https://platform.openai.com/account/api-keys
OPENAI_API_KEY=****

# Get your Fireworks AI API Key here for reasoning models: https://fireworks.ai/account/api-keys
FIREWORKS_API_KEY=****

# Generate a random secret: https://generate-secret.vercel.app/32 or `openssl rand -base64 32`
AUTH_SECRET=****

Expand Down
10 changes: 5 additions & 5 deletions app/(chat)/actions.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
'use server';

import { type CoreUserMessage, generateText, Message } from 'ai';
import { generateText, Message } from 'ai';
import { cookies } from 'next/headers';

import { customModel } from '@/lib/ai';
import {
deleteMessagesByChatIdAfterTimestamp,
getMessageById,
updateChatVisiblityById,
} from '@/lib/db/queries';
import { VisibilityType } from '@/components/visibility-selector';
import { myProvider } from '@/lib/ai/models';

export async function saveModelId(model: string) {
export async function saveChatModelAsCookie(model: string) {
const cookieStore = await cookies();
cookieStore.set('model-id', model);
cookieStore.set('chat-model', model);
}

export async function generateTitleFromUserMessage({
Expand All @@ -22,7 +22,7 @@ export async function generateTitleFromUserMessage({
message: Message;
}) {
const { text: title } = await generateText({
model: customModel('gpt-4o-mini'),
model: myProvider.languageModel('title-model'),
system: `\n
- you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 80 characters long
Expand Down
56 changes: 27 additions & 29 deletions app/(chat)/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ import {
createDataStreamResponse,
smoothStream,
streamText,
wrapLanguageModel,
} from 'ai';

import { auth } from '@/app/(auth)/auth';
import { customModel } from '@/lib/ai';
import { models } from '@/lib/ai/models';
import { myProvider } from '@/lib/ai/models';
import { systemPrompt } from '@/lib/ai/prompts';
import {
deleteChatById,
Expand Down Expand Up @@ -48,8 +48,8 @@ export async function POST(request: Request) {
const {
id,
messages,
modelId,
}: { id: string; messages: Array<Message>; modelId: string } =
selectedChatModel,
}: { id: string; messages: Array<Message>; selectedChatModel: string } =
await request.json();

const session = await auth();
Expand All @@ -58,12 +58,6 @@ export async function POST(request: Request) {
return new Response('Unauthorized', { status: 401 });
}

const model = models.find((model) => model.id === modelId);

if (!model) {
return new Response('Model not found', { status: 404 });
}

const userMessage = getMostRecentUserMessage(messages);

if (!userMessage) {
Expand All @@ -84,7 +78,7 @@ export async function POST(request: Request) {
return createDataStreamResponse({
execute: (dataStream) => {
const result = streamText({
model: customModel(model.apiIdentifier),
model: myProvider.languageModel(selectedChatModel),
system: systemPrompt,
messages,
maxSteps: 5,
Expand All @@ -93,32 +87,31 @@ export async function POST(request: Request) {
experimental_generateMessageId: generateUUID,
tools: {
getWeather,
createDocument: createDocument({ session, dataStream, model }),
updateDocument: updateDocument({ session, dataStream, model }),
createDocument: createDocument({ session, dataStream }),
updateDocument: updateDocument({ session, dataStream }),
requestSuggestions: requestSuggestions({
session,
dataStream,
model,
}),
},
onFinish: async ({ response }) => {
onFinish: async ({ response, reasoning }) => {
if (session.user?.id) {
try {
const responseMessagesWithoutIncompleteToolCalls =
sanitizeResponseMessages(response.messages);
const sanitizedResponseMessages = sanitizeResponseMessages({
messages: response.messages,
reasoning,
});

await saveMessages({
messages: responseMessagesWithoutIncompleteToolCalls.map(
(message) => {
return {
id: message.id,
chatId: id,
role: message.role,
content: message.content,
createdAt: new Date(),
};
},
),
messages: sanitizedResponseMessages.map((message) => {
return {
id: message.id,
chatId: id,
role: message.role,
content: message.content,
createdAt: new Date(),
};
}),
});
} catch (error) {
console.error('Failed to save chat');
Expand All @@ -131,7 +124,12 @@ export async function POST(request: Request) {
},
});

result.mergeIntoDataStream(dataStream);
result.mergeIntoDataStream(dataStream, {
sendReasoning: true,
});
},
onError: (error) => {
return 'Oops, an error occured!';
},
});
}
Expand Down
24 changes: 18 additions & 6 deletions app/(chat)/chat/[id]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import { notFound } from 'next/navigation';

import { auth } from '@/app/(auth)/auth';
import { Chat } from '@/components/chat';
import { DEFAULT_MODEL_NAME, models } from '@/lib/ai/models';
import { getChatById, getMessagesByChatId } from '@/lib/db/queries';
import { convertToUIMessages } from '@/lib/utils';
import { DataStreamHandler } from '@/components/data-stream-handler';
import { DEFAULT_CHAT_MODEL } from '@/lib/ai/models';

export default async function Page(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
Expand Down Expand Up @@ -34,17 +34,29 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
});

const cookieStore = await cookies();
const modelIdFromCookie = cookieStore.get('model-id')?.value;
const selectedModelId =
models.find((model) => model.id === modelIdFromCookie)?.id ||
DEFAULT_MODEL_NAME;
const chatModelFromCookie = cookieStore.get('chat-model');

if (!chatModelFromCookie) {
return (
<>
<Chat
id={chat.id}
initialMessages={convertToUIMessages(messagesFromDb)}
selectedChatModel={DEFAULT_CHAT_MODEL}
selectedVisibilityType={chat.visibility}
isReadonly={session?.user?.id !== chat.userId}
/>
<DataStreamHandler id={id} />
</>
);
}

return (
<>
<Chat
id={chat.id}
initialMessages={convertToUIMessages(messagesFromDb)}
selectedModelId={selectedModelId}
selectedChatModel={chatModelFromCookie.value}
selectedVisibilityType={chat.visibility}
isReadonly={session?.user?.id !== chat.userId}
/>
Expand Down
24 changes: 18 additions & 6 deletions app/(chat)/page.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
import { cookies } from 'next/headers';

import { Chat } from '@/components/chat';
import { DEFAULT_MODEL_NAME, models } from '@/lib/ai/models';
import { DEFAULT_CHAT_MODEL } from '@/lib/ai/models';
import { generateUUID } from '@/lib/utils';
import { DataStreamHandler } from '@/components/data-stream-handler';

export default async function Page() {
const id = generateUUID();

const cookieStore = await cookies();
const modelIdFromCookie = cookieStore.get('model-id')?.value;
const modelIdFromCookie = cookieStore.get('chat-model');

const selectedModelId =
models.find((model) => model.id === modelIdFromCookie)?.id ||
DEFAULT_MODEL_NAME;
if (!modelIdFromCookie) {
return (
<>
<Chat
key={id}
id={id}
initialMessages={[]}
selectedChatModel={DEFAULT_CHAT_MODEL}
selectedVisibilityType="private"
isReadonly={false}
/>
<DataStreamHandler id={id} />
</>
);
}

return (
<>
<Chat
key={id}
id={id}
initialMessages={[]}
selectedModelId={selectedModelId}
selectedChatModel={modelIdFromCookie.value}
selectedVisibilityType="private"
isReadonly={false}
/>
Expand Down
13 changes: 9 additions & 4 deletions components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ import { MultimodalInput } from './multimodal-input';
import { Messages } from './messages';
import { VisibilityType } from './visibility-selector';
import { useBlockSelector } from '@/hooks/use-block';
import { toast } from 'sonner';

export function Chat({
id,
initialMessages,
selectedModelId,
selectedChatModel,
selectedVisibilityType,
isReadonly,
}: {
id: string;
initialMessages: Array<Message>;
selectedModelId: string;
selectedChatModel: string;
selectedVisibilityType: VisibilityType;
isReadonly: boolean;
}) {
Expand All @@ -42,14 +43,18 @@ export function Chat({
reload,
} = useChat({
id,
body: { id, modelId: selectedModelId },
body: { id, selectedChatModel: selectedChatModel },
initialMessages,
experimental_throttle: 100,
sendExtraMessageFields: true,
generateId: generateUUID,
onFinish: () => {
mutate('/api/history');
},
onError: (error) => {
console.log(error);
toast.error('An error occured, please try again!');
},
});

const { data: votes } = useSWR<Array<Vote>>(
Expand All @@ -65,7 +70,7 @@ export function Chat({
<div className="flex flex-col min-w-0 h-dvh bg-background">
<ChatHeader
chatId={id}
selectedModelId={selectedModelId}
selectedModelId={selectedChatModel}
selectedVisibilityType={selectedVisibilityType}
isReadonly={isReadonly}
/>
Expand Down
2 changes: 1 addition & 1 deletion components/markdown.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Link from 'next/link';
import React, { memo, useMemo, useState } from 'react';
import React, { memo } from 'react';
import ReactMarkdown, { type Components } from 'react-markdown';
import remarkGfm from 'remark-gfm';
import { CodeBlock } from './code-block';
Expand Down
75 changes: 75 additions & 0 deletions components/message-reasoning.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
'use client';

import { useState } from 'react';
import { ChevronDownIcon, LoaderIcon } from './icons';
import { motion, AnimatePresence } from 'framer-motion';
import { Markdown } from './markdown';

interface MessageReasoningProps {
isLoading: boolean;
reasoning: string;
}

export function MessageReasoning({
isLoading,
reasoning,
}: MessageReasoningProps) {
const [isExpanded, setIsExpanded] = useState(true);

const variants = {
collapsed: {
height: 0,
opacity: 0,
marginTop: 0,
marginBottom: 0,
},
expanded: {
height: 'auto',
opacity: 1,
marginTop: '1rem',
marginBottom: '0.5rem',
},
};

return (
<div className="flex flex-col">
{isLoading ? (
<div className="flex flex-row gap-2 items-center">
<div className="font-medium">Reasoning</div>
<div className="animate-spin">
<LoaderIcon />
</div>
</div>
) : (
<div className="flex flex-row gap-2 items-center">
<div className="font-medium">Reasoned for a few seconds</div>
<div
className="cursor-pointer"
onClick={() => {
setIsExpanded(!isExpanded);
}}
>
<ChevronDownIcon />
</div>
</div>
)}

<AnimatePresence initial={false}>
{isExpanded && (
<motion.div
key="content"
initial="collapsed"
animate="expanded"
exit="collapsed"
variants={variants}
transition={{ duration: 0.2, ease: 'easeInOut' }}
style={{ overflow: 'hidden' }}
className="pl-4 text-zinc-600 dark:text-zinc-400 border-l flex flex-col gap-4"
>
<Markdown>{reasoning}</Markdown>
</motion.div>
)}
</AnimatePresence>
</div>
);
}
Loading

0 comments on commit c61d4f9

Please sign in to comment.