Skip to content
This repository has been archived by the owner on Feb 25, 2025. It is now read-only.

Commit

Permalink
Merge pull request #597 from codestoryai/features/add-gemini-pro-support
Browse files Browse the repository at this point in the history
[ide] add gemini pro support
  • Loading branch information
theskcd authored Apr 26, 2024
2 parents af8167f + 007ce0a commit f90b1b2
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ export class CSInteractiveEditorSessionProvider implements vscode.InteractiveEdi
variables: [],
file_content_map: [],
terminal_selection: undefined,
folder_paths: [],
}
};
const messages = await this.sidecarClient.getInLineEditorResponse(context);
Expand Down
2 changes: 2 additions & 0 deletions extensions/codestory/src/sidecar/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ export class SideCarClient {
const activeWindowData = getCurrentActiveWindow();
const folders = folderFromQuery(query);
const sideCarModelConfiguration = await getSideCarModelConfiguration(await vscode.modelSelection.getConfiguration());
console.log(sideCarModelConfiguration);
console.log(JSON.stringify(sideCarModelConfiguration));
const agentSystemInstruction = readCustomSystemInstruction();
const body = {
repo_ref: repoRef.getRepresentation(),
Expand Down
13 changes: 12 additions & 1 deletion extensions/codestory/src/sidecar/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ export interface InEditorRequest {
snippetInformation: SnippetInformation;
textDocumentWeb: TextDocument;
diagnosticsInformation: DiagnosticInformationFromEditor | null;
userContext: { variables: SidecarVariableTypes[]; file_content_map: { file_path: string; file_content: string; language: string }[]; terminal_selection: string | undefined };
userContext: { variables: SidecarVariableTypes[]; file_content_map: { file_path: string; file_content: string; language: string }[]; terminal_selection: string | undefined; folder_paths: string[] };
}

export interface DiagnosticInformationFromEditor {
Expand Down Expand Up @@ -604,6 +604,14 @@ function getProviderConfiguration(type: string, value: ModelProviderConfiguratio
}
};
}
if (type === 'geminipro') {
return {
'GeminiPro': {
'api_key': value.apiKey,
'api_base': value.apiBase,
}
};
}
return null;
}

Expand Down Expand Up @@ -640,5 +648,8 @@ function getModelProviderConfiguration(providerConfiguration: ProviderSpecificCo
if (providerConfiguration.type === 'fireworkai') {
return 'FireworksAI';
}
if (providerConfiguration.type === 'geminipro') {
return 'GeminiPro';
}
return null;
}
39 changes: 30 additions & 9 deletions src/vs/platform/aiModel/common/aiModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ export const humanReadableProviderConfigKey: Record<string, string> = {
'apiBase': 'Base URL'
};

export type ProviderType = 'codestory' | 'openai-default' | 'azure-openai' | 'togetherai' | 'ollama' | 'openai-compatible' | 'anthropic' | 'fireworkai';
export const providerTypeValues: ProviderType[] = ['codestory', 'openai-default', 'azure-openai', 'togetherai', 'ollama', 'openai-compatible', 'anthropic', 'fireworkai'];
export type ProviderType = 'codestory' | 'openai-default' | 'azure-openai' | 'togetherai' | 'ollama' | 'openai-compatible' | 'anthropic' | 'fireworkai' | 'geminipro';
export const providerTypeValues: ProviderType[] = ['codestory', 'openai-default', 'azure-openai', 'togetherai', 'ollama', 'openai-compatible', 'anthropic', 'fireworkai', 'geminipro'];

export interface AzureOpenAIModelProviderConfig {
readonly type: 'azure-openai';
Expand Down Expand Up @@ -93,7 +93,13 @@ export interface FireworkAIProviderConfig {
readonly apiKey: string;
}

export type ProviderConfig = CodeStoryProviderConfig | OpenAIProviderConfig | AzureOpenAIProviderConfig | TogetherAIProviderConfig | OpenAICompatibleProviderConfig | OllamaProviderConfig | AnthropicProviderConfig | FireworkAIProviderConfig;
export interface GeminiProProviderConfig {
readonly name: 'GeminiPro';
readonly apiKey: string;
readonly apiBase: string;
}

export type ProviderConfig = CodeStoryProviderConfig | OpenAIProviderConfig | AzureOpenAIProviderConfig | TogetherAIProviderConfig | OpenAICompatibleProviderConfig | OllamaProviderConfig | AnthropicProviderConfig | FireworkAIProviderConfig | GeminiProProviderConfig;
export type ProviderConfigsWithAPIKey = Exclude<ProviderConfig, CodeStoryProviderConfig | OllamaProviderConfig>;

export type IModelProviders =
Expand All @@ -104,7 +110,8 @@ export type IModelProviders =
| { 'openai-compatible': OpenAICompatibleProviderConfig }
| { 'ollama': OllamaProviderConfig }
| { 'anthropic': AnthropicProviderConfig }
| { 'fireworkai': FireworkAIProviderConfig };
| { 'fireworkai': FireworkAIProviderConfig }
| { 'geminipro': GeminiProProviderConfig };

export function isModelProviderItem(obj: any): obj is IModelProviders {
return obj && typeof obj === 'object'
Expand Down Expand Up @@ -262,7 +269,15 @@ export const defaultModelSelectionSettings: IModelSelectionSettings = {
provider: {
type: 'codestory'
}
}
},
'GeminiPro1.5': {
name: 'Gemini Pro 1.5',
contextLength: 1000000,
temperature: 0.2,
provider: {
type: 'geminipro',
}
},
},
providers: {
'codestory': {
Expand Down Expand Up @@ -297,6 +312,11 @@ export const defaultModelSelectionSettings: IModelSelectionSettings = {
name: 'Firework AI',
apiKey: '',
},
'geminipro': {
name: 'GeminiPro',
apiBase: '',
apiKey: '',
}
}
};

Expand All @@ -309,6 +329,7 @@ export const supportedModels: Record<ProviderType, string[]> = {
'ollama': ['Mixtral', 'MistralInstruct', 'CodeLlama13BInstruct', 'DeepSeekCoder1.3BInstruct', 'DeepSeekCoder6BInstruct', 'DeepSeekCoder33BInstruct'],
'anthropic': ['ClaudeOpus', 'ClaudeSonnet', 'ClaudeHaiku'],
'fireworkai': ['CodeLlama13BInstruct'],
'geminipro': ['GeminiPro1.5'],
};

export const providersSupportingModel = (model: string): ProviderType[] => {
Expand All @@ -332,23 +353,23 @@ export const isDefaultProviderConfig = (key: ProviderType, config: ProviderConfi
const defaultConfig = defaultModelSelectionSettings.providers[key as keyof IModelProviders] as ProviderConfig;
return defaultConfig
&& defaultConfig.name === config.name
&& (defaultConfig.name === 'OpenAI' || defaultConfig.name === 'Together AI' || defaultConfig.name === 'Azure OpenAI' || defaultConfig.name === 'OpenAI Compatible' || defaultConfig.name === 'Anthropic' || defaultConfig.name === 'Firework AI'
&& (defaultConfig.name === 'OpenAI' || defaultConfig.name === 'Together AI' || defaultConfig.name === 'Azure OpenAI' || defaultConfig.name === 'OpenAI Compatible' || defaultConfig.name === 'Anthropic' || defaultConfig.name === 'Firework AI' || defaultConfig.name === 'GeminiPro'
? (defaultConfig).apiKey === (config as ProviderConfigsWithAPIKey).apiKey
: true
)
&& (defaultConfig.name === 'Azure OpenAI' || defaultConfig.name === 'OpenAI Compatible'
&& (defaultConfig.name === 'Azure OpenAI' || defaultConfig.name === 'OpenAI Compatible' || defaultConfig.name === 'GeminiPro'
? defaultConfig.apiBase === (config as BaseOpenAICompatibleProviderConfig).apiBase
: true
);
};

export const areProviderConfigsEqual = (a: ProviderConfig, b: ProviderConfig) => {
return a.name === b.name
&& (a.name === 'OpenAI' || a.name === 'Together AI' || a.name === 'Azure OpenAI' || a.name === 'OpenAI Compatible' || a.name === 'Anthropic' || a.name === 'Firework AI'
&& (a.name === 'OpenAI' || a.name === 'Together AI' || a.name === 'Azure OpenAI' || a.name === 'OpenAI Compatible' || a.name === 'Anthropic' || a.name === 'Firework AI' || a.name === 'GeminiPro'
? (a as ProviderConfigsWithAPIKey).apiKey === (b as ProviderConfigsWithAPIKey).apiKey
: true
)
&& (a.name === 'Azure OpenAI' || a.name === 'OpenAI Compatible'
&& (a.name === 'Azure OpenAI' || a.name === 'OpenAI Compatible' || a.name === 'GeminiPro'
? (a as BaseOpenAICompatibleProviderConfig).apiBase === (b as BaseOpenAICompatibleProviderConfig).apiBase
: true
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@
background-color: currentColor;
}

.model-selection-editor > .model-selection-body > .table-container .monaco-table-td .provider-logo.geminipro {
mask-image: url('logos/fireworks.svg');
-webkit-mask-image: url('logos/fireworks.svg');
mask-size: contain;
-webkit-mask-size: contain;
mask-repeat: no-repeat;
-webkit-mask-repeat: no-repeat;
mask-position: center;
-webkit-mask-position: center;
background-color: currentColor;
}

.model-selection-editor > .model-selection-body > .table-container .monaco-table-td .provider-logo {
display: block;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ class ProviderConfigColumnRenderer implements ITableRenderer<IProviderItemEntry,
disposeTemplate(templateData: IProviderConfigColumnTemplateData): void { }

private getEmptyConfigurationMessage(providerType: ProviderType): { message: string; complete: boolean } {
if (providerType === 'azure-openai' || providerType === 'openai-default' || providerType === 'togetherai' || providerType === 'openai-compatible' || providerType === 'anthropic' || providerType === 'fireworkai') {
if (providerType === 'azure-openai' || providerType === 'openai-default' || providerType === 'togetherai' || providerType === 'openai-compatible' || providerType === 'anthropic' || providerType === 'fireworkai' || providerType === 'geminipro') {
return { message: 'Configuration incomplete', complete: false };
} else if (providerType === 'codestory' || providerType === 'ollama') {
return { message: 'No configuration required', complete: true };
Expand Down
33 changes: 29 additions & 4 deletions src/vs/workbench/services/aiModel/browser/aiModelService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class AIModelsService extends Disposable implements IAIModelSelectionServ
const key = untypedKey as ProviderType;
const acc = untypedAcc as { [key: string]: ProviderConfig };
const provider = modelSelection.providers[key as keyof typeof modelSelection.providers] as ProviderConfig;
if ((provider.name === 'Azure OpenAI' || provider.name === 'OpenAI Compatible') && (provider.apiBase.length > 0 && provider.apiKey.length > 0)) {
if ((provider.name === 'Azure OpenAI' || provider.name === 'OpenAI Compatible' || provider.name === 'GeminiPro') && (provider.apiBase.length > 0 && provider.apiKey.length > 0)) {
acc[key] = provider;
} else if ((provider.name === 'OpenAI' || provider.name === 'Together AI' || provider.name === 'OpenAI Compatible' || provider.name === 'Anthropic' || provider.name === 'Firework AI') && (provider.apiKey?.length ?? 0) > 0) {
acc[key] = provider;
Expand All @@ -95,7 +95,8 @@ export class AIModelsService extends Disposable implements IAIModelSelectionServ
|| model.provider.type === 'openai-compatible'
|| model.provider.type === 'ollama'
|| model.provider.type === 'anthropic'
|| model.provider.type === 'fireworkai') {
|| model.provider.type === 'fireworkai'
|| model.provider.type === 'geminipro') {
acc[key] = model;
}
}
Expand Down Expand Up @@ -381,7 +382,7 @@ class ModelSelectionJsonSchema {
'fireworksaiProvider': {
'type': 'object',
'properties': {
'anthropic': {
'fireworksai': {
'type': 'object',
'properties': {
'name': {
Expand All @@ -397,6 +398,29 @@ class ModelSelectionJsonSchema {
}
}
},
'geminiProProvider': {
'type': 'object',
'properties': {
'geminipro': {
'type': 'object',
'properties': {
'name': {
'enum': ['Gemini Pro 1.5'],
'description': nls.localize('modelSelection.json.geminiProProvider.name', 'Name of the provider')
},
'apiKey': {
'type': 'string',
'description': nls.localize('modelSelection.json.geminiProProvider.apiKey', 'API key for the provider')
},
'apiBase': {
'type': 'string',
'description': nls.localize('modelSelection.json.geminiProProvider.apiBase', 'Base URL of the provider\'s API')
}
},
'required': ['name', 'apiKey']
}
}
},
'providers': {
'oneOf': [
{ '$ref': '#/definitions/codestoryProvider' },
Expand All @@ -407,6 +431,7 @@ class ModelSelectionJsonSchema {
{ '$ref': '#/definitions/ollamaProvider' },
{ '$ref': '#/definitions/anthropicProvider' },
{ '$ref': '#/definitions/fireworksaiProvider' },
{ '$ref': '#/definitions/geminiProProvider' },
]
},
'azureOpenAIModelProviderConfig': {
Expand All @@ -427,7 +452,7 @@ class ModelSelectionJsonSchema {
'type': 'object',
'properties': {
'type': {
'enum': ['codestory', 'openai-default', 'togetherai', 'openai-compatible', 'ollama', 'anthropic', 'fireworkai'],
'enum': ['codestory', 'openai-default', 'togetherai', 'openai-compatible', 'ollama', 'anthropic', 'fireworkai', 'geminipro'],
'description': nls.localize('modelSelection.json.genericModelProviderConfig.type', 'Type of the provider')
}
},
Expand Down
6 changes: 5 additions & 1 deletion src/vs/workbench/services/preferences/common/preferences.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { URI } from 'vs/base/common/uri';
import { IRange } from 'vs/editor/common/core/range';
import { IEditorContribution } from 'vs/editor/common/editorCommon';
import { ITextModel } from 'vs/editor/common/model';
import { ModelProviderConfig, ProviderType, ProviderConfig, AzureOpenAIProviderConfig, OpenAIProviderConfig, OpenAICompatibleProviderConfig, AnthropicProviderConfig, FireworkAIProviderConfig } from 'vs/platform/aiModel/common/aiModels';
import { ModelProviderConfig, ProviderType, ProviderConfig, AzureOpenAIProviderConfig, OpenAIProviderConfig, OpenAICompatibleProviderConfig, AnthropicProviderConfig, FireworkAIProviderConfig, GeminiProProviderConfig } from 'vs/platform/aiModel/common/aiModels';
import { ConfigurationTarget } from 'vs/platform/configuration/common/configuration';
import { ConfigurationScope, EditPresentationTypes, IExtensionInfo } from 'vs/platform/configuration/common/configurationRegistry';
import { IEditorOptions } from 'vs/platform/editor/common/editor';
Expand Down Expand Up @@ -394,5 +394,9 @@ export const isProviderItemConfigComplete = (providerItem: IProviderItem): boole
const { name, apiKey } = providerItem as FireworkAIProviderConfig;
return !!name && !!apiKey;
}
case 'geminipro': {
const { name, apiKey, apiBase } = providerItem as GeminiProProviderConfig;
return !!name && !!apiKey && !!apiBase;
}
}
};

0 comments on commit f90b1b2

Please sign in to comment.