From 8043bf9834ff1e5cd30ca803f376e69e349c5ff4 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:15:39 +0100 Subject: [PATCH] Refactoring AIProvider and handling errors (#15) * Better handling of error with providers * Add a method to get the error message when catching an error --- src/chat-handler.ts | 70 ++++++++++++++++++--------- src/completion-provider.ts | 34 +++++++++---- src/index.ts | 6 +-- src/llm-models/base-completer.ts | 13 +++++ src/llm-models/codestral-completer.ts | 9 ++-- src/llm-models/utils.ts | 25 ++++++++-- src/provider.ts | 52 +++++++++++++------- src/token.ts | 4 +- 8 files changed, 146 insertions(+), 67 deletions(-) diff --git a/src/chat-handler.ts b/src/chat-handler.ts index 18417f6..a9b0ef8 100644 --- a/src/chat-handler.ts +++ b/src/chat-handler.ts @@ -16,6 +16,8 @@ import { mergeMessageRuns } from '@langchain/core/messages'; import { UUID } from '@lumino/coreutils'; +import { getErrorMessage } from './llm-models'; +import { IAIProvider } from './token'; export type ConnectionMessage = { type: 'connection'; @@ -25,14 +27,14 @@ export type ConnectionMessage = { export class ChatHandler extends ChatModel { constructor(options: ChatHandler.IOptions) { super(options); - this._provider = options.provider; + this._aiProvider = options.aiProvider; + this._aiProvider.modelChange.connect(() => { + this._errorMessage = this._aiProvider.chatError; + }); } get provider(): BaseChatModel | null { - return this._provider; - } - set provider(provider: BaseChatModel | null) { - this._provider = provider; + return this._aiProvider.chatModel; } async sendMessage(message: INewMessage): Promise { @@ -46,15 +48,15 @@ export class ChatHandler extends ChatModel { }; this.messageAdded(msg); - if (this._provider === null) { - const botMsg: IChatMessage = { + if (this._aiProvider.chatModel === null) { + const errorMsg: IChatMessage = { id: UUID.uuid4(), - body: '**AI provider not configured for the chat**', + body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`, sender: { username: 'ERROR' }, time: Date.now(), type: 'msg' }; - this.messageAdded(botMsg); + this.messageAdded(errorMsg); return false; } @@ -69,19 +71,37 @@ export class ChatHandler extends ChatModel { }) ); - const response = await this._provider.invoke(messages); - // TODO: fix deprecated response.text - const content = response.text; - const botMsg: IChatMessage = { - id: UUID.uuid4(), - body: content, - sender: { username: 'Bot' }, - time: Date.now(), - type: 'msg' - }; - this.messageAdded(botMsg); - this._history.messages.push(botMsg); - return true; + this.updateWriters([{ username: 'AI' }]); + return this._aiProvider.chatModel + .invoke(messages) + .then(response => { + const content = response.content; + const botMsg: IChatMessage = { + id: UUID.uuid4(), + body: content.toString(), + sender: { username: 'AI' }, + time: Date.now(), + type: 'msg' + }; + this.messageAdded(botMsg); + this._history.messages.push(botMsg); + return true; + }) + .catch(reason => { + const error = getErrorMessage(this._aiProvider.name, reason); + const errorMsg: IChatMessage = { + id: UUID.uuid4(), + body: `**${error}**`, + sender: { username: 'ERROR' }, + time: Date.now(), + type: 'msg' + }; + this.messageAdded(errorMsg); + return false; + }) + .finally(() => { + this.updateWriters([]); + }); } async getHistory(): Promise { @@ -96,12 +116,14 @@ export class ChatHandler extends ChatModel { super.messageAdded(message); } - private _provider: BaseChatModel | null; + private _aiProvider: IAIProvider; + private _errorMessage: string = ''; private _history: IChatHistory = { messages: [] }; + private _defaultErrorMessage = 'AI provider not configured'; } export namespace ChatHandler { export interface IOptions extends ChatModel.IOptions { - provider: BaseChatModel | null; + aiProvider: IAIProvider; } } diff --git a/src/completion-provider.ts b/src/completion-provider.ts index b2ac0b1..e7000c5 100644 --- a/src/completion-provider.ts +++ b/src/completion-provider.ts @@ -5,7 +5,8 @@ import { } from '@jupyterlab/completer'; import { LLM } from '@langchain/core/language_models/llms'; -import { getCompleter, IBaseCompleter } from './llm-models'; +import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models'; +import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; /** * The generic completion provider to register to the completion provider manager. @@ -14,23 +15,36 @@ export class CompletionProvider implements IInlineCompletionProvider { readonly identifier = '@jupyterlite/ai'; constructor(options: CompletionProvider.IOptions) { - this.name = options.name; + const { name, settings } = options; + this.setCompleter(name, settings); } /** - * Getter and setter of the name. - * The setter will create the appropriate completer, accordingly to the name. + * Set the completer. + * + * @param name - the name of the completer. + * @param settings - The settings associated to the completer. + */ + setCompleter(name: string, settings: ReadonlyPartialJSONObject) { + try { + this._completer = getCompleter(name, settings); + this._name = this._completer === null ? 'None' : name; + } catch (e: any) { + this._completer = null; + this._name = 'None'; + throw e; + } + } + + /** + * Get the current completer name. */ get name(): string { return this._name; } - set name(name: string) { - this._name = name; - this._completer = getCompleter(name); - } /** - * get the current completer. + * Get the current completer. */ get completer(): IBaseCompleter | null { return this._completer; @@ -55,7 +69,7 @@ export class CompletionProvider implements IInlineCompletionProvider { } export namespace CompletionProvider { - export interface IOptions { + export interface IOptions extends BaseCompleter.IOptions { name: string; } } diff --git a/src/index.ts b/src/index.ts index 2cc8bdc..76d2ab2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -41,14 +41,10 @@ const chatPlugin: JupyterFrontEndPlugin = { } const chatHandler = new ChatHandler({ - provider: aiProvider.chatModel, + aiProvider: aiProvider, activeCellManager: activeCellManager }); - aiProvider.modelChange.connect(() => { - chatHandler.provider = aiProvider.chatModel; - }); - let sendWithShiftEnter = false; let enableCodeToolbar = true; diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts index 498abf6..fb84f4f 100644 --- a/src/llm-models/base-completer.ts +++ b/src/llm-models/base-completer.ts @@ -3,6 +3,7 @@ import { IInlineCompletionContext } from '@jupyterlab/completer'; import { LLM } from '@langchain/core/language_models/llms'; +import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; export interface IBaseCompleter { /** @@ -18,3 +19,15 @@ export interface IBaseCompleter { context: IInlineCompletionContext ): Promise; } + +/** + * The namespace for the base completer. + */ +export namespace BaseCompleter { + /** + * The options for the constructor of a completer. + */ + export interface IOptions { + settings: ReadonlyPartialJSONObject; + } +} diff --git a/src/llm-models/codestral-completer.ts b/src/llm-models/codestral-completer.ts index f1168c8..efa7934 100644 --- a/src/llm-models/codestral-completer.ts +++ b/src/llm-models/codestral-completer.ts @@ -7,7 +7,7 @@ import { MistralAI } from '@langchain/mistralai'; import { Throttler } from '@lumino/polling'; import { CompletionRequest } from '@mistralai/mistralai'; -import { IBaseCompleter } from './base-completer'; +import { BaseCompleter, IBaseCompleter } from './base-completer'; /* * The Mistral API has a rate limit of 1 request per second @@ -15,11 +15,8 @@ import { IBaseCompleter } from './base-completer'; const INTERVAL = 1000; export class CodestralCompleter implements IBaseCompleter { - constructor() { - this._mistralProvider = new MistralAI({ - apiKey: 'TMP', - model: 'codestral-latest' - }); + constructor(options: BaseCompleter.IOptions) { + this._mistralProvider = new MistralAI({ ...options.settings }); this._throttler = new Throttler(async (data: CompletionRequest) => { const response = await this._mistralProvider.completionWithRetry( data, diff --git a/src/llm-models/utils.ts b/src/llm-models/utils.ts index 6d9b9f4..544d684 100644 --- a/src/llm-models/utils.ts +++ b/src/llm-models/utils.ts @@ -2,13 +2,17 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { ChatMistralAI } from '@langchain/mistralai'; import { IBaseCompleter } from './base-completer'; import { CodestralCompleter } from './codestral-completer'; +import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; /** * Get an LLM completer from the name. */ -export function getCompleter(name: string): IBaseCompleter | null { +export function getCompleter( + name: string, + settings: ReadonlyPartialJSONObject +): IBaseCompleter | null { if (name === 'MistralAI') { - return new CodestralCompleter(); + return new CodestralCompleter({ settings }); } return null; } @@ -16,9 +20,22 @@ export function getCompleter(name: string): IBaseCompleter | null { /** * Get an LLM chat model from the name. */ -export function getChatModel(name: string): BaseChatModel | null { +export function getChatModel( + name: string, + settings: ReadonlyPartialJSONObject +): BaseChatModel | null { if (name === 'MistralAI') { - return new ChatMistralAI({ apiKey: 'TMP' }); + return new ChatMistralAI({ ...settings }); } return null; } + +/** + * Get the error message from provider. + */ +export function getErrorMessage(name: string, error: any): string { + if (name === 'MistralAI') { + return error.message; + } + return 'Unknown provider'; +} diff --git a/src/provider.ts b/src/provider.ts index de88ba3..6019785 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -10,7 +10,10 @@ import { IAIProvider } from './token'; export class AIProvider implements IAIProvider { constructor(options: AIProvider.IOptions) { - this._completionProvider = new CompletionProvider({ name: 'None' }); + this._completionProvider = new CompletionProvider({ + name: 'None', + settings: {} + }); options.completionProviderManager.registerInlineProvider( this._completionProvider ); @@ -21,7 +24,7 @@ export class AIProvider implements IAIProvider { } /** - * get the current completer of the completion provider. + * Get the current completer of the completion provider. */ get completer(): IBaseCompleter | null { if (this._name === null) { @@ -31,7 +34,7 @@ export class AIProvider implements IAIProvider { } /** - * get the current llm chat model. + * Get the current llm chat model. */ get chatModel(): BaseChatModel | null { if (this._name === null) { @@ -40,6 +43,20 @@ export class AIProvider implements IAIProvider { return this._llmChatModel; } + /** + * Get the current chat error; + */ + get chatError(): string { + return this._chatError; + } + + /** + * get the current completer error. + */ + get completerError(): string { + return this._completerError; + } + /** * Set the models (chat model and completer). * Creates the models if the name has changed, otherwise only updates their config. @@ -48,22 +65,21 @@ export class AIProvider implements IAIProvider { * @param settings - the settings for the models. */ setModels(name: string, settings: ReadonlyPartialJSONObject) { - if (name !== this._name) { - this._name = name; - this._completionProvider.name = name; - this._llmChatModel = getChatModel(name); - this._modelChange.emit(); + try { + this._completionProvider.setCompleter(name, settings); + this._completerError = ''; + } catch (e: any) { + this._completerError = e.message; } - - // Update the inline completion provider settings. - if (this._completionProvider.llmCompleter) { - AIProvider.updateConfig(this._completionProvider.llmCompleter, settings); - } - - // Update the chat LLM settings. - if (this._llmChatModel) { - AIProvider.updateConfig(this._llmChatModel, settings); + try { + this._llmChatModel = getChatModel(name, settings); + this._chatError = ''; + } catch (e: any) { + this._chatError = e.message; + this._llmChatModel = null; } + this._name = name; + this._modelChange.emit(); } get modelChange(): ISignal { @@ -74,6 +90,8 @@ export class AIProvider implements IAIProvider { private _llmChatModel: BaseChatModel | null = null; private _name: string = 'None'; private _modelChange = new Signal(this); + private _chatError: string = ''; + private _completerError: string = ''; } export namespace AIProvider { diff --git a/src/token.ts b/src/token.ts index 626be4a..09f5a6e 100644 --- a/src/token.ts +++ b/src/token.ts @@ -5,10 +5,12 @@ import { ISignal } from '@lumino/signaling'; import { IBaseCompleter } from './llm-models'; export interface IAIProvider { - name: string | null; + name: string; completer: IBaseCompleter | null; chatModel: BaseChatModel | null; modelChange: ISignal; + chatError: string; + completerError: string; } export const IAIProvider = new Token(