Skip to content

Commit

Permalink
feat: add dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
HuberTRoy committed Dec 19, 2024
1 parent b83cc25 commit d553cf8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
10 changes: 4 additions & 6 deletions src/context/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export class Context implements IContext {
constructor(
getEmbedding: GetEmbedding,
vectorStorage?: Connection,
readonly topK = 10,
readonly topK = 10
) {
this.#getEmbedding = getEmbedding;
this.#vectorStorage = vectorStorage;
Expand All @@ -27,7 +27,7 @@ export class Context implements IContext {
public static async create(
sandbox: ISandbox,
loader: Loader,
runEmbedding: GetEmbedding,
runEmbedding: GetEmbedding
): Promise<Context> {
if (!sandbox.manifest.vectorStorage) {
return new Context(runEmbedding);
Expand All @@ -49,14 +49,12 @@ export class Context implements IContext {
async vectorSearch(tableName: string, vector: number[]): Promise<unknown[]> {
if (!this.#vectorStorage) {
throw new Error(
"Project did not provide vector storage. Unable to perform search",
"Project did not provide vector storage. Unable to perform search"
);
}
const table = await this.#vectorStorage.openTable(tableName);

return await table.vectorSearch(vector)
.limit(this.topK)
.toArray();
return await table.vectorSearch(vector).limit(this.topK).toArray();
}

@LogPerformance(logger)
Expand Down
52 changes: 27 additions & 25 deletions src/runners/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ import { getLogger } from "../logger.ts";

const logger = await getLogger("runner:openai");

const dimensionsDict: { [key in string]: number } = {
"text-embedding-3-large": 3072,
"text-embedding-3-small": 1536,
"text-embedding-ada-002": 1536,
};

export class OpenAIRunnerFactory implements IRunnerFactory {
#openai: OpenAI;
#sandbox: ISandbox;
#loader: Loader;

private constructor(
openAI: OpenAI,
sandbox: ISandbox,
loader: Loader,
) {
private constructor(openAI: OpenAI, sandbox: ISandbox, loader: Loader) {
this.#openai = openAI;
this.#sandbox = sandbox;
this.#loader = loader;
Expand All @@ -30,7 +32,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory {
baseUrl: string | undefined,
apiKey: string | undefined,
sandbox: ISandbox,
loader: Loader,
loader: Loader
): Promise<OpenAIRunnerFactory> {
const openai = new OpenAI({
apiKey,
Expand All @@ -52,11 +54,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory {
await openai.models.retrieve(sandbox.manifest.embeddingsModel);
}

const factory = new OpenAIRunnerFactory(
openai,
sandbox,
loader,
);
const factory = new OpenAIRunnerFactory(openai, sandbox, loader);

// Makes sure vector storage is loaded
await factory.getContext();
Expand All @@ -66,10 +64,12 @@ export class OpenAIRunnerFactory implements IRunnerFactory {

async runEmbedding(input: string | string[]): Promise<number[]> {
const res = await this.#openai.embeddings.create({
model: this.#sandbox.manifest.embeddingsModel ??
"text-embedding-3-small",
model: this.#sandbox.manifest.embeddingsModel ?? "text-embedding-3-small",
input,
dimensions: 768,
dimensions:
dimensionsDict[
this.#sandbox.manifest.embeddingsModel ?? "text-embedding-3-small"
] || 1536,
});

return res.data[0].embedding;
Expand All @@ -80,7 +80,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory {
return Context.create(
this.#sandbox,
this.#loader,
this.runEmbedding.bind(this),
this.runEmbedding.bind(this)
);
}

Expand All @@ -89,7 +89,7 @@ export class OpenAIRunnerFactory implements IRunnerFactory {
this.#openai,
await this.getContext(),
this.#sandbox,
chatStorage,
chatStorage
);
}
}
Expand All @@ -102,17 +102,19 @@ export class OpenAIRunner implements IRunner {
openAI: OpenAI,
context: IContext,
private sandbox: ISandbox,
private chatStorage: IChatStorage,
private chatStorage: IChatStorage
) {
this.#openai = openAI;
this.#context = context;
}

async prompt(message: string): Promise<string> {
const outMessage = await this.promptMessages([{
role: "user",
content: message,
}]);
const outMessage = await this.promptMessages([
{
role: "user",
content: message,
},
]);
return outMessage.message.content;
}

Expand Down Expand Up @@ -146,14 +148,14 @@ export class OpenAIRunner implements IRunner {
function: async (args: unknown) => {
try {
logger.debug(
`Calling tool: "${t.function.name}" args: "${
JSON.stringify(args)
}`,
`Calling tool: "${t.function.name}" args: "${JSON.stringify(
args
)}`
);
return await this.sandbox.runTool(
t.function.name,
args,
this.#context,
this.#context
);
} catch (e: unknown) {
logger.error(`Tool call failed: ${e}`);
Expand Down

0 comments on commit d553cf8

Please sign in to comment.