Skip to content

Commit

Permalink
Fix OpenAI endpoint, memoize OpenAI tool conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
stwiname committed Dec 17, 2024
1 parent 4012d77 commit 19d444e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
76 changes: 44 additions & 32 deletions src/runners/openai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ChatResponse, Message } from "ollama";
import type { ChatResponse, Message, Tool } from "ollama";
import type { IRunner, IRunnerFactory } from "./runner.ts";
import OpenAI from "openai";
import type { IChatStorage } from "../chatStorage/chatStorage.ts";
Expand Down Expand Up @@ -122,45 +122,57 @@ export class OpenAIRunner implements IRunner {
return this.runChat(tmpMessages);
}

@LogPerformance(logger)
private async runChat(messages: Message[]): Promise<ChatResponse> {
// Converts the tools to the format compatible with OpenAI
@Memoize()
private async getTools(): Promise<
{
type: "function";
function: Tool["function"] & {
function: (args: unknown) => Promise<string>;
};
}[]
> {
const tools = await this.sandbox.getTools();
return tools.map((t) => {
if (t.type !== "function") {
throw new Error("expected function tool type");
}
return {
type: "function",
function: {
...t.function,
function: async (args: unknown) => {
try {
logger.debug(
`Calling tool: "${t.function.name}" args: "${
JSON.stringify(args)
}`,
);
return await this.sandbox.runTool(
t.function.name,
args,
this.#context,
);
} catch (e: unknown) {
logger.error(`Tool call failed: ${e}`);
// Don't throw the error this will exit the application, instead pass the message back to the LLM
return (e as Error).message;
}
},
},
};
});
}

@LogPerformance(logger)
private async runChat(messages: Message[]): Promise<ChatResponse> {
const runner = this.#openai.beta.chat.completions.runTools({
model: this.sandbox.manifest.model,
messages: messages.map((m) => ({
role: m.role as "user" | "system" | "assistant",
content: m.content,
})),
tools: tools.map((t) => {
if (t.type !== "function") {
throw new Error("expected function tool type");
}
return {
type: "function",
function: {
...t.function,
function: async (args: unknown) => {
try {
logger.debug(
`Calling tool: "${t.function.name}" args: "${
JSON.stringify(args)
}`,
);
return await this.sandbox.runTool(
t.function.name,
args,
this.#context,
);
} catch (e: unknown) {
logger.error(`Tool call failed: ${e}`);
// Don't throw the error this will exit the application, instead pass the message back to the LLM
return (e as Error).message;
}
},
},
};
}),
tools: await this.getTools(),
});

const completion = await runner.finalChatCompletion();
Expand Down
3 changes: 2 additions & 1 deletion src/runners/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { type ChatResponse, type Message, Ollama } from "ollama";
import type { IChatStorage } from "../chatStorage/index.ts";
import type { GenerateEmbedding } from "../embeddings/lance/writer.ts";
import OpenAI from "openai";
import { DEFAULT_LLM_HOST } from "../constants.ts";

export interface IRunner {
prompt(message: string): Promise<string>;
Expand Down Expand Up @@ -31,7 +32,7 @@ export async function getGenerateFunction(
try {
const openai = new OpenAI({
apiKey,
baseURL: endpoint,
baseURL: endpoint === DEFAULT_LLM_HOST ? undefined : endpoint,
});

await openai.models.retrieve(model);
Expand Down

0 comments on commit 19d444e

Please sign in to comment.