Skip to content

Commit

Permalink
Merge pull request #304 from meronogbai/feat/add-reka-ai
Browse files Browse the repository at this point in the history
feat: add reka ai support
  • Loading branch information
VisargD authored May 7, 2024
2 parents c1c714b + 69b61b5 commit a7987cc
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const TOGETHER_AI: string = 'together-ai';
export const GOOGLE: string = 'google';
export const GOOGLE_VERTEX_AI: string = 'vertex-ai';
export const PERPLEXITY_AI: string = 'perplexity-ai';
export const REKA_AI: string = 'reka-ai';
export const MISTRAL_AI: string = 'mistral-ai';
export const DEEPINFRA: string = 'deepinfra';
export const STABILITY_AI: string = 'stability-ai';
Expand Down Expand Up @@ -62,6 +63,7 @@ export const VALID_PROVIDERS = [
OPEN_AI,
PALM,
PERPLEXITY_AI,
REKA_AI,
TOGETHER_AI,
DEEPINFRA,
STABILITY_AI,
Expand Down
2 changes: 2 additions & 0 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import SegmindConfig from './segmind';
import JinaConfig from './jina';
import FireworksAIConfig from './fireworks-ai';
import WorkersAiConfig from './workers-ai';
import RekaAIConfig from './reka-ai';
import MoonshotConfig from './moonshot';
import OpenrouterConfig from './openrouter';
import LingYiConfig from './lingyi';
Expand Down Expand Up @@ -51,6 +52,7 @@ const Providers: { [key: string]: ProviderConfigs } = {
jina: JinaConfig,
'fireworks-ai': FireworksAIConfig,
'workers-ai': WorkersAiConfig,
'reka-ai': RekaAIConfig,
moonshot: MoonshotConfig,
openrouter: OpenrouterConfig,
lingyi: LingYiConfig,
Expand Down
18 changes: 18 additions & 0 deletions src/providers/reka-ai/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { ProviderAPIConfig } from '../types';

const RekaAIApiConfig: ProviderAPIConfig = {
getBaseURL: () => 'https://api.reka.ai',
headers: ({ providerOptions }) => {
return { 'x-api-key': `${providerOptions.apiKey}` };
},
getEndpoint: ({ fn }) => {
switch (fn) {
case 'chatComplete':
return '/chat';
default:
return '';
}
},
};

export default RekaAIApiConfig;
189 changes: 189 additions & 0 deletions src/providers/reka-ai/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import { REKA_AI } from '../../globals';
import { Message, Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
ProviderConfig,
} from '../types';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';

interface RekaMessageItem {
text: string;
media_url?: string;
type: 'human' | 'model';
}

export const RekaAIChatCompleteConfig: ProviderConfig = {
model: {
param: 'model_name',
required: true,
default: 'reka-flash',
},
messages: {
param: 'conversation_history',
transform: (params: Params) => {
const messages: RekaMessageItem[] = [];
let lastType: 'human' | 'model' | undefined;

const addMessage = ({
type,
text,
media_url,
}: {
type: 'human' | 'model';
text: string;
media_url?: string;
}) => {
// NOTE: can't have more than one image in conversation history
if (media_url && messages[0].media_url) {
return;
}

const newMessage: RekaMessageItem = { type, text, media_url };

if (lastType === type) {
const placeholder: RekaMessageItem = {
type: type === 'human' ? 'model' : 'human',
text: 'Placeholder for alternation',
};
media_url
? messages.unshift(placeholder)
: messages.push(placeholder);
}

// NOTE: image need to be first
media_url ? messages.unshift(newMessage) : messages.push(newMessage);
lastType = type;
};

params.messages?.forEach((message) => {
const currentType: 'human' | 'model' =
message.role === 'user' ? 'human' : 'model';

if (!Array.isArray(message.content)) {
addMessage({ type: currentType, text: message.content || '' });
} else {
message.content.forEach((item) => {
addMessage({
type: currentType,
text: item.text || '',
media_url: item.image_url?.url,
});
});
}
});

if (messages[0].type !== 'human') {
messages.unshift({
type: 'human',
text: 'Placeholder for alternation',
});
}
return messages;
},
},
max_tokens: {
param: 'request_output_len',
},
temperature: {
param: 'temperature',
},
top_p: {
param: 'runtime_top_p',
},
stop: {
param: 'stop_words',
transform: (params: Params) => {
if (params.stop && !Array.isArray(params.stop)) {
return [params.stop];
}

return params.stop;
},
},
seed: {
param: 'random_seed',
},
frequency_penalty: {
param: 'frequency_penalty',
},
presence_penalty: {
param: 'presence_penalty',
},
// the following are reka specific
top_k: {
param: 'runtime_top_k',
},
length_penalty: {
param: 'length_penalty',
},
retrieval_dataset: {
param: 'retrieval_dataset',
},
use_search_engine: {
param: 'use_search_engine',
},
};

export interface RekaAIChatCompleteResponse {
type: string;
text: string;
finish_reason: string;
metadata: {
input_tokens: number;
generated_tokens: number;
};
}

export interface RekaAIErrorResponse {
detail: any; // could be string or array
}

export const RekaAIChatCompleteResponseTransform: (
response: RekaAIChatCompleteResponse | RekaAIErrorResponse,
responseStatus: number
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
if ('detail' in response) {
return generateErrorResponse(
{
message: JSON.stringify(response.detail),
type: null,
param: null,
code: null,
},
REKA_AI
);
}

if ('text' in response) {
return {
id: crypto.randomUUID(),
object: 'chat_completion',
created: Math.floor(Date.now() / 1000),
model: 'Unknown',
provider: REKA_AI,
choices: [
{
message: {
role: 'assistant',
content: response.text,
},
index: 0,
logprobs: null,
finish_reason: response.finish_reason,
},
],
usage: {
prompt_tokens: response.metadata.input_tokens,
completion_tokens: response.metadata.generated_tokens,
total_tokens:
response.metadata.input_tokens + response.metadata.generated_tokens,
},
};
}

return generateInvalidProviderResponseError(response, REKA_AI);
};
16 changes: 16 additions & 0 deletions src/providers/reka-ai/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { ProviderConfigs } from '../types';
import RekaAIApiConfig from './api';
import {
RekaAIChatCompleteConfig,
RekaAIChatCompleteResponseTransform,
} from './chatComplete';

const RekaAIConfig: ProviderConfigs = {
chatComplete: RekaAIChatCompleteConfig,
api: RekaAIApiConfig,
responseTransforms: {
chatComplete: RekaAIChatCompleteResponseTransform,
},
};

export default RekaAIConfig;

0 comments on commit a7987cc

Please sign in to comment.