Skip to content

Commit

Permalink
Merge pull request #265 from flexchar/#10-support-vertex-ai
Browse files Browse the repository at this point in the history
Support for Google's Vertex AI
  • Loading branch information
VisargD authored Apr 10, 2024
2 parents 25709cf + f82669e commit a50f600
Show file tree
Hide file tree
Showing 12 changed files with 719 additions and 676 deletions.
879 changes: 213 additions & 666 deletions package-lock.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"prettier": "3.2.5",
"rollup": "^4.9.1",
"tsx": "^4.7.0",
"wrangler": "^3.0.1"
"wrangler": "^3.48.0"
},
"bin": "build/start-server.js",
"type": "module"
Expand Down
2 changes: 2 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const ANYSCALE: string = 'anyscale';
export const PALM: string = 'palm';
export const TOGETHER_AI: string = 'together-ai';
export const GOOGLE: string = 'google';
export const GOOGLE_VERTEX_AI: string = 'vertex-ai';
export const PERPLEXITY_AI: string = 'perplexity-ai';
export const MISTRAL_AI: string = 'mistral-ai';
export const DEEPINFRA: string = 'deepinfra';
Expand All @@ -50,6 +51,7 @@ export const VALID_PROVIDERS = [
AZURE_OPEN_AI,
COHERE,
GOOGLE,
GOOGLE_VERTEX_AI,
MISTRAL_AI,
OPEN_AI,
PALM,
Expand Down
16 changes: 14 additions & 2 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ export const fetchProviderOptionsFromConfig = (
providerOptions[0].deploymentId = camelCaseConfig.deploymentId;
if (camelCaseConfig.apiVersion)
providerOptions[0].apiVersion = camelCaseConfig.apiVersion;
if (camelCaseConfig.apiVersion)
providerOptions[0].vertexProjectId = camelCaseConfig.vertexProjectId;
if (camelCaseConfig.apiVersion)
providerOptions[0].vertexRegion = camelCaseConfig.vertexRegion;
if (camelCaseConfig.workersAiAccountId)
providerOptions[0].workersAiAccountId =
camelCaseConfig.workersAiAccountId;
Expand Down Expand Up @@ -395,7 +399,11 @@ export async function tryPostProxy(
c.set('requestOptions', [
...requestOptions,
{
providerOptions: { ...providerOption, requestURL: url, rubeusURL: fn },
providerOptions: {
...providerOption,
requestURL: url,
rubeusURL: fn,
},
requestParams: params,
response: mappedResponse.clone(),
cacheStatus: cacheStatus,
Expand Down Expand Up @@ -589,7 +597,11 @@ export async function tryPost(
c.set('requestOptions', [
...requestOptions,
{
providerOptions: { ...providerOption, requestURL: url, rubeusURL: fn },
providerOptions: {
...providerOption,
requestURL: url,
rubeusURL: fn,
},
requestParams: transformedRequestBody,
response: mappedResponse.clone(),
cacheStatus: cacheStatus,
Expand Down
15 changes: 11 additions & 4 deletions src/handlers/streamHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
} from '../globals';
import { OpenAIChatCompleteResponse } from '../providers/openai/chatComplete';
import { OpenAICompleteResponse } from '../providers/openai/complete';
import { getStreamModeSplitPattern } from '../utils';
import { getStreamModeSplitPattern, type SplitPatternType } from '../utils';

function readUInt32BE(buffer: Uint8Array, offset: number) {
return (
Expand Down Expand Up @@ -114,7 +114,7 @@ export async function* readAWSStream(

export async function* readStream(
reader: ReadableStreamDefaultReader,
splitPattern: string,
splitPattern: SplitPatternType,
transformFunction: Function | undefined,
isSleepTimeRequired: boolean,
fallbackChunkId: string
Expand Down Expand Up @@ -235,7 +235,9 @@ export async function handleStreamingMode(
requestURL: string
): Promise<Response> {
const splitPattern = getStreamModeSplitPattern(proxyProvider, requestURL);
const fallbackChunkId = Date.now().toString();
// If the provider doesn't supply completion id,
// we generate a fallback id using the provider name + timestamp.
const fallbackChunkId = `${proxyProvider}-${Date.now().toString()}`;

if (!response.body) {
throw new Error('Response format is invalid. Body not found');
Expand Down Expand Up @@ -274,7 +276,12 @@ export async function handleStreamingMode(

// Convert GEMINI/COHERE json stream to text/event-stream for non-proxy calls
if (
[GOOGLE, COHERE, BEDROCK].includes(proxyProvider) &&
[
//
GOOGLE,
COHERE,
BEDROCK,
].includes(proxyProvider) &&
responseTransformer
) {
return new Response(readable, {
Expand Down
21 changes: 19 additions & 2 deletions src/middlewares/requestValidator/schema/config.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { z } from 'zod';
import { OLLAMA, VALID_PROVIDERS } from '../../../globals';
import { OLLAMA, VALID_PROVIDERS, GOOGLE_VERTEX_AI } from '../../../globals';

export const configSchema: any = z
.object({
Expand All @@ -20,7 +20,9 @@ export const configSchema: any = z
provider: z
.string()
.refine((value) => VALID_PROVIDERS.includes(value), {
message: `Invalid 'provider' value. Must be one of: ${VALID_PROVIDERS.join(', ')}`,
message: `Invalid 'provider' value. Must be one of: ${VALID_PROVIDERS.join(
', '
)}`,
})
.optional(),
api_key: z.string().optional(),
Expand Down Expand Up @@ -57,6 +59,9 @@ export const configSchema: any = z
request_timeout: z.number().optional(),
custom_host: z.string().optional(),
forward_headers: z.array(z.string()).optional(),
// Google Vertex AI specific
vertex_project_id: z.string().optional(),
vertex_region: z.string().optional(),
})
.refine(
(value) => {
Expand Down Expand Up @@ -94,4 +99,16 @@ export const configSchema: any = z
{
message: 'Invalid custom host',
}
)
// Validate Google Vertex AI specific fields
.refine(
(value) => {
const isGoogleVertexAIProvider = value.provider === GOOGLE_VERTEX_AI;
const hasGoogleVertexAIFields =
value.vertex_project_id && value.vertex_region;
return !(isGoogleVertexAIProvider && !hasGoogleVertexAIFields);
},
{
message: `Invalid configuration. 'vertex_project_id' and 'vertex_region' are required for '${GOOGLE_VERTEX_AI}' provider. Example: { 'provider': 'vertex-ai', 'vertex_project_id': 'my-project-id', 'vertex_region': 'us-central1', api_key: 'ya29...' }`,
}
);
42 changes: 42 additions & 0 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { ProviderAPIConfig } from '../types';

// Good reference for using REST: https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstarts/quickstart-multimodal#gemini-beginner-samples-drest
// Difference versus Studio AI: https://cloud.google.com/vertex-ai/docs/start/ai-platform-users
export const GoogleApiConfig: ProviderAPIConfig = {
getBaseURL: ({ providerOptions }) => {
const { vertexProjectId, vertexRegion } = providerOptions;

return `https://${vertexRegion}-aiplatform.googleapis.com/v1/projects/${vertexProjectId}/locations/${vertexRegion}/publishers/google`;
},
headers: ({ providerOptions }) => {
const { apiKey } = providerOptions;

return {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
};
},
getEndpoint: ({ fn, gatewayRequestBody }) => {
let mappedFn = fn;
const { model, stream } = gatewayRequestBody;
if (stream) {
mappedFn = `stream-${fn}`;
}
switch (mappedFn) {
case 'chatComplete': {
return `/models/${model}:generateContent`;
}
case 'stream-chatComplete': {
return `/models/${model}:streamGenerateContent?alt=sse`;
}

// Embed API is not yet implemented in the gateway
// This may be as easy as copy-paste from Google provider, but needs to be tested

default:
return '';
}
},
};

export default GoogleApiConfig;
Loading

0 comments on commit a50f600

Please sign in to comment.