Skip to content

Commit

Permalink
Merge pull request #301 from flexchar/small-qol-update
Browse files Browse the repository at this point in the history
Small QoL update
  • Loading branch information
VisargD authored Apr 26, 2024
2 parents 3e18127 + 32ad33c commit 07d215d
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ export async function tryProvidersInSequence(
/**
* Handles various types of responses based on the specified parameters
* and returns a mapped response
* @param {Response} response - The HTTP response recieved from LLM.
* @param {Response} response - The HTTP response received from LLM.
* @param {boolean} streamingMode - Indicates whether streaming mode is enabled.
* @param {string} proxyProvider - The provider string.
* @param {string | undefined} responseTransformer - The response transformer to determine type of call.
Expand Down
89 changes: 19 additions & 70 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Docs for REST API
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-drest

import { GOOGLE_VERTEX_AI } from '../../globals';
import { ContentType, Message, Params } from '../../types/requestBody';
import {
Expand All @@ -9,31 +12,11 @@ import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';

const transformGenerationConfig = (params: Params) => {
const generationConfig: Record<string, any> = {};
if (params['temperature']) {
generationConfig['temperature'] = params['temperature'];
}
if (params['top_p']) {
generationConfig['topP'] = params['top_p'];
}
if (params['top_k']) {
generationConfig['topK'] = params['top_k'];
}
if (params['max_tokens']) {
generationConfig['maxOutputTokens'] = params['max_tokens'];
}
if (params['stop']) {
generationConfig['stopSequences'] = params['stop'];
}
return generationConfig;
};

// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model.

// Docs for REST API
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-drest
import { transformGenerationConfig } from './transformGenerationConfig';
import type {
GoogleErrorResponse,
GoogleGenerateContentResponse,
} from './types';

export const GoogleChatCompleteConfig: ProviderConfig = {
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions
Expand Down Expand Up @@ -141,6 +124,10 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
response_format: {
param: 'generationConfig',
transform: (params: Params) => transformGenerationConfig(params),
},
// https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes
// Example payload to be included in the request that sets the safety settings:
// "safety_settings": [
Expand Down Expand Up @@ -171,51 +158,6 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
},
};

export interface GoogleErrorResponse {
error: {
code: number;
message: string;
status: string;
details: Array<Record<string, any>>;
};
}

interface GoogleGenerateFunctionCall {
name: string;
args: Record<string, any>;
}

interface GoogleGenerateContentResponse {
candidates: {
content: {
parts: {
text?: string;
functionCall?: GoogleGenerateFunctionCall;
}[];
};
finishReason: string;
index: 0;
safetyRatings: {
category: string;
probability: string;
}[];
}[];
promptFeedback: {
safetyRatings: {
category: string;
probability: string;
probabilityScore: number;
severity: string;
severityScore: number;
}[];
};
usageMetadata: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
}

export const GoogleChatCompleteResponseTransform: (
response:
| GoogleGenerateContentResponse
Expand Down Expand Up @@ -256,6 +198,13 @@ export const GoogleChatCompleteResponseTransform: (
);
}

if (
'candidates' in response &&
response.candidates[0].finishReason === 'PROHIBITED_CONTENT'
) {
return generateInvalidProviderResponseError(response, GOOGLE_VERTEX_AI);
}

if ('candidates' in response) {
const {
promptTokenCount = 0,
Expand Down
28 changes: 28 additions & 0 deletions src/providers/google-vertex-ai/transformGenerationConfig.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { Params } from '../../types/requestBody';

/**
* @see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#request_body
*/
export function transformGenerationConfig(params: Params) {
const generationConfig: Record<string, any> = {};
if (params['temperature']) {
generationConfig['temperature'] = params['temperature'];
}
if (params['top_p']) {
generationConfig['topP'] = params['top_p'];
}
if (params['top_k']) {
generationConfig['topK'] = params['top_k'];
}
if (params['max_tokens']) {
generationConfig['maxOutputTokens'] = params['max_tokens'];
}
if (params['stop']) {
generationConfig['stopSequences'] = params['stop'];
}
if (params?.response_format?.type === 'json_object') {
generationConfig['responseMimeType'] = 'application/json';
}

return generationConfig;
}
44 changes: 44 additions & 0 deletions src/providers/google-vertex-ai/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
export interface GoogleErrorResponse {
error: {
code: number;
message: string;
status: string;
details: Array<Record<string, any>>;
};
}

export interface GoogleGenerateFunctionCall {
name: string;
args: Record<string, any>;
}

export interface GoogleGenerateContentResponse {
candidates: {
content: {
parts: {
text?: string;
functionCall?: GoogleGenerateFunctionCall;
}[];
};
finishReason: string;
index: 0;
safetyRatings: {
category: string;
probability: string;
}[];
}[];
promptFeedback: {
safetyRatings: {
category: string;
probability: string;
probabilityScore: number;
severity: string;
severityScore: number;
}[];
};
usageMetadata: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
}
2 changes: 1 addition & 1 deletion src/providers/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export const generateInvalidProviderResponseError: (
) => ErrorResponse = (response, provider) => {
return {
error: {
message: `Invalid response recieved from ${provider}: ${JSON.stringify(
message: `Invalid response received from ${provider}: ${JSON.stringify(
response
)}`,
type: null,
Expand Down
3 changes: 3 additions & 0 deletions src/types/requestBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ export interface Params {
examples?: Examples[];
top_k?: number;
tools?: Tool[];
response_format?: { type: 'json_object' | 'text' };
// Google Vertex AI specific
safety_settings?: any;
}

interface Examples {
Expand Down

0 comments on commit 07d215d

Please sign in to comment.