diff --git a/src/context/context.ts b/src/context/context.ts index 8c93a6d..d529368 100644 --- a/src/context/context.ts +++ b/src/context/context.ts @@ -18,7 +18,7 @@ export class Context implements IContext { constructor( getEmbedding: GetEmbedding, vectorStorage?: Connection, - readonly topK = 10, + readonly topK = 10 ) { this.#getEmbedding = getEmbedding; this.#vectorStorage = vectorStorage; @@ -27,7 +27,7 @@ export class Context implements IContext { public static async create( sandbox: ISandbox, loader: Loader, - runEmbedding: GetEmbedding, + runEmbedding: GetEmbedding ): Promise { if (!sandbox.manifest.vectorStorage) { return new Context(runEmbedding); @@ -49,14 +49,12 @@ export class Context implements IContext { async vectorSearch(tableName: string, vector: number[]): Promise { if (!this.#vectorStorage) { throw new Error( - "Project did not provide vector storage. Unable to perform search", + "Project did not provide vector storage. Unable to perform search" ); } const table = await this.#vectorStorage.openTable(tableName); - return await table.vectorSearch(vector) - .limit(this.topK) - .toArray(); + return await table.vectorSearch(vector).limit(this.topK).toArray(); } @LogPerformance(logger) diff --git a/src/runners/openai.ts b/src/runners/openai.ts index b1a4c03..3eb3fa4 100644 --- a/src/runners/openai.ts +++ b/src/runners/openai.ts @@ -11,16 +11,18 @@ import { getLogger } from "../logger.ts"; const logger = await getLogger("runner:openai"); +const dimensionsDict: { [key in string]: number } = { + "text-embedding-3-large": 3072, + "text-embedding-3-small": 1536, + "text-embedding-ada-002": 1536, +}; + export class OpenAIRunnerFactory implements IRunnerFactory { #openai: OpenAI; #sandbox: ISandbox; #loader: Loader; - private constructor( - openAI: OpenAI, - sandbox: ISandbox, - loader: Loader, - ) { + private constructor(openAI: OpenAI, sandbox: ISandbox, loader: Loader) { this.#openai = openAI; this.#sandbox = sandbox; this.#loader = loader; @@ -30,7 +32,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory { baseUrl: string | undefined, apiKey: string | undefined, sandbox: ISandbox, - loader: Loader, + loader: Loader ): Promise { const openai = new OpenAI({ apiKey, @@ -52,11 +54,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory { await openai.models.retrieve(sandbox.manifest.embeddingsModel); } - const factory = new OpenAIRunnerFactory( - openai, - sandbox, - loader, - ); + const factory = new OpenAIRunnerFactory(openai, sandbox, loader); // Makes sure vector storage is loaded await factory.getContext(); @@ -66,10 +64,12 @@ export class OpenAIRunnerFactory implements IRunnerFactory { async runEmbedding(input: string | string[]): Promise { const res = await this.#openai.embeddings.create({ - model: this.#sandbox.manifest.embeddingsModel ?? - "text-embedding-3-small", + model: this.#sandbox.manifest.embeddingsModel ?? "text-embedding-3-small", input, - dimensions: 768, + dimensions: + dimensionsDict[ + this.#sandbox.manifest.embeddingsModel ?? "text-embedding-3-small" + ] || 1536, }); return res.data[0].embedding; @@ -80,7 +80,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory { return Context.create( this.#sandbox, this.#loader, - this.runEmbedding.bind(this), + this.runEmbedding.bind(this) ); } @@ -89,7 +89,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory { this.#openai, await this.getContext(), this.#sandbox, - chatStorage, + chatStorage ); } } @@ -102,17 +102,19 @@ export class OpenAIRunner implements IRunner { openAI: OpenAI, context: IContext, private sandbox: ISandbox, - private chatStorage: IChatStorage, + private chatStorage: IChatStorage ) { this.#openai = openAI; this.#context = context; } async prompt(message: string): Promise { - const outMessage = await this.promptMessages([{ - role: "user", - content: message, - }]); + const outMessage = await this.promptMessages([ + { + role: "user", + content: message, + }, + ]); return outMessage.message.content; } @@ -146,14 +148,14 @@ export class OpenAIRunner implements IRunner { function: async (args: unknown) => { try { logger.debug( - `Calling tool: "${t.function.name}" args: "${ - JSON.stringify(args) - }`, + `Calling tool: "${t.function.name}" args: "${JSON.stringify( + args + )}` ); return await this.sandbox.runTool( t.function.name, args, - this.#context, + this.#context ); } catch (e: unknown) { logger.error(`Tool call failed: ${e}`);