From 604ef1d3f4a9ead2aa6008e3d53927b5e805d455 Mon Sep 17 00:00:00 2001 From: visargD Date: Mon, 1 Apr 2024 15:40:12 +0530 Subject: [PATCH] feat: add azure-openai image generate model support --- src/providers/azure-openai/api.ts | 32 ++++++------ src/providers/azure-openai/imageGenerate.ts | 56 +++++++++++++++++++++ src/providers/azure-openai/index.ts | 6 +++ 3 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 src/providers/azure-openai/imageGenerate.ts diff --git a/src/providers/azure-openai/api.ts b/src/providers/azure-openai/api.ts index fc9dda58b..e60fa37f2 100644 --- a/src/providers/azure-openai/api.ts +++ b/src/providers/azure-openai/api.ts @@ -13,24 +13,17 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = { getEndpoint: ({ providerOptions, fn }) => { const { apiVersion, urlToFetch } = providerOptions; let mappedFn = fn; - if ( - fn === 'proxy' && - urlToFetch && - urlToFetch?.indexOf('/chat/completions') > -1 - ) { - mappedFn = 'chatComplete'; - } else if ( - fn === 'proxy' && - urlToFetch && - urlToFetch?.indexOf('/completions') > -1 - ) { - mappedFn = 'complete'; - } else if ( - fn === 'proxy' && - urlToFetch && - urlToFetch?.indexOf('/embeddings') > -1 - ) { - mappedFn = 'embed'; + + if (fn === 'proxy' && urlToFetch) { + if (urlToFetch?.indexOf('/chat/completions') > -1) { + mappedFn = 'chatComplete'; + } else if (urlToFetch?.indexOf('/completions') > -1) { + mappedFn = 'complete'; + } else if (urlToFetch?.indexOf('/embeddings') > -1) { + mappedFn = 'embed'; + } else if (urlToFetch?.indexOf('/images/generations') > -1) { + mappedFn = 'imageGenerate'; + } } switch (mappedFn) { @@ -43,6 +36,9 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = { case 'embed': { return `/embeddings?api-version=${apiVersion}`; } + case 'imageGenerate': { + return `/images/generations?api-version=${apiVersion}`; + } default: return ''; } diff --git a/src/providers/azure-openai/imageGenerate.ts b/src/providers/azure-openai/imageGenerate.ts new file mode 100644 index 000000000..ae562cc43 --- /dev/null +++ b/src/providers/azure-openai/imageGenerate.ts @@ -0,0 +1,56 @@ +import { AZURE_OPEN_AI } from '../../globals'; +import { OpenAIErrorResponseTransform } from '../openai/chatComplete'; +import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types'; + +export const AzureOpenAIImageGenerateConfig: ProviderConfig = { + prompt: { + param: 'prompt', + required: true, + }, + model: { + param: 'model', + required: true, + default: 'dall-e-3', + }, + n: { + param: 'n', + min: 1, + max: 10, + }, + quality: { + param: 'quality', + }, + response_format: { + param: 'response_format', + }, + size: { + param: 'size', + }, + style: { + param: 'style', + }, + user: { + param: 'user', + }, +}; + +interface AzureOpenAIImageObject { + b64_json?: string; // The base64-encoded JSON of the generated image, if response_format is b64_json. + url?: string; // The URL of the generated image, if response_format is url (default). + revised_prompt?: string; // The prompt that was used to generate the image, if there was any revision to the prompt. +} + +interface AzureOpenAIImageGenerateResponse extends ImageGenerateResponse { + data: AzureOpenAIImageObject[]; +} + +export const AzureOpenAIImageGenerateResponseTransform: ( + response: AzureOpenAIImageGenerateResponse | ErrorResponse, + responseStatus: number +) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + return OpenAIErrorResponseTransform(response, AZURE_OPEN_AI); + } + + return response; +}; diff --git a/src/providers/azure-openai/index.ts b/src/providers/azure-openai/index.ts index 8eb50335a..590d55fc0 100644 --- a/src/providers/azure-openai/index.ts +++ b/src/providers/azure-openai/index.ts @@ -12,16 +12,22 @@ import { AzureOpenAIChatCompleteConfig, AzureOpenAIChatCompleteResponseTransform, } from './chatComplete'; +import { + AzureOpenAIImageGenerateConfig, + AzureOpenAIImageGenerateResponseTransform, +} from './imageGenerate'; const AzureOpenAIConfig: ProviderConfigs = { complete: AzureOpenAICompleteConfig, embed: AzureOpenAIEmbedConfig, api: AzureOpenAIAPIConfig, + imageGenerate: AzureOpenAIImageGenerateConfig, chatComplete: AzureOpenAIChatCompleteConfig, responseTransforms: { complete: AzureOpenAICompleteResponseTransform, chatComplete: AzureOpenAIChatCompleteResponseTransform, embed: AzureOpenAIEmbedResponseTransform, + imageGenerate: AzureOpenAIImageGenerateResponseTransform, }, };