Skip to content

Commit

Permalink
Merge pull request #862 from Portkey-AI/feature/files-and-batches
Browse files Browse the repository at this point in the history
  • Loading branch information
VisargD authored Jan 17, 2025
2 parents 4fe6af0 + cdea58f commit c0a8fd4
Show file tree
Hide file tree
Showing 67 changed files with 4,023 additions and 355 deletions.
292 changes: 180 additions & 112 deletions package-lock.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"ts-jest": "^29.2.4",
"tsx": "^4.7.0",
"typescript-eslint": "^8.1.0",
"wrangler": "^3.48.0"
"wrangler": "^3.97.0"
},
"bin": "build/start-server.js",
"type": "module"
Expand Down
4 changes: 4 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export const HEADER_KEYS: Record<string, string> = {
MODE: `x-${POWERED_BY}-mode`,
RETRIES: `x-${POWERED_BY}-retry-count`,
PROVIDER: `x-${POWERED_BY}-provider`,
CONFIG: `x-${POWERED_BY}-config`,
TRACE_ID: `x-${POWERED_BY}-trace-id`,
CACHE: `x-${POWERED_BY}-cache`,
METADATA: `x-${POWERED_BY}-metadata`,
Expand Down Expand Up @@ -129,6 +130,7 @@ export const VALID_PROVIDERS = [
SAGEMAKER,
NEBIUS,
RECRAFTAI,
POWERED_BY,
];

export const CONTENT_TYPES = {
Expand All @@ -137,6 +139,7 @@ export const CONTENT_TYPES = {
EVENT_STREAM: 'text/event-stream',
AUDIO_MPEG: 'audio/mpeg',
APPLICATION_OCTET_STREAM: 'application/octet-stream',
BINARY_OCTET_STREAM: 'binary/octet-stream',
GENERIC_AUDIO_PATTERN: 'audio',
PLAIN_TEXT: 'text/plain',
HTML: 'text/html',
Expand All @@ -146,6 +149,7 @@ export const CONTENT_TYPES = {
export const MULTIPART_FORM_DATA_ENDPOINTS: endpointStrings[] = [
'createTranscription',
'createTranslation',
'uploadFile',
];

export const fileExtensionMimeTypeMap = {
Expand Down
44 changes: 44 additions & 0 deletions src/handlers/batchesHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { Context } from 'hono';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
} from './handlerUtils';
import { endpointStrings } from '../providers/types';

function batchesHandler(endpoint: endpointStrings, method: 'POST' | 'GET') {
async function handler(c: Context): Promise<Response> {
try {
let requestHeaders = Object.fromEntries(c.req.raw.headers);
let request = endpoint === 'createBatch' ? await c.req.json() : {};
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders);
const tryTargetsResponse = await tryTargetsRecursively(
c,
camelCaseConfig ?? {},
request,
requestHeaders,
endpoint,
method,
'config'
);

return tryTargetsResponse;
} catch (err: any) {
console.error({ message: `${endpoint} error ${err.message}` });
return new Response(
JSON.stringify({
status: 'failure',
message: 'Something went wrong',
}),
{
status: 500,
headers: {
'content-type': 'application/json',
},
}
);
}
}
return handler;
}

export default batchesHandler;
50 changes: 50 additions & 0 deletions src/handlers/filesHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { Context } from 'hono';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
} from './handlerUtils';
import { endpointStrings } from '../providers/types';

function filesHandler(
endpoint: endpointStrings,
method: 'POST' | 'GET' | 'DELETE'
) {
async function handler(c: Context): Promise<Response> {
try {
const requestHeaders = Object.fromEntries(c.req.raw.headers);
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders);
let body = {};
if (c.req.raw.body instanceof ReadableStream) {
body = c.req.raw.body;
}
const tryTargetsResponse = await tryTargetsRecursively(
c,
camelCaseConfig ?? {},
body,
requestHeaders,
endpoint,
method,
'config'
);

return tryTargetsResponse;
} catch (err: any) {
console.error({ message: `${endpoint} error ${err.message}` });
return new Response(
JSON.stringify({
status: 'failure',
message: 'Something went wrong',
}),
{
status: 500,
headers: {
'content-type': 'application/json',
},
}
);
}
}
return handler;
}

export default filesHandler;
89 changes: 70 additions & 19 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export function constructRequest(
let fetchOptions: RequestInit = {
method,
headers,
...(fn === 'uploadFile' && { duplex: 'half' }),
};
const contentType = headers['content-type']?.split(';')[0];
const isGetMethod = method === 'GET';
Expand All @@ -112,6 +113,8 @@ export function constructRequest(
let headers = fetchOptions.headers as Record<string, unknown>;
delete headers['content-type'];
}
if (fn === 'uploadFile')
headers['Content-Type'] = requestHeaders['content-type'];

return fetchOptions;
}
Expand Down Expand Up @@ -238,14 +241,17 @@ export function convertGuardrailsShorthand(guardrailsArr: any, type: string) {
export async function tryPost(
c: Context,
providerOption: Options,
inputParams: Params | FormData | ArrayBuffer,
requestBody: Params | FormData | ArrayBuffer | ReadableStream,
requestHeaders: Record<string, string>,
fn: endpointStrings,
currentIndex: number | string,
method: string = 'POST'
): Promise<Response> {
const overrideParams = providerOption?.overrideParams || {};
const params: Params = { ...inputParams, ...overrideParams };
const params: Params =
requestBody instanceof ReadableStream || requestBody instanceof FormData
? {}
: { ...requestBody, ...overrideParams };
const isStreamingMode = params.stream ? true : false;
let strictOpenAiCompliance = true;

Expand Down Expand Up @@ -277,14 +283,22 @@ export async function tryPost(
);

// Mapping providers to corresponding URLs
const apiConfig: ProviderAPIConfig = Providers[provider].api;
const providerConfig = Providers[provider];
const apiConfig: ProviderAPIConfig = providerConfig.api;
// Attach the body of the request
const transformedRequestBody = transformToProviderRequest(
provider,
params,
inputParams,
fn
);
let transformedRequestBody: ReadableStream | FormData | Params = {};
if (!providerConfig?.requestHandlers?.[fn]) {
transformedRequestBody =
method === 'POST'
? transformToProviderRequest(
provider,
params,
requestBody,
fn,
requestHeaders
)
: requestBody;
}

const forwardHeaders =
requestHeaders[HEADER_KEYS.FORWARD_HEADERS]
Expand All @@ -297,12 +311,19 @@ export async function tryPost(
requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || '';

const baseUrl =
customHost || apiConfig.getBaseURL({ providerOptions: providerOption });
customHost ||
(await apiConfig.getBaseURL({
providerOptions: providerOption,
fn,
c,
}));

const endpoint = apiConfig.getEndpoint({
c,
providerOptions: providerOption,
fn,
gatewayRequestBody: params,
gatewayRequestBodyJSON: params,
gatewayRequestBody: requestBody,
gatewayRequestURL: c.req.url,
});

Expand Down Expand Up @@ -343,6 +364,8 @@ export async function tryPost(
(fn == 'proxy' && requestContentType === CONTENT_TYPES.MULTIPART_FORM_DATA)
) {
fetchOptions.body = transformedRequestBody as FormData;
} else if (transformedRequestBody instanceof ReadableStream) {
fetchOptions.body = transformedRequestBody;
} else if (
fn == 'proxy' &&
requestContentType.startsWith(CONTENT_TYPES.GENERIC_AUDIO_PATTERN)
Expand Down Expand Up @@ -392,7 +415,8 @@ export async function tryPost(
url,
isCacheHit,
params,
strictOpenAiCompliance
strictOpenAiCompliance,
c.req.url
));
}

Expand All @@ -402,7 +426,8 @@ export async function tryPost(
params,
cacheStatus,
retryCount ?? 0,
requestHeaders[HEADER_KEYS.TRACE_ID] ?? ''
requestHeaders[HEADER_KEYS.TRACE_ID] ?? '',
provider
);

c.set('requestOptions', [
Expand Down Expand Up @@ -493,7 +518,8 @@ export async function tryPost(
fn,
requestHeaders,
hookSpan.id,
strictOpenAiCompliance
strictOpenAiCompliance,
requestBody
));

return createResponse(mappedResponse, undefined, false, true);
Expand All @@ -502,7 +528,7 @@ export async function tryPost(
export async function tryTargetsRecursively(
c: Context,
targetGroup: Targets,
request: Params | FormData,
request: Params | FormData | ReadableStream,
requestHeaders: Record<string, string>,
fn: endpointStrings,
method: string,
Expand Down Expand Up @@ -764,7 +790,8 @@ export function updateResponseHeaders(
params: Record<string, any>,
cacheStatus: string | undefined,
retryAttempt: number,
traceId: string
traceId: string,
provider: string
) {
response.headers.append(
RESPONSE_HEADER_KEYS.LAST_USED_OPTION_INDEX,
Expand Down Expand Up @@ -793,6 +820,9 @@ export function updateResponseHeaders(
// workerd environment handles this authomatically
response.headers.delete('content-length');
response.headers.delete('transfer-encoding');
if (provider && provider !== POWERED_BY) {
response.headers.append(HEADER_KEYS.PROVIDER, provider);
}
}

export function constructConfigFromRequestHeaders(
Expand Down Expand Up @@ -841,6 +871,9 @@ export function constructConfigFromRequestHeaders(
awsRoleArn: requestHeaders[`x-${POWERED_BY}-aws-role-arn`],
awsAuthType: requestHeaders[`x-${POWERED_BY}-aws-auth-type`],
awsExternalId: requestHeaders[`x-${POWERED_BY}-aws-external-id`],
awsS3Bucket: requestHeaders[`x-${POWERED_BY}-aws-s3-bucket`],
awsS3ObjectKey: requestHeaders[`x-${POWERED_BY}-aws-s3-object-key`],
awsBedrockModel: requestHeaders[`x-${POWERED_BY}-aws-bedrock-model`],
};

const sagemakerConfig = {
Expand Down Expand Up @@ -1035,7 +1068,8 @@ export async function recursiveAfterRequestHookHandler(
fn: any,
requestHeaders: Record<string, string>,
hookSpanId: string,
strictOpenAiCompliance: boolean
strictOpenAiCompliance: boolean,
requestBody?: ReadableStream | FormData | Params | ArrayBuffer
): Promise<{
mappedResponse: Response;
retryCount: number;
Expand All @@ -1050,6 +1084,21 @@ export async function recursiveAfterRequestHookHandler(

const { retry } = providerOption;

const provider = providerOption.provider ?? '';
const providerConfig = Providers[provider];
const requestHandlers = providerConfig.requestHandlers;
let requestHandler;
if (requestHandlers && requestHandlers[fn]) {
requestHandler = () =>
requestHandlers[fn]({
c,
providerOptions: providerOption,
requestURL: c.req.url,
requestHeaders,
requestBody,
});
}

({
response,
attempt: retryCount,
Expand All @@ -1059,7 +1108,8 @@ export async function recursiveAfterRequestHookHandler(
options,
retry?.attempts || 0,
retry?.onStatusCodes || [],
requestTimeout || null
requestTimeout || null,
requestHandler
));

const {
Expand All @@ -1074,7 +1124,8 @@ export async function recursiveAfterRequestHookHandler(
url,
false,
gatewayParams,
strictOpenAiCompliance
strictOpenAiCompliance,
c.req.url
);

const arhResponse = await afterRequestHookHandler(
Expand Down
3 changes: 2 additions & 1 deletion src/handlers/realtimeHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export async function realTimeHandler(c: Context): Promise<Response> {
const url = getURLForOutgoingConnection(
apiConfig,
providerOptions,
c.req.url
c.req.url,
c
);
const options = await getOptionsForOutgoingConnection(
apiConfig,
Expand Down
5 changes: 3 additions & 2 deletions src/handlers/realtimeHandlerNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ export async function realTimeHandlerNode(
const provider = camelCaseConfig?.provider ?? '';
const apiConfig: ProviderAPIConfig = Providers[provider].api;
const providerOptions = camelCaseConfig as Options;
const baseUrl = apiConfig.getBaseURL({ providerOptions });
const baseUrl = apiConfig.getBaseURL({ providerOptions, c });
const endpoint = apiConfig.getEndpoint({
c,
providerOptions,
fn: 'realtime',
gatewayRequestBody: {},
gatewayRequestBodyJSON: {},
gatewayRequestURL: c.req.url,
});
let url = `${baseUrl}${endpoint}`;
Expand Down
Loading

0 comments on commit c0a8fd4

Please sign in to comment.