Skip to content

Commit

Permalink
Merge pull request #856 from narengogi/chore/vertex/gemini-thinking-m…
Browse files Browse the repository at this point in the history
…odel-support

support for gemini thinking model
  • Loading branch information
VisargD authored Jan 18, 2025
2 parents 7283e59 + 19ffb37 commit b8f8753
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 12 deletions.
14 changes: 13 additions & 1 deletion src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ import { Options } from '../../types/requestBody';
import { endpointStrings, ProviderAPIConfig } from '../types';
import { getModelAndProvider, getAccessToken } from './utils';

const shouldUseBeta1Version = (provider: string, inputModel: string) => {
if (
provider === 'meta' ||
inputModel.includes('gemini-2.0-flash-thinking-exp')
)
return true;
return false;
};

const getProjectRoute = (
providerOptions: Options,
inputModel: string
Expand All @@ -17,7 +26,10 @@ const getProjectRoute = (
}

const { provider } = getModelAndProvider(inputModel as string);
const routeVersion = provider === 'meta' ? 'v1beta1' : 'v1';
let routeVersion = provider === 'meta' ? 'v1beta1' : 'v1';
if (shouldUseBeta1Version(provider, inputModel)) {
routeVersion = 'v1beta1';
}
return `/${routeVersion}/projects/${projectId}/locations/${vertexRegion}`;
};

Expand Down
44 changes: 41 additions & 3 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,26 @@ export const GoogleChatCompleteResponseTransform: (
provider: GOOGLE_VERTEX_AI,
choices:
response.candidates?.map((generation, index) => {
const containsChainOfThoughtMessage =
generation.content?.parts.length > 1;
let message: Message = { role: 'assistant', content: '' };
if (generation.content?.parts[0]?.text) {
let content: string = generation.content.parts[0]?.text;
if (
containsChainOfThoughtMessage &&
generation.content.parts[1]?.text
) {
if (strictOpenAiCompliance)
content = generation.content.parts[1]?.text;
else
content =
generation.content.parts[0]?.text +
'\r\n\r\n' +
generation.content.parts[1]?.text;
}
message = {
role: 'assistant',
content: generation.content.parts[0]?.text,
content,
};
} else if (generation.content?.parts[0]?.functionCall) {
message = {
Expand Down Expand Up @@ -751,9 +766,11 @@ export const GoogleChatCompleteStreamChunkTransform: (
) => string = (
responseChunk,
fallbackId,
_streamState,
streamState,
strictOpenAiCompliance
) => {
streamState.containsChainOfThoughtMessage =
streamState?.containsChainOfThoughtMessage ?? false;
const chunk = responseChunk
.trim()
.replace(/^data: /, '')
Expand Down Expand Up @@ -784,9 +801,30 @@ export const GoogleChatCompleteStreamChunkTransform: (
parsedChunk.candidates?.map((generation, index) => {
let message: Message = { role: 'assistant', content: '' };
if (generation.content?.parts[0]?.text) {
if (generation.content.parts[0].thought)
streamState.containsChainOfThoughtMessage = true;

let content: string =
strictOpenAiCompliance && streamState.containsChainOfThoughtMessage
? ''
: generation.content.parts[0]?.text;
if (generation.content.parts[1]?.text) {
if (strictOpenAiCompliance)
content = generation.content.parts[1].text;
else content += '\r\n\r\n' + generation.content.parts[1]?.text;
streamState.containsChainOfThoughtMessage = false;
} else if (
streamState.containsChainOfThoughtMessage &&
!generation.content.parts[0]?.thought
) {
if (strictOpenAiCompliance)
content = generation.content.parts[0].text;
else content = '\r\n\r\n' + content;
streamState.containsChainOfThoughtMessage = false;
}
message = {
role: 'assistant',
content: generation.content.parts[0]?.text,
content,
};
} else if (generation.content?.parts[0]?.functionCall) {
message = {
Expand Down
1 change: 1 addition & 0 deletions src/providers/google-vertex-ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export interface GoogleGenerateContentResponse {
content: {
parts: {
text?: string;
thought?: string; // for models like gemini-2.0-flash-thinking-exp refer: https://ai.google.dev/gemini-api/docs/thinking-mode#streaming_model_thinking
functionCall?: GoogleGenerateFunctionCall;
}[];
};
Expand Down
12 changes: 8 additions & 4 deletions src/providers/google/api.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import { ProviderAPIConfig } from '../types';

export const GoogleApiConfig: ProviderAPIConfig = {
getBaseURL: () => 'https://generativelanguage.googleapis.com/v1beta',
getBaseURL: () => 'https://generativelanguage.googleapis.com',
headers: () => {
return { 'Content-Type': 'application/json' };
},
getEndpoint: ({ fn, providerOptions, gatewayRequestBodyJSON }) => {
let routeVersion = 'v1beta';
let mappedFn = fn;
const { model, stream } = gatewayRequestBodyJSON;
if (model?.includes('gemini-2.0-flash-thinking-exp')) {
routeVersion = 'v1alpha';
}
const { apiKey } = providerOptions;
if (stream && fn === 'chatComplete') {
return `/models/${model}:streamGenerateContent?key=${apiKey}`;
return `/${routeVersion}/models/${model}:streamGenerateContent?key=${apiKey}`;
}
switch (mappedFn) {
case 'chatComplete': {
return `/models/${model}:generateContent?key=${apiKey}`;
return `/${routeVersion}/models/${model}:generateContent?key=${apiKey}`;
}
case 'embed': {
return `/models/${model}:embedContent?key=${apiKey}`;
return `/${routeVersion}/models/${model}:embedContent?key=${apiKey}`;
}
default:
return '';
Expand Down
48 changes: 44 additions & 4 deletions src/providers/google/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ interface GoogleGenerateContentResponse {
content: {
parts: {
text?: string;
thought?: string; // for models like gemini-2.0-flash-thinking-exp refer: https://ai.google.dev/gemini-api/docs/thinking-mode#streaming_model_thinking
functionCall?: GoogleGenerateFunctionCall;
}[];
};
Expand Down Expand Up @@ -489,11 +490,26 @@ export const GoogleChatCompleteResponseTransform: (
provider: 'google',
choices:
response.candidates?.map((generation, idx) => {
const containsChainOfThoughtMessage =
generation.content?.parts.length > 1;
let message: Message = { role: 'assistant', content: '' };
if (generation.content?.parts[0]?.text) {
let content: string = generation.content.parts[0]?.text;
if (
containsChainOfThoughtMessage &&
generation.content.parts[1]?.text
) {
if (strictOpenAiCompliance)
content = generation.content.parts[1]?.text;
else
content =
generation.content.parts[0]?.text +
'\r\n\r\n' +
generation.content.parts[1]?.text;
}
message = {
role: 'assistant',
content: generation.content.parts[0]?.text,
content,
};
} else if (generation.content?.parts[0]?.functionCall) {
message = {
Expand Down Expand Up @@ -540,9 +556,11 @@ export const GoogleChatCompleteStreamChunkTransform: (
) => string = (
responseChunk,
fallbackId,
_streamState,
streamState,
strictOpenAiCompliance
) => {
streamState.containsChainOfThoughtMessage =
streamState?.containsChainOfThoughtMessage ?? false;
let chunk = responseChunk.trim();
if (chunk.startsWith('[')) {
chunk = chunk.slice(1);
Expand Down Expand Up @@ -572,10 +590,32 @@ export const GoogleChatCompleteStreamChunkTransform: (
choices:
parsedChunk.candidates?.map((generation, index) => {
let message: Message = { role: 'assistant', content: '' };
if (generation.content.parts[0]?.text) {
if (generation.content?.parts[0]?.text) {
if (generation.content.parts[0].thought)
streamState.containsChainOfThoughtMessage = true;
let content: string =
strictOpenAiCompliance &&
streamState.containsChainOfThoughtMessage
? ''
: generation.content.parts[0]?.text;
if (generation.content.parts[1]?.text) {
if (strictOpenAiCompliance)
content = generation.content.parts[1].text;
else content += '\r\n\r\n' + generation.content.parts[1]?.text;
streamState.containsChainOfThoughtMessage = false;
} else if (
streamState.containsChainOfThoughtMessage &&
!generation.content.parts[0]?.thought
) {
if (strictOpenAiCompliance)
content = generation.content.parts[0].text;
else content = '\r\n\r\n' + content;
streamState.containsChainOfThoughtMessage = false;
}
message = {
role: 'assistant',
content: generation.content.parts[0]?.text,
content,
};
} else if (generation.content.parts[0]?.functionCall) {
message = {
Expand Down

0 comments on commit b8f8753

Please sign in to comment.