Skip to content

Commit

Permalink
feat: add azure-openai image generate model support
Browse files Browse the repository at this point in the history
  • Loading branch information
VisargD committed Apr 1, 2024
1 parent 0454d41 commit 604ef1d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
32 changes: 14 additions & 18 deletions src/providers/azure-openai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,17 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = {
getEndpoint: ({ providerOptions, fn }) => {
const { apiVersion, urlToFetch } = providerOptions;
let mappedFn = fn;
if (
fn === 'proxy' &&
urlToFetch &&
urlToFetch?.indexOf('/chat/completions') > -1
) {
mappedFn = 'chatComplete';
} else if (
fn === 'proxy' &&
urlToFetch &&
urlToFetch?.indexOf('/completions') > -1
) {
mappedFn = 'complete';
} else if (
fn === 'proxy' &&
urlToFetch &&
urlToFetch?.indexOf('/embeddings') > -1
) {
mappedFn = 'embed';

if (fn === 'proxy' && urlToFetch) {
if (urlToFetch?.indexOf('/chat/completions') > -1) {
mappedFn = 'chatComplete';
} else if (urlToFetch?.indexOf('/completions') > -1) {
mappedFn = 'complete';
} else if (urlToFetch?.indexOf('/embeddings') > -1) {
mappedFn = 'embed';
} else if (urlToFetch?.indexOf('/images/generations') > -1) {
mappedFn = 'imageGenerate';
}
}

switch (mappedFn) {
Expand All @@ -43,6 +36,9 @@ const AzureOpenAIAPIConfig: ProviderAPIConfig = {
case 'embed': {
return `/embeddings?api-version=${apiVersion}`;
}
case 'imageGenerate': {
return `/images/generations?api-version=${apiVersion}`;
}
default:
return '';
}
Expand Down
56 changes: 56 additions & 0 deletions src/providers/azure-openai/imageGenerate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { AZURE_OPEN_AI } from '../../globals';
import { OpenAIErrorResponseTransform } from '../openai/chatComplete';
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';

export const AzureOpenAIImageGenerateConfig: ProviderConfig = {
prompt: {
param: 'prompt',
required: true,
},
model: {
param: 'model',
required: true,
default: 'dall-e-3',
},
n: {
param: 'n',
min: 1,
max: 10,
},
quality: {
param: 'quality',
},
response_format: {
param: 'response_format',
},
size: {
param: 'size',
},
style: {
param: 'style',
},
user: {
param: 'user',
},
};

interface AzureOpenAIImageObject {
b64_json?: string; // The base64-encoded JSON of the generated image, if response_format is b64_json.
url?: string; // The URL of the generated image, if response_format is url (default).
revised_prompt?: string; // The prompt that was used to generate the image, if there was any revision to the prompt.
}

interface AzureOpenAIImageGenerateResponse extends ImageGenerateResponse {
data: AzureOpenAIImageObject[];
}

export const AzureOpenAIImageGenerateResponseTransform: (
response: AzureOpenAIImageGenerateResponse | ErrorResponse,
responseStatus: number
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200 && 'error' in response) {
return OpenAIErrorResponseTransform(response, AZURE_OPEN_AI);
}

return response;
};
6 changes: 6 additions & 0 deletions src/providers/azure-openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@ import {
AzureOpenAIChatCompleteConfig,
AzureOpenAIChatCompleteResponseTransform,
} from './chatComplete';
import {
AzureOpenAIImageGenerateConfig,
AzureOpenAIImageGenerateResponseTransform,
} from './imageGenerate';

const AzureOpenAIConfig: ProviderConfigs = {
complete: AzureOpenAICompleteConfig,
embed: AzureOpenAIEmbedConfig,
api: AzureOpenAIAPIConfig,
imageGenerate: AzureOpenAIImageGenerateConfig,
chatComplete: AzureOpenAIChatCompleteConfig,
responseTransforms: {
complete: AzureOpenAICompleteResponseTransform,
chatComplete: AzureOpenAIChatCompleteResponseTransform,
embed: AzureOpenAIEmbedResponseTransform,
imageGenerate: AzureOpenAIImageGenerateResponseTransform,
},
};

Expand Down

0 comments on commit 604ef1d

Please sign in to comment.