From 36cfc75147ef4e6273ae08198d23acb12c296616 Mon Sep 17 00:00:00 2001 From: Lukas Date: Sat, 13 Apr 2024 12:57:00 +0200 Subject: [PATCH 1/4] move vertex types into own files --- .../google-vertex-ai/chatComplete.ts | 78 ++----------------- .../transformGenerationConfig.ts | 25 ++++++ src/providers/google-vertex-ai/types.ts | 44 +++++++++++ 3 files changed, 77 insertions(+), 70 deletions(-) create mode 100644 src/providers/google-vertex-ai/transformGenerationConfig.ts create mode 100644 src/providers/google-vertex-ai/types.ts diff --git a/src/providers/google-vertex-ai/chatComplete.ts b/src/providers/google-vertex-ai/chatComplete.ts index 3962716da..43a25001a 100644 --- a/src/providers/google-vertex-ai/chatComplete.ts +++ b/src/providers/google-vertex-ai/chatComplete.ts @@ -1,3 +1,6 @@ +// Docs for REST API +// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-drest + import { GOOGLE_VERTEX_AI } from '../../globals'; import { ContentType, Message, Params } from '../../types/requestBody'; import { @@ -9,31 +12,11 @@ import { generateErrorResponse, generateInvalidProviderResponseError, } from '../utils'; - -const transformGenerationConfig = (params: Params) => { - const generationConfig: Record = {}; - if (params['temperature']) { - generationConfig['temperature'] = params['temperature']; - } - if (params['top_p']) { - generationConfig['topP'] = params['top_p']; - } - if (params['top_k']) { - generationConfig['topK'] = params['top_k']; - } - if (params['max_tokens']) { - generationConfig['maxOutputTokens'] = params['max_tokens']; - } - if (params['stop']) { - generationConfig['stopSequences'] = params['stop']; - } - return generationConfig; -}; - -// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. - -// Docs for REST API -// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-drest +import { transformGenerationConfig } from './transformGenerationConfig'; +import type { + GoogleErrorResponse, + GoogleGenerateContentResponse, +} from './types'; export const GoogleChatCompleteConfig: ProviderConfig = { // https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions @@ -171,51 +154,6 @@ export const GoogleChatCompleteConfig: ProviderConfig = { }, }; -export interface GoogleErrorResponse { - error: { - code: number; - message: string; - status: string; - details: Array>; - }; -} - -interface GoogleGenerateFunctionCall { - name: string; - args: Record; -} - -interface GoogleGenerateContentResponse { - candidates: { - content: { - parts: { - text?: string; - functionCall?: GoogleGenerateFunctionCall; - }[]; - }; - finishReason: string; - index: 0; - safetyRatings: { - category: string; - probability: string; - }[]; - }[]; - promptFeedback: { - safetyRatings: { - category: string; - probability: string; - probabilityScore: number; - severity: string; - severityScore: number; - }[]; - }; - usageMetadata: { - promptTokenCount: number; - candidatesTokenCount: number; - totalTokenCount: number; - }; -} - export const GoogleChatCompleteResponseTransform: ( response: | GoogleGenerateContentResponse diff --git a/src/providers/google-vertex-ai/transformGenerationConfig.ts b/src/providers/google-vertex-ai/transformGenerationConfig.ts new file mode 100644 index 000000000..d2efef95e --- /dev/null +++ b/src/providers/google-vertex-ai/transformGenerationConfig.ts @@ -0,0 +1,25 @@ +import { Params } from '../../types/requestBody'; + +/** + * @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#request_body + */ +export function transformGenerationConfig(params: Params) { + const generationConfig: Record = {}; + if (params['temperature']) { + generationConfig['temperature'] = params['temperature']; + } + if (params['top_p']) { + generationConfig['topP'] = params['top_p']; + } + if (params['top_k']) { + generationConfig['topK'] = params['top_k']; + } + if (params['max_tokens']) { + generationConfig['maxOutputTokens'] = params['max_tokens']; + } + if (params['stop']) { + generationConfig['stopSequences'] = params['stop']; + } + + return generationConfig; +} diff --git a/src/providers/google-vertex-ai/types.ts b/src/providers/google-vertex-ai/types.ts new file mode 100644 index 000000000..8d054df15 --- /dev/null +++ b/src/providers/google-vertex-ai/types.ts @@ -0,0 +1,44 @@ +export interface GoogleErrorResponse { + error: { + code: number; + message: string; + status: string; + details: Array>; + }; +} + +export interface GoogleGenerateFunctionCall { + name: string; + args: Record; +} + +export interface GoogleGenerateContentResponse { + candidates: { + content: { + parts: { + text?: string; + functionCall?: GoogleGenerateFunctionCall; + }[]; + }; + finishReason: string; + index: 0; + safetyRatings: { + category: string; + probability: string; + }[]; + }[]; + promptFeedback: { + safetyRatings: { + category: string; + probability: string; + probabilityScore: number; + severity: string; + severityScore: number; + }[]; + }; + usageMetadata: { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + }; +} From bdd84e14fec0f125e19c6d9b7e8fb49e73fae0a8 Mon Sep 17 00:00:00 2001 From: Lukas Date: Sat, 13 Apr 2024 12:58:24 +0200 Subject: [PATCH 2/4] support json response for gemini 1.5 --- src/providers/google-vertex-ai/chatComplete.ts | 4 ++++ src/providers/google-vertex-ai/transformGenerationConfig.ts | 3 +++ src/types/requestBody.ts | 1 + 3 files changed, 8 insertions(+) diff --git a/src/providers/google-vertex-ai/chatComplete.ts b/src/providers/google-vertex-ai/chatComplete.ts index 43a25001a..870e975a0 100644 --- a/src/providers/google-vertex-ai/chatComplete.ts +++ b/src/providers/google-vertex-ai/chatComplete.ts @@ -124,6 +124,10 @@ export const GoogleChatCompleteConfig: ProviderConfig = { param: 'generationConfig', transform: (params: Params) => transformGenerationConfig(params), }, + response_format: { + param: 'generationConfig', + transform: (params: Params) => transformGenerationConfig(params), + }, // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes // Example payload to be included in the request that sets the safety settings: // "safety_settings": [ diff --git a/src/providers/google-vertex-ai/transformGenerationConfig.ts b/src/providers/google-vertex-ai/transformGenerationConfig.ts index d2efef95e..dc277864d 100644 --- a/src/providers/google-vertex-ai/transformGenerationConfig.ts +++ b/src/providers/google-vertex-ai/transformGenerationConfig.ts @@ -20,6 +20,9 @@ export function transformGenerationConfig(params: Params) { if (params['stop']) { generationConfig['stopSequences'] = params['stop']; } + if (params?.response_format?.type === 'json_object') { + generationConfig['responseMimeType'] = 'application/json'; + } return generationConfig; } diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index c59c80a17..45508ac68 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -212,6 +212,7 @@ export interface Params { examples?: Examples[]; top_k?: number; tools?: Tool[]; + response_format?: { type: 'json_object' | 'text' }; } interface Examples { From 91f20ef66d7c1e4957a6016e109afe71a8e5af04 Mon Sep 17 00:00:00 2001 From: Lukas Date: Sat, 13 Apr 2024 12:59:11 +0200 Subject: [PATCH 3/4] return error on vertex `PROHIBITED_CONTENT` case --- src/providers/google-vertex-ai/chatComplete.ts | 7 +++++++ src/types/requestBody.ts | 2 ++ 2 files changed, 9 insertions(+) diff --git a/src/providers/google-vertex-ai/chatComplete.ts b/src/providers/google-vertex-ai/chatComplete.ts index 870e975a0..180d0dc30 100644 --- a/src/providers/google-vertex-ai/chatComplete.ts +++ b/src/providers/google-vertex-ai/chatComplete.ts @@ -198,6 +198,13 @@ export const GoogleChatCompleteResponseTransform: ( ); } + if ( + 'candidates' in response && + response.candidates[0].finishReason === 'PROHIBITED_CONTENT' + ) { + return generateInvalidProviderResponseError(response, GOOGLE_VERTEX_AI); + } + if ('candidates' in response) { const { promptTokenCount = 0, diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 45508ac68..8f67ecab7 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -213,6 +213,8 @@ export interface Params { top_k?: number; tools?: Tool[]; response_format?: { type: 'json_object' | 'text' }; + // Google Vertex AI specific + safety_settings?: any; } interface Examples { From f8138dbc5170aa61904027bb83c75e884d34fd9b Mon Sep 17 00:00:00 2001 From: Lukas Date: Sat, 13 Apr 2024 13:00:46 +0200 Subject: [PATCH 4/4] fix typo in `recieved` --- src/handlers/handlerUtils.ts | 2 +- src/providers/utils.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 4d8a35594..84e0853a0 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -685,7 +685,7 @@ export async function tryProvidersInSequence( /** * Handles various types of responses based on the specified parameters * and returns a mapped response - * @param {Response} response - The HTTP response recieved from LLM. + * @param {Response} response - The HTTP response received from LLM. * @param {boolean} streamingMode - Indicates whether streaming mode is enabled. * @param {string} proxyProvider - The provider string. * @param {string | undefined} responseTransformer - The response transformer to determine type of call. diff --git a/src/providers/utils.ts b/src/providers/utils.ts index 8cb466651..e30ab689e 100644 --- a/src/providers/utils.ts +++ b/src/providers/utils.ts @@ -6,7 +6,7 @@ export const generateInvalidProviderResponseError: ( ) => ErrorResponse = (response, provider) => { return { error: { - message: `Invalid response recieved from ${provider}: ${JSON.stringify( + message: `Invalid response received from ${provider}: ${JSON.stringify( response )}`, type: null,