Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: plugin-openai #2898

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions packages/plugin-openai/src/actions/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,27 @@ export function validateApiKey(): string {
* @returns The response data.
* @throws Will throw an error for request failures or rate limits.
*/

export interface OpenAIRequestData {
model: string;
prompt: string;
max_tokens: number;
temperature: number;
[key: string]: unknown;
}

export interface OpenAIEditRequestData {
model: string;
input: string;
instruction: string;
max_tokens: number;
temperature: number;
[key: string]: unknown;
}

export async function callOpenAiApi<T>(
url: string,
data: any,
data: OpenAIRequestData | OpenAIEditRequestData,
apiKey: string,
): Promise<T> {
try {
Expand All @@ -55,7 +73,7 @@ export async function callOpenAiApi<T>(
const response = await axios.post<T>(url, data, config);
return response.data;
} catch (error) {
console.error("Error communicating with OpenAI API:", error.message);
console.error("Error communicating with OpenAI API:", error instanceof Error ? error.message : String(error));
if (axios.isAxiosError(error)) {
if (error.response?.status === 429) {
throw new Error("Rate limit exceeded. Please try again later.");
Expand All @@ -73,12 +91,13 @@ export async function callOpenAiApi<T>(
* @param temperature - The sampling temperature.
* @returns The request payload for OpenAI completions.
*/

export function buildRequestData(
prompt: string,
model: string = DEFAULT_MODEL,
maxTokens: number = DEFAULT_MAX_TOKENS,
temperature: number = DEFAULT_TEMPERATURE,
): Record<string, any> {
): OpenAIRequestData {
return {
model,
prompt,
Expand Down
7 changes: 4 additions & 3 deletions packages/plugin-openai/src/actions/analyzeSentimentAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@ import {
export const analyzeSentimentAction: Action = {
name: "analyzeSentiment",
description: "Analyze sentiment using OpenAI",
async handler(runtime, message, state) {
similes: [], // Added missing required property
async handler(_runtime, message, _state) {
const prompt = `Analyze the sentiment of the following text: "${message.content.text?.trim() || ""}"`;
validatePrompt(prompt);

const apiKey = validateApiKey();
const requestData = buildRequestData(prompt);

const response = await callOpenAiApi(
const response = await callOpenAiApi<{ choices: Array<{ text: string }> }>(
"https://api.openai.com/v1/completions",
requestData,
apiKey,
);
return response.choices[0].text.trim();
},
validate: async (runtime, message) => {
validate: async (runtime, _message) => {
return !!runtime.getSetting("OPENAI_API_KEY");
},
examples: [],
Expand Down
18 changes: 12 additions & 6 deletions packages/plugin-openai/src/actions/editTextAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ import {
validatePrompt,
validateApiKey,
callOpenAiApi,
buildRequestData,
} from "./action";

interface EditResponse {
choices: Array<{ text: string }>;
}

export const editTextAction: Action = {
name: "editText",
description: "Edit text using OpenAI",
async handler(runtime, message, state) {
const input = message.content.input?.trim() || "";
const instruction = message.content.instruction?.trim() || "";
similes: [],
async handler(_runtime, message, _state) {
const input = (message.content.input as string)?.trim() || "";
const instruction = (message.content.instruction as string)?.trim() || "";
validatePrompt(input);
validatePrompt(instruction);

Expand All @@ -20,16 +24,18 @@ export const editTextAction: Action = {
model: "text-davinci-edit-001",
input,
instruction,
max_tokens: 1000,
temperature: 0.7,
};

const response = await callOpenAiApi(
const response = await callOpenAiApi<EditResponse>(
"https://api.openai.com/v1/edits",
requestData,
apiKey,
);
return response.choices[0].text.trim();
},
validate: async (runtime, message) => {
validate: async (runtime, _message) => {
return !!runtime.getSetting("OPENAI_API_KEY");
},
examples: [],
Expand Down
19 changes: 10 additions & 9 deletions packages/plugin-openai/src/actions/generateEmbeddingAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@ import {
export const generateEmbeddingAction: Action = {
name: "generateEmbedding",
description: "Generate embeddings using OpenAI",
async handler(runtime, message, state) {
const input = message.content.text?.trim() || "";
similes: [],
async handler(_runtime, message, _state) {
const input = (message.content.text as string)?.trim() || "";
validatePrompt(input);

const apiKey = validateApiKey();
const requestData = {
model: "text-embedding-ada-002",
input,
};
const requestData = buildRequestData(
"text-embedding-ada-002",
input
);

const response = await callOpenAiApi(
"https://api.openai.com/v1/embeddings",
requestData,
apiKey,
);
return response.data.map((item) => item.embedding);
) as { data: Array<{ embedding: number[] }> };
return response.data.map((item: { embedding: number[] }) => item.embedding);
},
validate: async (runtime, message) => {
validate: async (runtime, _message) => {
return !!runtime.getSetting("OPENAI_API_KEY");
},
examples: [],
Expand Down
15 changes: 8 additions & 7 deletions packages/plugin-openai/src/actions/generateTextAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@ import {
export const generateTextAction: Action = {
name: "generateText",
description: "Generate text using OpenAI",
async handler(runtime, message, state) {
const prompt = message.content.text?.trim() || "";
similes: [],
async handler(_runtime, message, _state) {
const prompt = (message.content.text as string)?.trim() || "";
validatePrompt(prompt);

const apiKey = validateApiKey();
const requestData = buildRequestData(
String(message.content.model),
prompt,
message.content.model,
message.content.maxTokens,
message.content.temperature,
typeof message.content.maxTokens === 'number' ? message.content.maxTokens : undefined,
typeof message.content.temperature === 'number' ? message.content.temperature : undefined,
);

const response = await callOpenAiApi(
"https://api.openai.com/v1/completions",
requestData,
apiKey,
);
) as { choices: Array<{ text: string }> };
return { text: response.choices[0].text.trim() };
},
validate: async (runtime, message) => {
validate: async (runtime, _message) => {
return !!runtime.getSetting("OPENAI_API_KEY");
},
examples: [],
Expand Down
17 changes: 11 additions & 6 deletions packages/plugin-openai/src/actions/moderateContentAction.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import type { Action } from "@elizaos/core";
import { validatePrompt, validateApiKey, callOpenAiApi } from "./action";
import { validatePrompt, validateApiKey, callOpenAiApi, buildRequestData } from "./action";

export const moderateContentAction: Action = {
name: "moderateContent",
description: "Moderate content using OpenAI",
async handler(runtime, message, state) {
const input = message.content.text?.trim() || "";
similes: [],
async handler(_runtime, message, _state) {
const input = (message.content.text as string)?.trim() || "";
validatePrompt(input);

const apiKey = validateApiKey();
const requestData = { input };
const requestData = buildRequestData(
"text-moderation-latest",
input
);

const response = await callOpenAiApi(
"https://api.openai.com/v1/moderations",
requestData,
apiKey,
);
) as { results: Array<{ flagged: boolean; categories: Record<string, boolean>; category_scores: Record<string, number> }> };
return response.results;
},
validate: async (runtime, message) => {
validate: async (runtime, _message) => {
return !!runtime.getSetting("OPENAI_API_KEY");
},
examples: [],
};

22 changes: 16 additions & 6 deletions packages/plugin-openai/src/actions/transcribeAudioAction.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
import type { Action } from "@elizaos/core";
import { validateApiKey, callOpenAiApi } from "./action";
import {
validateApiKey,
callOpenAiApi,
buildRequestData,
type OpenAIRequestData
} from "./action";

export const transcribeAudioAction: Action = {
name: "transcribeAudio",
description: "Transcribe audio using OpenAI Whisper",
async handler(runtime, message, state) {
similes: [],
async handler(_runtime, message, _state) {
const file = message.content.file;
if (!file) {
throw new Error("No audio file provided");
}

const apiKey = validateApiKey();
const formData = new FormData();
formData.append("file", file);
formData.append("file", file as Blob);
formData.append("model", "whisper-1");

interface TranscriptionResponse {
text: string;
}

const response = await callOpenAiApi(
"https://api.openai.com/v1/audio/transcriptions",
formData,
formData as unknown as OpenAIRequestData,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type cast from FormData to OpenAIRequestData is unsafe and will cause runtime issues since these types have fundamentally different structures. Consider either:

  1. Creating a dedicated function like callOpenAiMultipartApi for handling form data uploads
  2. Modifying callOpenAiApi to use a union type that accepts both JSON and form data payloads

This would maintain type safety while properly handling the different content types needed for file uploads.

Spotted by Graphite Reviewer

Is this helpful? React 👍 or 👎 to let us know.

apiKey,
);
) as TranscriptionResponse;
return response.text;
},
validate: async (runtime, message) => {
validate: async (runtime, _message) => {
return !!runtime.getSetting("OPENAI_API_KEY");
},
examples: [],
Expand Down
14 changes: 12 additions & 2 deletions packages/plugin-openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ export interface OpenAITextResponse {
choices: Array<{
text: string;
index: number;
logprobs: null | any;
logprobs: null | {
tokens: string[];
token_logprobs: number[];
top_logprobs: Record<string, number>[];
text_offset: number[];
};
finish_reason: string;
}>;
usage: {
Expand Down Expand Up @@ -59,7 +64,12 @@ export interface OpenAISentimentAnalysisResponse {
choices: Array<{
text: string;
index: number;
logprobs: null | any;
logprobs: null | {
tokens: string[];
token_logprobs: number[];
top_logprobs: Record<string, number>[];
text_offset: number[];
};
finish_reason: string;
}>;
}
Expand Down
Loading