Skip to content

Commit

Permalink
Adjust custom endpoints to be client side fetches
Browse files Browse the repository at this point in the history
- Custom endpoints should be processed from the client
side since it will often be a localhost inference.
- We also wouldn't want arbitrary proxying to models from
the server side other than the predefined endpoints provided.

Signed-off-by: Brent Salisbury <[email protected]>
  • Loading branch information
nerdalert committed May 30, 2024
1 parent 96874b6 commit 77f5feb
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 84 deletions.
196 changes: 124 additions & 72 deletions ui/src/app/playground/chat/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ import { SelectOption, SelectList } from '@patternfly/react-core/dist/dynamic/co
import { MenuToggle, MenuToggleElement } from '@patternfly/react-core/dist/dynamic/components/MenuToggle';
import { Spinner } from '@patternfly/react-core/dist/dynamic/components/Spinner';
import UserIcon from '@patternfly/react-icons/dist/dynamic/icons/user-icon';
import CopyIcon from '@patternfly/react-icons/dist/dynamic/icons/copy-icon';
import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import { faBroom } from '@fortawesome/free-solid-svg-icons';
import Image from 'next/image';
import styles from './chat.module.css';
import { Endpoint, Message, Model } from '@/types';
import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import { faBroom } from '@fortawesome/free-solid-svg-icons';
import CopyToClipboardButton from '@/components/CopyToClipboardButton';

const ChatPage: React.FC = () => {
const [question, setQuestion] = useState('');
Expand All @@ -30,6 +30,7 @@ const ChatPage: React.FC = () => {
const [isSelectOpen, setIsSelectOpen] = useState(false);
const [selectedModel, setSelectedModel] = useState<Model | null>(null);
const [customModels, setCustomModels] = useState<Model[]>([]);
const [defaultModels, setDefaultModels] = useState<Model[]>([]);
const messagesContainerRef = useRef<HTMLDivElement>(null);

useEffect(() => {
Expand All @@ -52,9 +53,9 @@ const ChatPage: React.FC = () => {
}))
: [];

const allModels = [...defaultModels, ...customModels];
setCustomModels(allModels);
setSelectedModel(allModels[0] || null);
setDefaultModels(defaultModels);
setCustomModels(customModels);
setSelectedModel([...defaultModels, ...customModels][0] || null);
};

fetchDefaultModels();
Expand All @@ -65,7 +66,7 @@ const ChatPage: React.FC = () => {
};

const onSelect = (_event: React.MouseEvent<Element, MouseEvent> | undefined, value: string | number | undefined) => {
const selected = customModels.find((model) => model.name === value) || null;
const selected = [...defaultModels, ...customModels].find((model) => model.name === value) || null;
setSelectedModel(selected);
setIsSelectOpen(false);
};
Expand All @@ -76,7 +77,7 @@ const ChatPage: React.FC = () => {
</MenuToggle>
);

const dropdownItems = customModels
const dropdownItems = [...defaultModels, ...customModels]
.filter((model) => model.name && model.apiURL && model.modelName)
.map((model, index) => (
<SelectOption key={index} value={model.name}>
Expand All @@ -100,42 +101,123 @@ const ChatPage: React.FC = () => {
setQuestion('');

setIsLoading(true);
const response = await fetch(
`/api/playground/chat?apiURL=${encodeURIComponent(selectedModel.apiURL)}&modelName=${encodeURIComponent(selectedModel.modelName)}`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ question, systemRole }),
}
);

if (response.body) {
const reader = response.body.getReader();
const textDecoder = new TextDecoder('utf-8');
let botMessage = '';

setMessages((messages) => [...messages, { text: '', isUser: false }]);

(async () => {
for (;;) {
const { value, done } = await reader.read();
if (done) break;
const chunk = textDecoder.decode(value, { stream: true });
botMessage += chunk;

setMessages((messages) => {
const updatedMessages = [...messages];
updatedMessages[updatedMessages.length - 1].text = botMessage;
return updatedMessages;
});

const messagesPayload = [
{ role: 'system', content: systemRole },
{ role: 'user', content: question },
];

const requestData = {
model: selectedModel.modelName,
messages: messagesPayload,
stream: true,
};

if (customModels.some((model) => model.name === selectedModel.name)) {
// Client-side fetch if the selected model is a custom endpoint
try {
const chatResponse = await fetch(`${selectedModel.apiURL}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
accept: 'application/json',
},
body: JSON.stringify(requestData),
});

if (!chatResponse.body) {
setMessages((messages) => [...messages, { text: 'Failed to fetch chat response', isUser: false }]);
setIsLoading(false);
return;
}

const reader = chatResponse.body.getReader();
const textDecoder = new TextDecoder('utf-8');
let botMessage = '';

setMessages((messages) => [...messages, { text: '', isUser: false }]);

let done = false;
while (!done) {
const { value, done: isDone } = await reader.read();
done = isDone;
if (value) {
const chunk = textDecoder.decode(value, { stream: true });
const lines = chunk.split('\n').filter((line) => line.trim() !== '');

for (const line of lines) {
if (line.startsWith('data: ')) {
const json = line.replace('data: ', '');
if (json === '[DONE]') {
setIsLoading(false);
return;
}

try {
const parsed = JSON.parse(json);
const deltaContent = parsed.choices[0].delta?.content;

if (deltaContent) {
botMessage += deltaContent;

setMessages((messages) => {
const updatedMessages = [...messages];
updatedMessages[updatedMessages.length - 1].text = botMessage;
return updatedMessages;
});
}
} catch (err) {
console.error('Error parsing chunk:', err);
}
}
}
}
}

setIsLoading(false);
} catch (error) {
setMessages((messages) => [...messages, { text: 'Error fetching chat response', isUser: false }]);
setIsLoading(false);
})();
}
} else {
setMessages((messages) => [...messages, { text: 'Failed to fetch response from the server.', isUser: false }]);
setIsLoading(false);
// Server-side fetch for default endpoints
const response = await fetch(
`/api/playground/chat?apiURL=${encodeURIComponent(selectedModel.apiURL)}&modelName=${encodeURIComponent(selectedModel.modelName)}`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ question, systemRole }),
}
);

if (response.body) {
const reader = response.body.getReader();
const textDecoder = new TextDecoder('utf-8');
let botMessage = '';

setMessages((messages) => [...messages, { text: '', isUser: false }]);

(async () => {
for (;;) {
const { value, done } = await reader.read();
if (done) break;
const chunk = textDecoder.decode(value, { stream: true });
botMessage += chunk;

setMessages((messages) => {
const updatedMessages = [...messages];
updatedMessages[updatedMessages.length - 1].text = botMessage;
return updatedMessages;
});
}
setIsLoading(false);
})();
} else {
setMessages((messages) => [...messages, { text: 'Failed to fetch response from the server.', isUser: false }]);
setIsLoading(false);
}
}
};

Expand All @@ -145,32 +227,6 @@ const ChatPage: React.FC = () => {
}
}, [messages]);

const handleCopyToClipboard = (text: string) => {
if (navigator.clipboard && navigator.clipboard.writeText) {
navigator.clipboard
.writeText(text)
.then(() => {
console.log('Text copied to clipboard');
})
.catch((err) => {
console.error('Could not copy text: ', err);
});
} else {
const textArea = document.createElement('textarea');
textArea.value = text;
document.body.appendChild(textArea);
textArea.focus();
textArea.select();
try {
document.execCommand('copy');
console.log('Text copied to clipboard');
} catch (err) {
console.error('Could not copy text: ', err);
}
document.body.removeChild(textArea);
}
};

const handleCleanup = () => {
setMessages([]);
};
Expand Down Expand Up @@ -218,11 +274,7 @@ const ChatPage: React.FC = () => {
<pre>
<code>{msg.text}</code>
</pre>
{!msg.isUser && (
<Button variant="plain" onClick={() => handleCopyToClipboard(msg.text)} aria-label="Copy to clipboard">
<CopyIcon />
</Button>
)}
{!msg.isUser && <CopyToClipboardButton text={msg.text} />}
</div>
))}
{isLoading && <Spinner aria-label="Loading" size="lg" />}
Expand Down
Loading

0 comments on commit 77f5feb

Please sign in to comment.