Skip to content

Commit

Permalink
feat: Implement user limits #82
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay committed Feb 1, 2024
1 parent 79ed1b2 commit 87d4627
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/main/java/com/epam/aidial/core/config/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class Config {
private Assistants assistant = new Assistants();
private Map<String, Key> keys = new HashMap<>();
private Map<String, Role> roles = new HashMap<>();
private Map<String, Role> userRoles = new HashMap<>();


public Deployment selectDeployment(String deploymentId) {
Expand Down
41 changes: 35 additions & 6 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.util.List;
import java.util.Map;

@Slf4j
@RequiredArgsConstructor
public class RateLimiter {
Expand All @@ -32,10 +35,7 @@ public Future<Void> increase(ProxyContext context) {
if (resourceService == null) {
return Future.succeededFuture();
}
Key key = context.getKey();
if (key == null) {
return Future.succeededFuture();
}

Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();

Expand All @@ -59,8 +59,7 @@ public Future<RateLimitResult> limit(ProxyContext context) {
Key key = context.getKey();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture(RateLimitResult.SUCCESS);
limit = getLimitByUser(context);
} else {
limit = getLimitByApiKey(context);
}
Expand Down Expand Up @@ -130,6 +129,36 @@ private Limit getLimitByApiKey(ProxyContext context) {
return role.getLimits().get(deployment.getName());
}

private Limit getLimitByUser(ProxyContext context) {
List<String> userRoles = context.getUserRoles();
String deploymentName = context.getDeployment().getName();
Map<String, Role> userRoleToDeploymentLimits = context.getConfig().getUserRoles();
long minuteLimit = 0;
long dayLimit = 0;
for (String userRole : userRoles) {
Role role = userRoleToDeploymentLimits.get(userRole);
if (role == null) {
continue;
}
Limit limit = role.getLimits().get(deploymentName);
if (limit == null) {
continue;
}
minuteLimit = Math.max(minuteLimit, limit.getMinute());
dayLimit = Math.max(dayLimit, limit.getDay());
}
if (minuteLimit == 0) {
minuteLimit = Long.MAX_VALUE;
}
if (dayLimit == 0) {
dayLimit = Long.MAX_VALUE;
}
Limit limit = new Limit();
limit.setMinute(minuteLimit);
limit.setDay(dayLimit);
return limit;
}

private static String getPath(String deploymentName) {
return String.format("%s/tokens", deploymentName);
}
Expand Down
110 changes: 98 additions & 12 deletions src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import redis.embedded.RedisServer;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;

Expand Down Expand Up @@ -127,17 +127,6 @@ public void testLimit_EntityNotFound() {
assertEquals(HttpStatus.FORBIDDEN, result.result().status());
}

@Test
public void testLimit_SuccessUser() {
ProxyContext proxyContext = new ProxyContext(new Config(), request, new ApiKeyData(), new ExtractedClaims("sub", Collections.emptyList(), "hash"), "trace-id", "span-id");

Future<RateLimitResult> result = rateLimiter.limit(proxyContext);

assertNotNull(result);
assertNotNull(result.result());
assertEquals(HttpStatus.OK, result.result().status());
}

@Test
public void testLimit_ApiKeyLimitNotFound() {
Key key = new Key();
Expand Down Expand Up @@ -280,4 +269,101 @@ public void testLimit_ApiKeySuccess_KeyExist() {

}

@Test
public void testLimit_User_LimitFound() {
Config config = new Config();

Role role1 = new Role();
Limit limit = new Limit();
limit.setDay(10000);
limit.setMinute(100);
role1.setLimits(Map.of("model", limit));

Role role2 = new Role();
limit = new Limit();
limit.setDay(20000);
limit.setMinute(200);
role2.setLimits(Map.of("model", limit));

config.getUserRoles().put("role1", role1);
config.getUserRoles().put("role2", role2);

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setTotalTokens(150);
proxyContext.setTokenUsage(tokenUsage);

Future<Void> increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

Future<RateLimitResult> checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.OK, checkLimitFuture.result().status());

increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.TOO_MANY_REQUESTS, checkLimitFuture.result().status());

}

@Test
public void testLimit_User_LimitNotFound() {
Config config = new Config();

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1"), "user-hash"), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setTotalTokens(90);
proxyContext.setTokenUsage(tokenUsage);

Future<Void> increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

Future<RateLimitResult> checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.OK, checkLimitFuture.result().status());

increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.OK, checkLimitFuture.result().status());

}

}

0 comments on commit 87d4627

Please sign in to comment.