Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Files and Batches as a unified route #862

Merged
merged 16 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 180 additions & 112 deletions package-lock.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"ts-jest": "^29.2.4",
"tsx": "^4.7.0",
"typescript-eslint": "^8.1.0",
"wrangler": "^3.48.0"
"wrangler": "^3.97.0"
},
"bin": "build/start-server.js",
"type": "module"
Expand Down
4 changes: 4 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export const HEADER_KEYS: Record<string, string> = {
MODE: `x-${POWERED_BY}-mode`,
RETRIES: `x-${POWERED_BY}-retry-count`,
PROVIDER: `x-${POWERED_BY}-provider`,
CONFIG: `x-${POWERED_BY}-config`,
TRACE_ID: `x-${POWERED_BY}-trace-id`,
CACHE: `x-${POWERED_BY}-cache`,
METADATA: `x-${POWERED_BY}-metadata`,
Expand Down Expand Up @@ -129,6 +130,7 @@ export const VALID_PROVIDERS = [
SAGEMAKER,
NEBIUS,
RECRAFTAI,
POWERED_BY,
];

export const CONTENT_TYPES = {
Expand All @@ -137,6 +139,7 @@ export const CONTENT_TYPES = {
EVENT_STREAM: 'text/event-stream',
AUDIO_MPEG: 'audio/mpeg',
APPLICATION_OCTET_STREAM: 'application/octet-stream',
BINARY_OCTET_STREAM: 'binary/octet-stream',
GENERIC_AUDIO_PATTERN: 'audio',
PLAIN_TEXT: 'text/plain',
HTML: 'text/html',
Expand All @@ -146,6 +149,7 @@ export const CONTENT_TYPES = {
export const MULTIPART_FORM_DATA_ENDPOINTS: endpointStrings[] = [
'createTranscription',
'createTranslation',
'uploadFile',
];

export const fileExtensionMimeTypeMap = {
Expand Down
44 changes: 44 additions & 0 deletions src/handlers/batchesHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { Context } from 'hono';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
} from './handlerUtils';
import { endpointStrings } from '../providers/types';

function batchesHandler(endpoint: endpointStrings, method: 'POST' | 'GET') {
async function handler(c: Context): Promise<Response> {
try {
let requestHeaders = Object.fromEntries(c.req.raw.headers);
let request = endpoint === 'createBatch' ? await c.req.json() : {};
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders);
const tryTargetsResponse = await tryTargetsRecursively(
c,
camelCaseConfig ?? {},
request,
requestHeaders,
endpoint,
method,
'config'
);

return tryTargetsResponse;
} catch (err: any) {
console.error({ message: `${endpoint} error ${err.message}` });
return new Response(
JSON.stringify({
status: 'failure',
message: 'Something went wrong',
}),
{
status: 500,
headers: {
'content-type': 'application/json',
},
}
);
}
}
return handler;
}

export default batchesHandler;
50 changes: 50 additions & 0 deletions src/handlers/filesHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { Context } from 'hono';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
} from './handlerUtils';
import { endpointStrings } from '../providers/types';

function filesHandler(
endpoint: endpointStrings,
method: 'POST' | 'GET' | 'DELETE'
) {
async function handler(c: Context): Promise<Response> {
try {
const requestHeaders = Object.fromEntries(c.req.raw.headers);
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders);
let body = {};
if (c.req.raw.body instanceof ReadableStream) {
body = c.req.raw.body;
}
const tryTargetsResponse = await tryTargetsRecursively(
c,
camelCaseConfig ?? {},
body,
requestHeaders,
endpoint,
method,
'config'
);

return tryTargetsResponse;
} catch (err: any) {
console.error({ message: `${endpoint} error ${err.message}` });
return new Response(
JSON.stringify({
status: 'failure',
message: 'Something went wrong',
}),
{
status: 500,
headers: {
'content-type': 'application/json',
},
}
);
}
}
return handler;
}

export default filesHandler;
89 changes: 70 additions & 19 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export function constructRequest(
let fetchOptions: RequestInit = {
method,
headers,
...(fn === 'uploadFile' && { duplex: 'half' }),
};
const contentType = headers['content-type']?.split(';')[0];
const isGetMethod = method === 'GET';
Expand All @@ -112,6 +113,8 @@ export function constructRequest(
let headers = fetchOptions.headers as Record<string, unknown>;
delete headers['content-type'];
}
if (fn === 'uploadFile')
headers['Content-Type'] = requestHeaders['content-type'];

return fetchOptions;
}
Expand Down Expand Up @@ -238,14 +241,17 @@ export function convertGuardrailsShorthand(guardrailsArr: any, type: string) {
export async function tryPost(
c: Context,
providerOption: Options,
inputParams: Params | FormData | ArrayBuffer,
requestBody: Params | FormData | ArrayBuffer | ReadableStream,
requestHeaders: Record<string, string>,
fn: endpointStrings,
currentIndex: number | string,
method: string = 'POST'
): Promise<Response> {
const overrideParams = providerOption?.overrideParams || {};
const params: Params = { ...inputParams, ...overrideParams };
const params: Params =
requestBody instanceof ReadableStream || requestBody instanceof FormData
? {}
: { ...requestBody, ...overrideParams };
const isStreamingMode = params.stream ? true : false;
let strictOpenAiCompliance = true;

Expand Down Expand Up @@ -277,14 +283,22 @@ export async function tryPost(
);

// Mapping providers to corresponding URLs
const apiConfig: ProviderAPIConfig = Providers[provider].api;
const providerConfig = Providers[provider];
const apiConfig: ProviderAPIConfig = providerConfig.api;
// Attach the body of the request
const transformedRequestBody = transformToProviderRequest(
provider,
params,
inputParams,
fn
);
let transformedRequestBody: ReadableStream | FormData | Params = {};
if (!providerConfig?.requestHandlers?.[fn]) {
transformedRequestBody =
method === 'POST'
? transformToProviderRequest(
provider,
params,
requestBody,
fn,
requestHeaders
)
: requestBody;
}

const forwardHeaders =
requestHeaders[HEADER_KEYS.FORWARD_HEADERS]
Expand All @@ -297,12 +311,19 @@ export async function tryPost(
requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || '';

const baseUrl =
customHost || apiConfig.getBaseURL({ providerOptions: providerOption });
customHost ||
(await apiConfig.getBaseURL({
providerOptions: providerOption,
fn,
c,
}));

const endpoint = apiConfig.getEndpoint({
c,
providerOptions: providerOption,
fn,
gatewayRequestBody: params,
gatewayRequestBodyJSON: params,
gatewayRequestBody: requestBody,
gatewayRequestURL: c.req.url,
});

Expand Down Expand Up @@ -343,6 +364,8 @@ export async function tryPost(
(fn == 'proxy' && requestContentType === CONTENT_TYPES.MULTIPART_FORM_DATA)
) {
fetchOptions.body = transformedRequestBody as FormData;
} else if (transformedRequestBody instanceof ReadableStream) {
fetchOptions.body = transformedRequestBody;
} else if (
fn == 'proxy' &&
requestContentType.startsWith(CONTENT_TYPES.GENERIC_AUDIO_PATTERN)
Expand Down Expand Up @@ -392,7 +415,8 @@ export async function tryPost(
url,
isCacheHit,
params,
strictOpenAiCompliance
strictOpenAiCompliance,
c.req.url
));
}

Expand All @@ -402,7 +426,8 @@ export async function tryPost(
params,
cacheStatus,
retryCount ?? 0,
requestHeaders[HEADER_KEYS.TRACE_ID] ?? ''
requestHeaders[HEADER_KEYS.TRACE_ID] ?? '',
provider
);

c.set('requestOptions', [
Expand Down Expand Up @@ -493,7 +518,8 @@ export async function tryPost(
fn,
requestHeaders,
hookSpan.id,
strictOpenAiCompliance
strictOpenAiCompliance,
requestBody
));

return createResponse(mappedResponse, undefined, false, true);
Expand All @@ -502,7 +528,7 @@ export async function tryPost(
export async function tryTargetsRecursively(
c: Context,
targetGroup: Targets,
request: Params | FormData,
request: Params | FormData | ReadableStream,
requestHeaders: Record<string, string>,
fn: endpointStrings,
method: string,
Expand Down Expand Up @@ -764,7 +790,8 @@ export function updateResponseHeaders(
params: Record<string, any>,
cacheStatus: string | undefined,
retryAttempt: number,
traceId: string
traceId: string,
provider: string
) {
response.headers.append(
RESPONSE_HEADER_KEYS.LAST_USED_OPTION_INDEX,
Expand Down Expand Up @@ -793,6 +820,9 @@ export function updateResponseHeaders(
// workerd environment handles this authomatically
response.headers.delete('content-length');
response.headers.delete('transfer-encoding');
if (provider && provider !== POWERED_BY) {
response.headers.append(HEADER_KEYS.PROVIDER, provider);
}
}

export function constructConfigFromRequestHeaders(
Expand Down Expand Up @@ -841,6 +871,9 @@ export function constructConfigFromRequestHeaders(
awsRoleArn: requestHeaders[`x-${POWERED_BY}-aws-role-arn`],
awsAuthType: requestHeaders[`x-${POWERED_BY}-aws-auth-type`],
awsExternalId: requestHeaders[`x-${POWERED_BY}-aws-external-id`],
awsS3Bucket: requestHeaders[`x-${POWERED_BY}-aws-s3-bucket`],
awsS3ObjectKey: requestHeaders[`x-${POWERED_BY}-aws-s3-object-key`],
awsBedrockModel: requestHeaders[`x-${POWERED_BY}-aws-bedrock-model`],
};

const sagemakerConfig = {
Expand Down Expand Up @@ -1035,7 +1068,8 @@ export async function recursiveAfterRequestHookHandler(
fn: any,
requestHeaders: Record<string, string>,
hookSpanId: string,
strictOpenAiCompliance: boolean
strictOpenAiCompliance: boolean,
requestBody?: ReadableStream | FormData | Params | ArrayBuffer
): Promise<{
mappedResponse: Response;
retryCount: number;
Expand All @@ -1050,6 +1084,21 @@ export async function recursiveAfterRequestHookHandler(

const { retry } = providerOption;

const provider = providerOption.provider ?? '';
const providerConfig = Providers[provider];
const requestHandlers = providerConfig.requestHandlers;
let requestHandler;
if (requestHandlers && requestHandlers[fn]) {
requestHandler = () =>
requestHandlers[fn]({
c,
providerOptions: providerOption,
requestURL: c.req.url,
requestHeaders,
requestBody,
});
}

({
response,
attempt: retryCount,
Expand All @@ -1059,7 +1108,8 @@ export async function recursiveAfterRequestHookHandler(
options,
retry?.attempts || 0,
retry?.onStatusCodes || [],
requestTimeout || null
requestTimeout || null,
requestHandler
));

const {
Expand All @@ -1074,7 +1124,8 @@ export async function recursiveAfterRequestHookHandler(
url,
false,
gatewayParams,
strictOpenAiCompliance
strictOpenAiCompliance,
c.req.url
);

const arhResponse = await afterRequestHookHandler(
Expand Down
3 changes: 2 additions & 1 deletion src/handlers/realtimeHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export async function realTimeHandler(c: Context): Promise<Response> {
const url = getURLForOutgoingConnection(
apiConfig,
providerOptions,
c.req.url
c.req.url,
c
);
const options = await getOptionsForOutgoingConnection(
apiConfig,
Expand Down
5 changes: 3 additions & 2 deletions src/handlers/realtimeHandlerNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ export async function realTimeHandlerNode(
const provider = camelCaseConfig?.provider ?? '';
const apiConfig: ProviderAPIConfig = Providers[provider].api;
const providerOptions = camelCaseConfig as Options;
const baseUrl = apiConfig.getBaseURL({ providerOptions });
const baseUrl = apiConfig.getBaseURL({ providerOptions, c });
const endpoint = apiConfig.getEndpoint({
c,
providerOptions,
fn: 'realtime',
gatewayRequestBody: {},
gatewayRequestBodyJSON: {},
gatewayRequestURL: c.req.url,
});
let url = `${baseUrl}${endpoint}`;
Expand Down
Loading
Loading