Skip to content

Commit

Permalink
Merge pull request #339 from microsoft/mistral-support-and-extended-m…
Browse files Browse the repository at this point in the history
…iddleware

Extended middleware and Mistral Chat Completions support
  • Loading branch information
gloveboxes authored Sep 27, 2024
2 parents b36f741 + 1da801e commit 66d6888
Show file tree
Hide file tree
Showing 27 changed files with 675 additions and 336 deletions.
233 changes: 113 additions & 120 deletions database/aoai-proxy.sql

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions src/AzureAIProxy/Middleware/ApiKeyAuthenticationHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ protected override async Task<AuthenticateResult> HandleAuthenticateAsync()

var apiKey = apiKeyValues.ToString(); // Convert StringValues to string
if (string.IsNullOrWhiteSpace(apiKey))
return AuthenticateResult.Fail("Missing API key is empty.");
return AuthenticateResult.Fail("API key is empty.");

var requestContext = await authorizeService.IsUserAuthorizedAsync(apiKey);
if (requestContext is null)
return AuthenticateResult.Fail("Authentication failed.");

Context.Items["RequestContext"] = requestContext;
Context.Items["RateLimited"] = requestContext.RateLimitExceed;
Context.Items["DailyRequestCap"] = requestContext.DailyRequestCap;

var identity = new ClaimsIdentity(null, nameof(ApiKeyAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name);
Expand Down
37 changes: 37 additions & 0 deletions src/AzureAIProxy/Middleware/BearerTokenAuthenticationHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using System.Security.Claims;
using System.Text.Encodings.Web;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Options;
using AzureAIProxy.Services;

namespace AzureAIProxy.Middleware;

public class BearerTokenAuthenticationHandler(
IOptionsMonitor<ProxyAuthenticationOptions> options,
IAuthorizeService authorizeService,
ILoggerFactory logger,
UrlEncoder encoder
) : AuthenticationHandler<ProxyAuthenticationOptions>(options, logger, encoder)
{
protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
{
if (!Request.Headers.TryGetValue("Authorization", out var apiKeyValues))
return AuthenticateResult.Fail("Missing API key is empty.");

// Extract the API key from the Authorization header
var apiKey = apiKeyValues.ToString().Split(" ").Last(); // Convert StringValues to string
if (string.IsNullOrWhiteSpace(apiKey))
return AuthenticateResult.Fail("API key is empty.");

var requestContext = await authorizeService.IsUserAuthorizedAsync(apiKey);
if (requestContext is null)
return AuthenticateResult.Fail("Authentication failed.");

Context.Items["RequestContext"] = requestContext;

var identity = new ClaimsIdentity(null, nameof(BearerTokenAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name);

return AuthenticateResult.Success(ticket);
}
}
12 changes: 12 additions & 0 deletions src/AzureAIProxy/Middleware/BearerTokenAuthorizeAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Microsoft.AspNetCore.Authorization;
using AzureAIProxy.Middleware;
using System.Diagnostics;

namespace AzureAIProxy;

[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
[DebuggerDisplay("{ToString(),nq}")]
public class BearerTokenAuthorizeAttribute : AuthorizeAttribute
{
public BearerTokenAuthorizeAttribute() => AuthenticationSchemes = ProxyAuthenticationOptions.BearerTokenScheme;
}
2 changes: 1 addition & 1 deletion src/AzureAIProxy/Middleware/JwtAuthenticationHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected override async Task<AuthenticateResult> HandleAuthenticateAsync()

Context.Items["RequestContext"] = requestContext;

var identity = new ClaimsIdentity(null, nameof(ApiKeyAuthenticationHandler));
var identity = new ClaimsIdentity(null, nameof(JwtAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name);

return AuthenticateResult.Success(ticket);
Expand Down
68 changes: 68 additions & 0 deletions src/AzureAIProxy/Middleware/LoadProperties.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using AzureAIProxy.Shared.Database;
using System.Text.Json;

namespace AzureAIProxy.Middleware;

public class LoadProperties(RequestDelegate next)
{
public async Task InvokeAsync(HttpContext context)
{
JsonDocument? jsonDoc = null;
try
{
if (!context.Request.HasFormContentType &&
context.Request.ContentType is not null &&
context.Request.ContentType.Contains("application/json", StringComparison.InvariantCultureIgnoreCase))
{
using var reader = new StreamReader(context.Request.Body);
string json = await reader.ReadToEndAsync();
if (!string.IsNullOrWhiteSpace(json))
{
try
{
jsonDoc = JsonDocument.Parse(json);
}
catch (JsonException)
{
await OpenAIErrorResponse.BadRequest($"Invalid JSON in request body: {json}").WriteAsync(context);
return;
}

jsonDoc = JsonDocument.Parse(json);

context.Items["IsStreaming"] = IsStreaming(jsonDoc);
context.Items["ModelName"] = GetModelName(jsonDoc);
}
}

context.Items["requestPath"] = context.Request.Path.Value!.Split("/api/v1/").Last();
context.Items["jsonDoc"] = jsonDoc;

await next(context);
}
finally
{
jsonDoc?.Dispose();
}
}

private static bool IsStreaming(JsonDocument requestJsonDoc)
{
return requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object
&& requestJsonDoc.RootElement.TryGetProperty("stream", out JsonElement streamElement)
&& (
streamElement.ValueKind == JsonValueKind.True
|| streamElement.ValueKind == JsonValueKind.False
)
&& streamElement.GetBoolean();
}

private static string? GetModelName(JsonDocument requestJsonDoc)
{
return requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object &&
requestJsonDoc.RootElement.TryGetProperty("model", out JsonElement modelElement) &&
modelElement.ValueKind == JsonValueKind.String
? modelElement.GetString()
: null;
}
}
38 changes: 38 additions & 0 deletions src/AzureAIProxy/Middleware/MaxTokensHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using AzureAIProxy.Shared.Database;
using System.Text.Json;

namespace AzureAIProxy.Middleware;

public class MaxTokensHandler(RequestDelegate next)
{
public async Task InvokeAsync(HttpContext context)
{
RequestContext? requestContext = context.Items["RequestContext"]! as RequestContext;
JsonDocument? jsonDoc = context.Items["jsonDoc"]! as JsonDocument;

if (requestContext is not null && jsonDoc is not null)
{
int? maxTokens = GetMaxTokens(jsonDoc);
if (maxTokens.HasValue && maxTokens > requestContext.MaxTokenCap && requestContext.MaxTokenCap > 0)
{
await OpenAIErrorResponse.BadRequest(
$"max_tokens exceeds the event max token cap of {requestContext.MaxTokenCap}"
).WriteAsync(context);
return;
}
}

await next(context);
}

private static int? GetMaxTokens(JsonDocument requestJsonDoc)
{
return
requestJsonDoc.RootElement.ValueKind == JsonValueKind.Object
&& requestJsonDoc.RootElement.TryGetProperty("max_tokens", out var maxTokensElement)
&& maxTokensElement.ValueKind == JsonValueKind.Number
&& maxTokensElement.TryGetInt32(out int maxTokens)
? maxTokens
: null;
}
}
31 changes: 31 additions & 0 deletions src/AzureAIProxy/Middleware/OpenAIErrorResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System.Net;
using Microsoft.AspNetCore.Http.HttpResults;

namespace AzureAIProxy.Middleware;

public class OpenAIErrorResponse(string message, HttpStatusCode statusCode)
{
private readonly JsonHttpResult<ErrorResponse> innerResult = TypedResults.Json(
new ErrorResponse(new ErrorDetails(statusCode.ToString(), message, (int)statusCode)),
statusCode: (int)statusCode
);

public static OpenAIErrorResponse BadRequest(string message) =>
new(message, HttpStatusCode.BadRequest);

public static OpenAIErrorResponse TooManyRequests(string message) =>
new(message, HttpStatusCode.TooManyRequests);

public static OpenAIErrorResponse Unauthorized(string message) =>
new(message, HttpStatusCode.Unauthorized);

public async Task WriteAsync(HttpContext httpContext)
{
httpContext.Response.StatusCode = (int)innerResult.StatusCode!;
await httpContext.Response.WriteAsJsonAsync(innerResult.Value);
}

record ErrorDetails(string Code, string Message, int Status);

record ErrorResponse(ErrorDetails Error);
}
1 change: 1 addition & 0 deletions src/AzureAIProxy/Middleware/ProxyAuthenticationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ public class ProxyAuthenticationOptions : AuthenticationSchemeOptions
{
public const string ApiKeyScheme = "ApiKeyScheme";
public const string JwtScheme = "JwtScheme";
public const string BearerTokenScheme = "BearerTokenScheme";
}
23 changes: 9 additions & 14 deletions src/AzureAIProxy/Middleware/RateLimiterHandler.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
using AzureAIProxy.Shared.Database;

namespace AzureAIProxy.Middleware;

public class RateLimiterHandler(RequestDelegate next)
{
private const int RateLimitStatusCode = 429;
private readonly RequestDelegate _next = next;

public async Task InvokeAsync(HttpContext context)
{
if (context.Items.TryGetValue("RateLimited", out var rateLimit) && rateLimit is true)
RequestContext? requestContext = context.Items["RequestContext"] as RequestContext;

if (requestContext is not null && requestContext.RateLimitExceed)
{
var dailyRequestCap = context.Items["DailyRequestCap"] ?? 0;
context.Response.StatusCode = RateLimitStatusCode; // Too Many Requests
await context.Response.WriteAsJsonAsync(
new
{
code = RateLimitStatusCode,
message = $"The event daily request rate of {dailyRequestCap} calls has been exceeded. Requests are disabled until UTC midnight."
}
);
await OpenAIErrorResponse.TooManyRequests(
$"The event daily request rate of {requestContext.DailyRequestCap} calls has been exceeded. Requests are disabled until UTC midnight."
).WriteAsync(context);
}
else
{
await _next(context);
await next(context);
}
}
}
2 changes: 1 addition & 1 deletion src/AzureAIProxy/Models/AssistantResponse.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System.Text.Json.Serialization;

namespace AzureAIProxy.Services;
namespace AzureAIProxy.Models;

public class AssistantResponse
{
Expand Down
2 changes: 1 addition & 1 deletion src/AzureAIProxy/Models/AttendeeKey.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using System.Text.Json.Serialization;

namespace AzureAIProxy.Services;
namespace AzureAIProxy.Models;

public record AttendeeKey([property: JsonPropertyName("api_key")] string ApiKey, bool Active);
7 changes: 7 additions & 0 deletions src/AzureAIProxy/Models/AuthHeader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace AzureAIProxy.Models;

public class RequestHeader(string key, string value)
{
public string Key { get; set; } = key;
public string Value { get; set; } = value;
}
6 changes: 6 additions & 0 deletions src/AzureAIProxy/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
.AddScheme<ProxyAuthenticationOptions, JwtAuthenticationHandler>(
ProxyAuthenticationOptions.JwtScheme,
_ => { }
)
.AddScheme<ProxyAuthenticationOptions, BearerTokenAuthenticationHandler>(
ProxyAuthenticationOptions.BearerTokenScheme,
_ => { }
);

builder.Services.AddMemoryCache();
Expand All @@ -31,6 +35,8 @@
app.UseAuthentication();
app.UseAuthorization();
app.UseMiddleware<RateLimiterHandler>();
app.UseMiddleware<LoadProperties>();
app.UseMiddleware<MaxTokensHandler>();
app.MapProxyRoutes();

app.Run();
Loading

0 comments on commit 66d6888

Please sign in to comment.