From d946c491d8a89517c0f671d45dadd27d4c0f06a2 Mon Sep 17 00:00:00 2001 From: Christian Bager Bach Houmann Date: Mon, 1 Jul 2024 15:33:46 +0200 Subject: [PATCH] fix: ai.prompt API now also allows the model parameter to be a string with simply the model's name as originally intended --- src/quickAddApi.ts | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/quickAddApi.ts b/src/quickAddApi.ts index affc611..fcd435b 100644 --- a/src/quickAddApi.ts +++ b/src/quickAddApi.ts @@ -107,7 +107,7 @@ export class QuickAddApi { ai: { prompt: async ( prompt: string, - model: Model, + model: Model | string, settings?: Partial<{ variableName: string; shouldAssignVariables: boolean; @@ -131,17 +131,29 @@ export class QuickAddApi { choiceExecutor ).format; - const modelProvider = getModelProvider(model.name); + let _model: Model; + if (typeof model === "string") { + const foundModel = getModelByName(model); + if (!foundModel) { + throw new Error(`Model '${model}' not found.`); + } + + _model = foundModel; + } else { + _model = model; + } + + const modelProvider = getModelProvider(_model.name); if (!modelProvider) { throw new Error( - `Model '${model.name}' not found in any provider` + `Model '${_model.name}' not found in any provider` ); } const assistantRes = await Prompt( { - model, + model: _model, prompt, apiKey: modelProvider.apiKey, modelOptions: settings?.modelOptions ?? {}, @@ -173,7 +185,7 @@ export class QuickAddApi { chunkedPrompt: async ( text: string, promptTemplate: string, - model: string, + model: Model | string, settings?: Partial<{ variableName: string; shouldAssignVariables: boolean; @@ -201,13 +213,19 @@ export class QuickAddApi { choiceExecutor ).format; - const _model = getModelByName(model); + let _model: Model; + if (typeof model === "string") { + const foundModel = getModelByName(model); + if (!foundModel) { + throw new Error(`Model ${model} not found.`); + } - if (!_model) { - throw new Error(`Model ${model} not found.`); + _model = foundModel; + } else { + _model = model; } - const modelProvider = getModelProvider(model); + const modelProvider = getModelProvider(_model.name); if (!modelProvider) { throw new Error(