diff --git a/src/AzureAIProxy/Middleware/ApiKeyAuthenticationHandler.cs b/src/AzureAIProxy/Middleware/ApiKeyAuthenticationHandler.cs index dca8e366..1f03d630 100644 --- a/src/AzureAIProxy/Middleware/ApiKeyAuthenticationHandler.cs +++ b/src/AzureAIProxy/Middleware/ApiKeyAuthenticationHandler.cs @@ -20,15 +20,13 @@ protected override async Task HandleAuthenticateAsync() var apiKey = apiKeyValues.ToString(); // Convert StringValues to string if (string.IsNullOrWhiteSpace(apiKey)) - return AuthenticateResult.Fail("Missing API key is empty."); + return AuthenticateResult.Fail("API key is empty."); var requestContext = await authorizeService.IsUserAuthorizedAsync(apiKey); if (requestContext is null) return AuthenticateResult.Fail("Authentication failed."); Context.Items["RequestContext"] = requestContext; - Context.Items["RateLimited"] = requestContext.RateLimitExceed; - Context.Items["DailyRequestCap"] = requestContext.DailyRequestCap; var identity = new ClaimsIdentity(null, nameof(ApiKeyAuthenticationHandler)); var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name); diff --git a/src/AzureAIProxy/Middleware/BearerTokenAuthenticationHandler.cs b/src/AzureAIProxy/Middleware/BearerTokenAuthenticationHandler.cs new file mode 100644 index 00000000..1db0c6ae --- /dev/null +++ b/src/AzureAIProxy/Middleware/BearerTokenAuthenticationHandler.cs @@ -0,0 +1,37 @@ +using System.Security.Claims; +using System.Text.Encodings.Web; +using Microsoft.AspNetCore.Authentication; +using Microsoft.Extensions.Options; +using AzureAIProxy.Services; + +namespace AzureAIProxy.Middleware; + +public class BearerTokenAuthenticationHandler( + IOptionsMonitor options, + IAuthorizeService authorizeService, + ILoggerFactory logger, + UrlEncoder encoder +) : AuthenticationHandler(options, logger, encoder) +{ + protected override async Task HandleAuthenticateAsync() + { + if (!Request.Headers.TryGetValue("Authorization", out var apiKeyValues)) + return AuthenticateResult.Fail("Missing API key is empty."); + + // Extract the API key from the Authorization header + var apiKey = apiKeyValues.ToString().Split(" ").Last(); // Convert StringValues to string + if (string.IsNullOrWhiteSpace(apiKey)) + return AuthenticateResult.Fail("API key is empty."); + + var requestContext = await authorizeService.IsUserAuthorizedAsync(apiKey); + if (requestContext is null) + return AuthenticateResult.Fail("Authentication failed."); + + Context.Items["RequestContext"] = requestContext; + + var identity = new ClaimsIdentity(null, nameof(BearerTokenAuthenticationHandler)); + var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name); + + return AuthenticateResult.Success(ticket); + } +} diff --git a/src/AzureAIProxy/Middleware/BearerTokenAuthorizeAttribute.cs b/src/AzureAIProxy/Middleware/BearerTokenAuthorizeAttribute.cs new file mode 100644 index 00000000..d964d83a --- /dev/null +++ b/src/AzureAIProxy/Middleware/BearerTokenAuthorizeAttribute.cs @@ -0,0 +1,12 @@ +using Microsoft.AspNetCore.Authorization; +using AzureAIProxy.Middleware; +using System.Diagnostics; + +namespace AzureAIProxy; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] +[DebuggerDisplay("{ToString(),nq}")] +public class BearerTokenAuthorizeAttribute : AuthorizeAttribute +{ + public BearerTokenAuthorizeAttribute() => AuthenticationSchemes = ProxyAuthenticationOptions.BearerTokenScheme; +} diff --git a/src/AzureAIProxy/Middleware/JwtAuthenticationHandler.cs b/src/AzureAIProxy/Middleware/JwtAuthenticationHandler.cs index 53d10833..49d24119 100644 --- a/src/AzureAIProxy/Middleware/JwtAuthenticationHandler.cs +++ b/src/AzureAIProxy/Middleware/JwtAuthenticationHandler.cs @@ -28,7 +28,7 @@ protected override async Task HandleAuthenticateAsync() Context.Items["RequestContext"] = requestContext; - var identity = new ClaimsIdentity(null, nameof(ApiKeyAuthenticationHandler)); + var identity = new ClaimsIdentity(null, nameof(JwtAuthenticationHandler)); var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name); return AuthenticateResult.Success(ticket); diff --git a/src/AzureAIProxy/Middleware/LoadProperties.cs b/src/AzureAIProxy/Middleware/LoadProperties.cs new file mode 100644 index 00000000..a1c58cf5 --- /dev/null +++ b/src/AzureAIProxy/Middleware/LoadProperties.cs @@ -0,0 +1,68 @@ +using AzureAIProxy.Shared.Database; +using System.Text.Json; + +namespace AzureAIProxy.Middleware; + +public class LoadProperties(RequestDelegate next) +{ + private readonly RequestDelegate _next = next; + + public async Task InvokeAsync(HttpContext context) + { + JsonDocument? jsonDoc = null; + try + { + if (!context.Request.HasFormContentType && + context.Request.ContentType != null && + context.Request.ContentType.Contains("application/json", StringComparison.InvariantCultureIgnoreCase)) + { + using var reader = new StreamReader(context.Request.Body); + string json = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(json)) + { + try + { + jsonDoc = JsonDocument.Parse(json); + } + catch (JsonException) + { + await OpenAIErrorResponse.BadRequest($"Invalid JSON in request body: {json}").WriteAsync(context); + return; + } + + jsonDoc = JsonDocument.Parse(json); + + context.Items["IsStreaming"] = IsStreaming(jsonDoc); + context.Items["ModelName"] = GetModelName(jsonDoc); + } + } + + context.Items["requestPath"]= context.Request.Path.Value!.Split("/api/v1/").Last(); + context.Items["jsonDoc"] = jsonDoc; + + await _next(context); + } + finally + { + jsonDoc?.Dispose(); + } + } + + private static bool IsStreaming(JsonDocument requestJsonDoc) + { + return requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object + && requestJsonDoc.RootElement.TryGetProperty("stream", out JsonElement streamElement) + && ( + streamElement.ValueKind == JsonValueKind.True + || streamElement.ValueKind == JsonValueKind.False + ) + && streamElement.GetBoolean(); + } + + private static string? GetModelName(JsonDocument requestJsonDoc) + { + return requestJsonDoc.RootElement.TryGetProperty("model", out JsonElement modelElement) && modelElement.ValueKind == JsonValueKind.String + ? modelElement.GetString() + : null; + } +} diff --git a/src/AzureAIProxy/Middleware/MaxTokensHandler.cs b/src/AzureAIProxy/Middleware/MaxTokensHandler.cs new file mode 100644 index 00000000..a1a9977a --- /dev/null +++ b/src/AzureAIProxy/Middleware/MaxTokensHandler.cs @@ -0,0 +1,40 @@ +using AzureAIProxy.Shared.Database; +using System.Text.Json; + +namespace AzureAIProxy.Middleware; + +public class MaxTokensHandler(RequestDelegate next) +{ + private readonly RequestDelegate _next = next; + + public async Task InvokeAsync(HttpContext context) + { + RequestContext? requestContext = context.Items["RequestContext"]! as RequestContext; + JsonDocument? jsonDoc = context.Items["jsonDoc"]! as JsonDocument; + + if (requestContext is not null && jsonDoc is not null) + { + int? maxTokens = GetMaxTokens(jsonDoc); + if (maxTokens.HasValue && maxTokens > requestContext.MaxTokenCap && requestContext.MaxTokenCap > 0) + { + await OpenAIErrorResponse.BadRequest( + $"max_tokens exceeds the event max token cap of {requestContext.MaxTokenCap}" + ).WriteAsync(context); + return; + } + } + + await _next(context); + } + + private static int? GetMaxTokens(JsonDocument requestJsonDoc) + { + return + requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object + && requestJsonDoc.RootElement.TryGetProperty("max_tokens", out var maxTokensElement) + && maxTokensElement.ValueKind == JsonValueKind.Number + && maxTokensElement.TryGetInt32(out int maxTokens) + ? maxTokens + : null; + } +} diff --git a/src/AzureAIProxy/Middleware/OpenAIErrorResponse.cs b/src/AzureAIProxy/Middleware/OpenAIErrorResponse.cs new file mode 100644 index 00000000..c218fce5 --- /dev/null +++ b/src/AzureAIProxy/Middleware/OpenAIErrorResponse.cs @@ -0,0 +1,32 @@ +using System.Net; +using Microsoft.AspNetCore.Http.HttpResults; + +namespace AzureAIProxy.Middleware; + +public class OpenAIErrorResponse(string value, HttpStatusCode statusCode) +{ + private readonly JsonHttpResult innerResult = TypedResults.Json( + new OpenAIErrorPayload((int)statusCode, value), + statusCode: (int)statusCode + ); + + public static OpenAIErrorResponse BadRequest(string message) => + new(message, HttpStatusCode.BadRequest); + + public static OpenAIErrorResponse TooManyRequests(string message) => new(message, HttpStatusCode.TooManyRequests); + + public static OpenAIErrorResponse Unauthorized(string message) => + new(message, HttpStatusCode.Unauthorized); + + public async Task WriteAsync(HttpContext httpContext) + { + httpContext.Response.StatusCode = (int)innerResult.StatusCode!; + await httpContext.Response.WriteAsJsonAsync(new + { + code = innerResult.StatusCode, + message = innerResult.Value + }); + } + + record OpenAIErrorPayload(int Code, string Message); +} diff --git a/src/AzureAIProxy/Middleware/ProxyAuthenticationOptions.cs b/src/AzureAIProxy/Middleware/ProxyAuthenticationOptions.cs index 1df1d2d1..48738697 100644 --- a/src/AzureAIProxy/Middleware/ProxyAuthenticationOptions.cs +++ b/src/AzureAIProxy/Middleware/ProxyAuthenticationOptions.cs @@ -6,4 +6,5 @@ public class ProxyAuthenticationOptions : AuthenticationSchemeOptions { public const string ApiKeyScheme = "ApiKeyScheme"; public const string JwtScheme = "JwtScheme"; + public const string BearerTokenScheme = "BearerTokenScheme"; } diff --git a/src/AzureAIProxy/Middleware/RateLimiterHandler.cs b/src/AzureAIProxy/Middleware/RateLimiterHandler.cs index 94e15337..524c29d7 100644 --- a/src/AzureAIProxy/Middleware/RateLimiterHandler.cs +++ b/src/AzureAIProxy/Middleware/RateLimiterHandler.cs @@ -1,23 +1,20 @@ +using AzureAIProxy.Shared.Database; + namespace AzureAIProxy.Middleware; public class RateLimiterHandler(RequestDelegate next) { - private const int RateLimitStatusCode = 429; private readonly RequestDelegate _next = next; public async Task InvokeAsync(HttpContext context) { - if (context.Items.TryGetValue("RateLimited", out var rateLimit) && rateLimit is true) + RequestContext? requestContext = context.Items["RequestContext"] as RequestContext; + + if (requestContext is not null && requestContext.RateLimitExceed) { - var dailyRequestCap = context.Items["DailyRequestCap"] ?? 0; - context.Response.StatusCode = RateLimitStatusCode; // Too Many Requests - await context.Response.WriteAsJsonAsync( - new - { - code = RateLimitStatusCode, - message = $"The event daily request rate of {dailyRequestCap} calls has been exceeded. Requests are disabled until UTC midnight." - } - ); + await OpenAIErrorResponse.TooManyRequests( + $"The event daily request rate of {requestContext.DailyRequestCap} calls has been exceeded. Requests are disabled until UTC midnight." + ).WriteAsync(context); } else { diff --git a/src/AzureAIProxy/Program.cs b/src/AzureAIProxy/Program.cs index 98b24706..c63947d7 100644 --- a/src/AzureAIProxy/Program.cs +++ b/src/AzureAIProxy/Program.cs @@ -18,6 +18,10 @@ .AddScheme( ProxyAuthenticationOptions.JwtScheme, _ => { } + ) + .AddScheme( + ProxyAuthenticationOptions.BearerTokenScheme, + _ => { } ); builder.Services.AddMemoryCache(); @@ -31,6 +35,8 @@ app.UseAuthentication(); app.UseAuthorization(); app.UseMiddleware(); +app.UseMiddleware(); +app.UseMiddleware(); app.MapProxyRoutes(); app.Run(); diff --git a/src/AzureAIProxy/Routes/AzureAIProxy.cs b/src/AzureAIProxy/Routes/AzureAIProxy.cs index 94d3004f..8bced7ba 100644 --- a/src/AzureAIProxy/Routes/AzureAIProxy.cs +++ b/src/AzureAIProxy/Routes/AzureAIProxy.cs @@ -30,64 +30,37 @@ public static RouteGroupBuilder MapAzureAIProxyRoutes(this RouteGroupBuilder bui private static async Task ProcessRequestAsync( [FromServices] ICatalogService catalogService, [FromServices] IProxyService proxyService, - [FromBody] JsonDocument requestJsonDoc, HttpContext context, string deploymentName ) { - using (requestJsonDoc) - { - var requestPath = context.Request.Path.Value!.Split("/api/v1/").Last(); - var requestContext = (RequestContext)context.Items["RequestContext"]!; - - var streaming = IsStreaming(requestJsonDoc); - var maxTokens = GetMaxTokens(requestJsonDoc); + string requestPath = (string)context.Items["requestPath"]!; + RequestContext requestContext = (RequestContext)context.Items["RequestContext"]!; + JsonDocument requestJsonDoc = (JsonDocument)context.Items["jsonDoc"]!; + bool streaming = (bool)context.Items["IsStreaming"]!; - if ( - maxTokens.HasValue - && maxTokens > requestContext.MaxTokenCap - && requestContext.MaxTokenCap > 0 - ) - { - return OpenAIResult.BadRequest( - $"max_tokens exceeds the event max token cap of {requestContext.MaxTokenCap}" - ); - } + var (deployment, eventCatalog) = await catalogService.GetCatalogItemAsync( + requestContext.EventId, + deploymentName! + ); - var (deployment, eventCatalog) = await catalogService.GetCatalogItemAsync( - requestContext.EventId, - deploymentName + if (deployment is null) + { + return OpenAIResult.NotFound( + $"Deployment '{deploymentName}' not found for this event. Available deployments are: {string.Join(", ", eventCatalog.Select(d => d.DeploymentName))}" ); + } - if (deployment is null) - { - return OpenAIResult.NotFound( - $"Deployment '{deploymentName}' not found for this event. Available deployments are: {string.Join(", ", eventCatalog.Select(d => d.DeploymentName))}" - ); - } - - var url = new UriBuilder(deployment.EndpointUrl.TrimEnd('/')) - { - Path = requestPath - }; + var url = new UriBuilder(deployment.EndpointUrl.TrimEnd('/')) + { + Path = requestPath + }; - try + try + { + if (streaming) { - if (streaming) - { - await proxyService.HttpPostStreamAsync( - url, - deployment.EndpointKey, - context, - requestJsonDoc, - requestContext, - deployment - ); - return new ProxyResult(null!, (int)HttpStatusCode.OK); - } - - - var (responseContent, statusCode) = await proxyService.HttpPostAsync( + await proxyService.HttpPostStreamAsync( url, deployment.EndpointKey, context, @@ -95,42 +68,31 @@ await proxyService.HttpPostStreamAsync( requestContext, deployment ); - return new ProxyResult(responseContent, statusCode); - } - catch (TaskCanceledException ex) when (ex.InnerException is System.Net.Sockets.SocketException) - { - return OpenAIResult.ServiceUnavailable("The request was canceled due to timeout. Inner exception: " + ex.InnerException.Message); + return new ProxyResult(null!, (int)HttpStatusCode.OK); } - catch (TaskCanceledException ex) - { - return OpenAIResult.ServiceUnavailable("The request was canceled: " + ex.Message); - } - catch (HttpRequestException ex) - { - return OpenAIResult.ServiceUnavailable("The request failed: " + ex.Message); - } - } - } - private static bool IsStreaming(JsonDocument requestJsonDoc) - { - return requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object - && requestJsonDoc.RootElement.TryGetProperty("stream", out JsonElement streamElement) - && ( - streamElement.ValueKind == JsonValueKind.True - || streamElement.ValueKind == JsonValueKind.False - ) - && streamElement.GetBoolean(); - } - private static int? GetMaxTokens(JsonDocument requestJsonDoc) - { - return - requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object - && requestJsonDoc.RootElement.TryGetProperty("max_tokens", out var maxTokensElement) - && maxTokensElement.ValueKind == JsonValueKind.Number - && maxTokensElement.TryGetInt32(out int maxTokens) - ? maxTokens - : null; + var (responseContent, statusCode) = await proxyService.HttpPostAsync( + url, + deployment.EndpointKey, + context, + requestJsonDoc, + requestContext, + deployment + ); + return new ProxyResult(responseContent, statusCode); + } + catch (TaskCanceledException ex) when (ex.InnerException is System.Net.Sockets.SocketException) + { + return OpenAIResult.ServiceUnavailable("The request was canceled due to timeout. Inner exception: " + ex.InnerException.Message); + } + catch (TaskCanceledException ex) + { + return OpenAIResult.ServiceUnavailable("The request was canceled: " + ex.Message); + } + catch (HttpRequestException ex) + { + return OpenAIResult.ServiceUnavailable("The request failed: " + ex.Message); + } } } diff --git a/src/AzureAIProxy/Routes/AzureOpenAIAssistants.cs b/src/AzureAIProxy/Routes/AzureOpenAIAssistants.cs index 44e50124..431e81d6 100644 --- a/src/AzureAIProxy/Routes/AzureOpenAIAssistants.cs +++ b/src/AzureAIProxy/Routes/AzureOpenAIAssistants.cs @@ -40,7 +40,6 @@ public static RouteGroupBuilder MapAzureOpenAIAssistantsRoutes(this RouteGroupBu /// The proxy service for forwarding requests. /// The assistant service for managing assistant and thread IDs. /// The HTTP context of the request. - /// The optional JSON document in the request body. /// The optional assistant identifier from the route. /// The optional thread identifier from the route. /// An representing the result of the operation. @@ -50,13 +49,13 @@ private static async Task CreateThreadAsync( [FromServices] IProxyService proxyService, [FromServices] IAssistantService assistantService, HttpContext context, - [FromBody] JsonDocument? requestJsonDoc = null, string? assistantId = null, string? threadId = null ) { - var requestPath = context.Request.Path.Value!.Split("/api/v1/").Last(); - var requestContext = (RequestContext)context.Items["RequestContext"]!; + string requestPath = (string)context.Items["requestPath"]!; + RequestContext requestContext = (RequestContext)context.Items["RequestContext"]!; + JsonDocument requestJsonDoc = (JsonDocument)context.Items["jsonDoc"]!; var deployment = await catalogService.GetEventAssistantAsync(requestContext.EventId); if (deployment is null) diff --git a/src/AzureAIProxy/Routes/AzureOpenAIFiles.cs b/src/AzureAIProxy/Routes/AzureOpenAIFiles.cs index ed0223cc..143ef564 100644 --- a/src/AzureAIProxy/Routes/AzureOpenAIFiles.cs +++ b/src/AzureAIProxy/Routes/AzureOpenAIFiles.cs @@ -46,8 +46,8 @@ private static async Task CreateThreadAsync( string? fileId = null ) { - var requestPath = context.Request.Path.Value!.Split("/api/v1/").Last(); - var requestContext = (RequestContext)context.Items["RequestContext"]!; + string requestPath = (string)context.Items["requestPath"]!; + RequestContext requestContext = (RequestContext)context.Items["RequestContext"]!; var deployment = await catalogService.GetEventAssistantAsync(requestContext.EventId); if (deployment is null) diff --git a/src/AzureAIProxy/Routes/OpenAIProxy.cs b/src/AzureAIProxy/Routes/OpenAIProxy.cs new file mode 100644 index 00000000..9623f3da --- /dev/null +++ b/src/AzureAIProxy/Routes/OpenAIProxy.cs @@ -0,0 +1,89 @@ +using System.Net; +using System.Text.Json; +using Microsoft.AspNetCore.Mvc; +using AzureAIProxy.Shared.Database; +using AzureAIProxy.Routes.CustomResults; +using AzureAIProxy.Services; + +namespace AzureAIProxy.Routes; + +public static class OpenAIAIProxy +{ + public static RouteGroupBuilder MapOpenAIProxyRoutes(this RouteGroupBuilder builder) + { + // OpenAI Routes for Mistral chat completions compatibity + builder.MapPost("/chat/completions", ProcessRequestAsync); + + return builder; + } + + [BearerTokenAuthorize] + private static async Task ProcessRequestAsync( + [FromServices] ICatalogService catalogService, + [FromServices] IProxyService proxyService, + HttpContext context + ) + { + string requestPath = (string)context.Items["requestPath"]!; + RequestContext requestContext = (RequestContext)context.Items["RequestContext"]!; + JsonDocument requestJsonDoc = (JsonDocument)context.Items["jsonDoc"]!; + bool streaming = (bool)context.Items["IsStreaming"]!; + string deploymentName = (string)context.Items["ModelName"]!; + + var (deployment, eventCatalog) = await catalogService.GetCatalogItemAsync( + requestContext.EventId, + deploymentName! + ); + + if (deployment is null) + { + return OpenAIResult.NotFound( + $"Deployment '{deploymentName}' not found for this event. Available deployments are: {string.Join(", ", eventCatalog.Select(d => d.DeploymentName))}" + ); + } + + var url = new UriBuilder(deployment.EndpointUrl.TrimEnd('/')) + { + Path = requestPath + }; + + try + { + if (streaming) + { + await proxyService.HttpPostStreamAsync( + url, + deployment.EndpointKey, + context, + requestJsonDoc, + requestContext, + deployment + ); + return new ProxyResult(null!, (int)HttpStatusCode.OK); + } + + + var (responseContent, statusCode) = await proxyService.HttpPostAsync( + url, + deployment.EndpointKey, + context, + requestJsonDoc, + requestContext, + deployment + ); + return new ProxyResult(responseContent, statusCode); + } + catch (TaskCanceledException ex) when (ex.InnerException is System.Net.Sockets.SocketException) + { + return OpenAIResult.ServiceUnavailable("The request was canceled due to timeout. Inner exception: " + ex.InnerException.Message); + } + catch (TaskCanceledException ex) + { + return OpenAIResult.ServiceUnavailable("The request was canceled: " + ex.Message); + } + catch (HttpRequestException ex) + { + return OpenAIResult.ServiceUnavailable("The request failed: " + ex.Message); + } + } +} diff --git a/src/AzureAIProxy/Routes/ProxyRoutes.cs b/src/AzureAIProxy/Routes/ProxyRoutes.cs index 2865598f..baba4e80 100644 --- a/src/AzureAIProxy/Routes/ProxyRoutes.cs +++ b/src/AzureAIProxy/Routes/ProxyRoutes.cs @@ -8,6 +8,7 @@ public static IEndpointRouteBuilder MapProxyRoutes(this IEndpointRouteBuilder bu .MapAttendeeRoutes() .MapEventRoutes() .MapAzureAIProxyRoutes() + .MapOpenAIProxyRoutes() .MapAzureOpenAIAssistantsRoutes() .MapAzureOpenAIFilesRoutes(); } diff --git a/src/AzureAIProxy/Services/ProxyService.cs b/src/AzureAIProxy/Services/ProxyService.cs index ca2ed4c4..7c44465d 100644 --- a/src/AzureAIProxy/Services/ProxyService.cs +++ b/src/AzureAIProxy/Services/ProxyService.cs @@ -192,6 +192,7 @@ Deployment deployment "application/json" ); requestMessage.Headers.Add("api-key", endpointKey); + requestMessage.Headers.Add("Authorization", $"Bearer {endpointKey}"); var response = await httpClient.SendAsync(requestMessage); var responseContent = await response.Content.ReadAsStringAsync();