diff --git a/src/middlewares/requestValidator/index.ts b/src/middlewares/requestValidator/index.ts index 7707c94d0..e44a7942b 100644 --- a/src/middlewares/requestValidator/index.ts +++ b/src/middlewares/requestValidator/index.ts @@ -1,9 +1,44 @@ import { Context } from 'hono'; import { CONTENT_TYPES, POWERED_BY, VALID_PROVIDERS } from '../../globals'; import { configSchema } from './schema/config'; +import Providers from '../../providers'; -export const requestValidator = (c: Context, next: any) => { +const providers = Object.keys(Providers); + +export const requestValidator = async (c: Context, next: any) => { const requestHeaders = Object.fromEntries(c.req.raw.headers); + const clonedReq = c.req.raw.clone(); + if (clonedReq.method === 'GET') { + const provider = new URLSearchParams(new URL(clonedReq.url).search).get( + 'provider' + ); + if (provider && providers.includes(provider)) { + requestHeaders[`x-${POWERED_BY}-provider`] = provider; + const newHeaders = new Headers(requestHeaders); + const newRequest = new Request(c.req.raw.url, { + method: c.req.raw.method, + headers: newHeaders, + }); + + c.req.raw = newRequest; + } + } else { + const originalBody = await clonedReq.text(); + const modifiedBody = JSON.parse(originalBody); + const [provider, ...modelNames] = modifiedBody.model.split(':'); + if (providers.includes(provider)) { + requestHeaders[`x-${POWERED_BY}-provider`] = provider; + modifiedBody.model = modelNames.join(':') || undefined; + const newHeaders = new Headers(requestHeaders); + const newRequest = new Request(c.req.raw.url, { + method: c.req.raw.method, + headers: newHeaders, + body: JSON.stringify(modifiedBody), + }); + + c.req.raw = newRequest; + } + } const contentType = requestHeaders['content-type']; if ( @@ -151,5 +186,5 @@ export const requestValidator = (c: Context, next: any) => { ); } } - return next(); + return await next(); };