Skip to content

Commit

Permalink
Add REST API changes for backend rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
SavinduDimal committed Sep 19, 2024
1 parent 20c8634 commit 205323c
Show file tree
Hide file tree
Showing 14 changed files with 494 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@

public class TokenBaseThrottlingCountHolder {

private Long productionMaxPromptTokenCount = -1L;
private Long productionMaxCompletionTokenCount = -1L;
private Long productionMaxTotalTokenCount = -1L;
private Long sandboxMaxPromptTokenCount = -1L;
private Long sandboxMaxCompletionTokenCount = -1L;
private Long sandboxMaxTotalTokenCount = -1L;
private boolean isTokenBasedThrottlingEnabled = false;
private String productionMaxPromptTokenCount;
private String productionMaxCompletionTokenCount;
private String productionMaxTotalTokenCount;
private String sandboxMaxPromptTokenCount;
private String sandboxMaxCompletionTokenCount;
private String sandboxMaxTotalTokenCount;
private Boolean isTokenBasedThrottlingEnabled = false;

public TokenBaseThrottlingCountHolder() {

}

public TokenBaseThrottlingCountHolder(Long productionMaxPromptTokenCount, Long productionMaxCompletionTokenCount,
Long productionMaxTotalTokenCount, Long sandboxMaxPromptTokenCount,
Long sandboxMaxCompletionTokenCount, Long sandboxMaxTotalTokenCount,
public TokenBaseThrottlingCountHolder(String productionMaxPromptTokenCount, String productionMaxCompletionTokenCount,
String productionMaxTotalTokenCount, String sandboxMaxPromptTokenCount,
String sandboxMaxCompletionTokenCount, String sandboxMaxTotalTokenCount,
boolean isTokenBasedThrottlingEnabled) {
this.productionMaxPromptTokenCount = productionMaxPromptTokenCount;
this.productionMaxCompletionTokenCount = productionMaxCompletionTokenCount;
Expand All @@ -45,59 +45,59 @@ public TokenBaseThrottlingCountHolder(Long productionMaxPromptTokenCount, Long p
this.isTokenBasedThrottlingEnabled = isTokenBasedThrottlingEnabled;
}

public Long getProductionMaxPromptTokenCount() {
public String getProductionMaxPromptTokenCount() {
return productionMaxPromptTokenCount;
}

public void setProductionMaxPromptTokenCount(Long productionMaxPromptTokenCount) {
public void setProductionMaxPromptTokenCount(String productionMaxPromptTokenCount) {
this.productionMaxPromptTokenCount = productionMaxPromptTokenCount;
}

public Long getProductionMaxCompletionTokenCount() {
public String getProductionMaxCompletionTokenCount() {
return productionMaxCompletionTokenCount;
}

public void setProductionMaxCompletionTokenCount(Long productionMaxCompletionTokenCount) {
public void setProductionMaxCompletionTokenCount(String productionMaxCompletionTokenCount) {
this.productionMaxCompletionTokenCount = productionMaxCompletionTokenCount;
}

public Long getProductionMaxTotalTokenCount() {
public String getProductionMaxTotalTokenCount() {
return productionMaxTotalTokenCount;
}

public void setProductionMaxTotalTokenCount(Long productionMaxTotalTokenCount) {
public void setProductionMaxTotalTokenCount(String productionMaxTotalTokenCount) {
this.productionMaxTotalTokenCount = productionMaxTotalTokenCount;
}

public Long getSandboxMaxPromptTokenCount() {
public String getSandboxMaxPromptTokenCount() {
return sandboxMaxPromptTokenCount;
}

public void setSandboxMaxPromptTokenCount(Long sandboxMaxPromptTokenCount) {
public void setSandboxMaxPromptTokenCount(String sandboxMaxPromptTokenCount) {
this.sandboxMaxPromptTokenCount = sandboxMaxPromptTokenCount;
}

public Long getSandboxMaxCompletionTokenCount() {
public String getSandboxMaxCompletionTokenCount() {
return sandboxMaxCompletionTokenCount;
}

public void setSandboxMaxCompletionTokenCount(Long sandboxMaxCompletionTokenCount) {
public void setSandboxMaxCompletionTokenCount(String sandboxMaxCompletionTokenCount) {
this.sandboxMaxCompletionTokenCount = sandboxMaxCompletionTokenCount;
}

public Long getSandboxMaxTotalTokenCount() {
public String getSandboxMaxTotalTokenCount() {
return sandboxMaxTotalTokenCount;
}

public void setSandboxMaxTotalTokenCount(Long sandboxMaxTotalTokenCount) {
public void setSandboxMaxTotalTokenCount(String sandboxMaxTotalTokenCount) {
this.sandboxMaxTotalTokenCount = sandboxMaxTotalTokenCount;
}

public boolean isTokenBasedThrottlingEnabled() {
public Boolean isTokenBasedThrottlingEnabled() {
return isTokenBasedThrottlingEnabled;
}

public void setTokenBasedThrottlingEnabled(boolean tokenBasedThrottlingEnabled) {
public void setTokenBasedThrottlingEnabled(Boolean tokenBasedThrottlingEnabled) {
isTokenBasedThrottlingEnabled = tokenBasedThrottlingEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ public class API implements Serializable {

// Used for keeping Production & Sandbox Throttling limits.
private String productionMaxTps;
private String productionTimeUnit = "1000";
private String sandboxMaxTps;
private String sandboxTimeUnit = "1000";

private String visibility;
private String visibleRoles;
Expand Down Expand Up @@ -484,10 +486,26 @@ public void setProductionMaxTps(String productionMaxTps) {
this.productionMaxTps = productionMaxTps;
}

public String getProductionTimeUnit() {
return productionTimeUnit;
}

public void setProductionTimeUnit(String productionTimeUnit) {
this.productionTimeUnit = productionTimeUnit;
}

public String getSandboxMaxTps() {
return sandboxMaxTps;
}

public String getSandboxTimeUnit() {
return sandboxTimeUnit;
}

public void setSandboxTimeUnit(String sandboxTimeUnit) {
this.sandboxTimeUnit = sandboxTimeUnit;
}

public void setSandboxMaxTps(String sandboxMaxTps) {
this.sandboxMaxTps = sandboxMaxTps;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ public class ThrottleHandler extends AbstractHandler implements ManagedLifecycle
private String productionUnitTime = "1000";
private String sandboxMaxCount;
private String productionMaxCount;
private String isTokenBasedThrottlingEnabled = "false";
private String productionMaxPromptTokenCount;
private String productionMaxCompletionTokenCount;
private String productionMaxTotalTokenCount;
private String sandboxMaxPromptTokenCount;
private String sandboxMaxCompletionTokenCount;
private String sandboxMaxTotalTokenCount;
private RoleBasedAccessRateController roleBasedAccessController;

public ThrottleHandler() {
Expand Down Expand Up @@ -1288,6 +1295,63 @@ public void setProductionUnitTime(String productionUnitTime) {
this.productionUnitTime = productionUnitTime;
}

public String getIsTokenBasedThrottlingEnabled() {
return isTokenBasedThrottlingEnabled;
}

public void setIsTokenBasedThrottlingEnabled(String isTokenBasedThrottlingEnabled) {
this.isTokenBasedThrottlingEnabled = isTokenBasedThrottlingEnabled;
}

public String getProductionMaxPromptTokenCount() {
return productionMaxPromptTokenCount;
}

public void setProductionMaxPromptTokenCount(String productionMaxPromptTokenCount) {
this.productionMaxPromptTokenCount = productionMaxPromptTokenCount;
}

public String getProductionMaxCompletionTokenCount() {
return productionMaxCompletionTokenCount;
}

public void setProductionMaxCompletionTokenCount(String productionMaxCompletionTokenCount) {
this.productionMaxCompletionTokenCount = productionMaxCompletionTokenCount;
}

public String getProductionMaxTotalTokenCount() {
return productionMaxTotalTokenCount;
}

public void setProductionMaxTotalTokenCount(String productionMaxTotalTokenCount) {
this.productionMaxTotalTokenCount = productionMaxTotalTokenCount;
}

public String getSandboxMaxPromptTokenCount() {
return sandboxMaxPromptTokenCount;
}

public void setSandboxMaxPromptTokenCount(String sandboxMaxPromptTokenCount) {
this.sandboxMaxPromptTokenCount = sandboxMaxPromptTokenCount;
}

public String getSandboxMaxCompletionTokenCount() {
return sandboxMaxCompletionTokenCount;
}

public void setSandboxMaxCompletionTokenCount(String sandboxMaxCompletionTokenCount) {
this.sandboxMaxCompletionTokenCount = sandboxMaxCompletionTokenCount;
}

public String getSandboxMaxTotalTokenCount() {
return sandboxMaxTotalTokenCount;
}

public void setSandboxMaxTotalTokenCount(String sandboxMaxTotalTokenCount) {
this.sandboxMaxTotalTokenCount = sandboxMaxTotalTokenCount;
}


public void init(SynapseEnvironment synapseEnvironment) {
initThrottleForHardLimitThrottling();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ public final class APIConstants {
public static final String PROTOTYPE_OVERVIEW_IMPLEMENTATION = "overview_implementation";
public static final String API_PRODUCTION_THROTTLE_MAXTPS = "overview_productionTps";
public static final String API_SANDBOX_THROTTLE_MAXTPS = "overview_sandboxTps";
public static final String API_BACKEND_THROTTLE_TIMEUNIT_SECOND = "SECOND";
public static final String API_BACKEND_THROTTLE_TIMEUNIT_SECOND_MS = "1000";
public static final String API_BACKEND_THROTTLE_TIMEUNIT_MINUTE = "MINUTE";
public static final String API_BACKEND_THROTTLE_TIMEUNIT_MINUTE_MS = "60000";
public static final String API_BACKEND_THROTTLE_TIMEUNIT_HOUR = "HOUR";
public static final String API_BACKEND_THROTTLE_TIMEUNIT_HOUR_MS = "3600000";

public static final String IMPLEMENTATION_TYPE_ENDPOINT = "ENDPOINT";
public static final String IMPLEMENTATION_TYPE_INLINE = "INLINE";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,16 @@ public final class APIConstants {

public static final String API_OVERVIEW_TESTKEY = "overview_testKey";
public static final String API_PRODUCTION_THROTTLE_MAXTPS = "overview_productionTps";
public static final String API_PRODUCTION_THROTTLE_TIMEUNIT = "overview_productionTimeUnit";
public static final String API_SANDBOX_THROTTLE_MAXTPS = "overview_sandboxTps";
public static final String API_SANDBOX_THROTTLE_TIMEUNIT = "overview_sandboxTimeUnit";
public static final String AI_PRODUCTION_MAX_PROMPT_TOKEN_COUNT = "overview_aiProductionMaxPromptTokenCount";
public static final String AI_PRODUCTION_MAX_COMPLETION_TOKEN_COUNT = "overview_aiProductionMaxCompletionTokenCount";
public static final String AI_PRODUCTION_MAX_TOTAL_TOKEN_COUNT = "overview_aiProductionMaxTotalTokenCount";
public static final String AI_SANDBOX_MAX_PROMPT_TOKEN_COUNT = "overview_aiSandboxMaxPromptTokenCount";
public static final String AI_SANDBOX_MAX_COMPLETION_TOKEN_COUNT = "overview_aiSandboxMaxCompletionTokenCount";
public static final String AI_SANDBOX_MAX_TOTAL_TOKEN_COUNT = "overview_aiSandboxMaxTotalTokenCount";
public static final String AI_TOKEN_BASED_THROTTLING_ENABLED = "overview_aiTokenBasedThrottlingEnabled";
public static final String SUPER_TENANT_DOMAIN = "carbon.super";
public static final String VERSION_PLACEHOLDER = "{version}";
public static final String TENANT_PREFIX = "/t/";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ public class PublisherAPI extends PublisherAPIInfo {
private String subscriptionAvailableOrgs; // subscriptionAvailableTenants;
private String implementation;
private String productionMaxTps;
private String productionTimeUnit;
private String sandboxMaxTps;
private String sandboxTimeUnit;
private String authorizationHeader;
private String apiKeyHeader;
private String apiSecurity; // ?check whether same to private List<String> securityScheme = new ArrayList<>();
Expand Down Expand Up @@ -362,6 +364,23 @@ public void setSandboxMaxTps(String sandboxMaxTps) {
this.sandboxMaxTps = sandboxMaxTps;
}


public String getProductionTimeUnit() {
return productionTimeUnit;
}

public void setProductionTimeUnit(String productionTimeUnit) {
this.productionTimeUnit = productionTimeUnit;
}

public String getSandboxTimeUnit() {
return sandboxTimeUnit;
}

public void setSandboxTimeUnit(String sandboxTimeUnit) {
this.sandboxTimeUnit = sandboxTimeUnit;
}

public String getAuthorizationHeader() {
return authorizationHeader;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.json.simple.parser.ParseException;
import org.wso2.carbon.CarbonConstants;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.TokenBaseThrottlingCountHolder;
import org.wso2.carbon.apimgt.api.model.AIConfiguration;
import org.wso2.carbon.apimgt.api.model.API;
import org.wso2.carbon.apimgt.api.model.APICategory;
import org.wso2.carbon.apimgt.api.model.APIIdentifier;
Expand Down Expand Up @@ -167,7 +169,9 @@ public static GenericArtifact createAPIArtifactContent(GenericArtifact artifact,
artifact.setAttribute(APIConstants.PROTOTYPE_OVERVIEW_IMPLEMENTATION, api.getImplementation());

artifact.setAttribute(APIConstants.API_PRODUCTION_THROTTLE_MAXTPS, api.getProductionMaxTps());
artifact.setAttribute(APIConstants.API_PRODUCTION_THROTTLE_TIMEUNIT, api.getProductionTimeUnit());
artifact.setAttribute(APIConstants.API_SANDBOX_THROTTLE_MAXTPS, api.getSandboxMaxTps());
artifact.setAttribute(APIConstants.API_SANDBOX_THROTTLE_TIMEUNIT, api.getSandboxTimeUnit());
artifact.setAttribute(APIConstants.API_OVERVIEW_AUTHORIZATION_HEADER, api.getAuthorizationHeader());
artifact.setAttribute(APIConstants.API_OVERVIEW_API_KEY_HEADER, api.getApiKeyHeader());
artifact.setAttribute(APIConstants.API_OVERVIEW_API_SECURITY, api.getApiSecurity());
Expand Down Expand Up @@ -206,6 +210,40 @@ public static GenericArtifact createAPIArtifactContent(GenericArtifact artifact,
policyBuilder.append("||");
}

if (api.getAiConfiguration() != null
&& api.getAiConfiguration().getTokenBasedThrottlingConfiguration() != null
&& api.getAiConfiguration().getTokenBasedThrottlingConfiguration()
.isTokenBasedThrottlingEnabled()) {
TokenBaseThrottlingCountHolder tokenBaseThrottlingCountHolder = api.getAiConfiguration()
.getTokenBasedThrottlingConfiguration();
artifact.setAttribute(APIConstants.AI_TOKEN_BASED_THROTTLING_ENABLED,
tokenBaseThrottlingCountHolder.isTokenBasedThrottlingEnabled().toString());
if (tokenBaseThrottlingCountHolder.getProductionMaxPromptTokenCount() != null) {
artifact.setAttribute(APIConstants.AI_PRODUCTION_MAX_PROMPT_TOKEN_COUNT,
tokenBaseThrottlingCountHolder.getProductionMaxPromptTokenCount());
}
if (tokenBaseThrottlingCountHolder.getProductionMaxCompletionTokenCount() != null) {
artifact.setAttribute(APIConstants.AI_PRODUCTION_MAX_COMPLETION_TOKEN_COUNT,
tokenBaseThrottlingCountHolder.getProductionMaxCompletionTokenCount());
}
if (tokenBaseThrottlingCountHolder.getProductionMaxTotalTokenCount() != null) {
artifact.setAttribute(APIConstants.AI_PRODUCTION_MAX_TOTAL_TOKEN_COUNT,
tokenBaseThrottlingCountHolder.getProductionMaxTotalTokenCount());
}
if (tokenBaseThrottlingCountHolder.getSandboxMaxPromptTokenCount() != null) {
artifact.setAttribute(APIConstants.AI_SANDBOX_MAX_PROMPT_TOKEN_COUNT,
tokenBaseThrottlingCountHolder.getSandboxMaxPromptTokenCount());
}
if (tokenBaseThrottlingCountHolder.getSandboxMaxCompletionTokenCount() != null) {
artifact.setAttribute(APIConstants.AI_SANDBOX_MAX_COMPLETION_TOKEN_COUNT,
tokenBaseThrottlingCountHolder.getSandboxMaxCompletionTokenCount());
}
if (tokenBaseThrottlingCountHolder.getSandboxMaxTotalTokenCount() != null) {
artifact.setAttribute(APIConstants.AI_SANDBOX_MAX_TOTAL_TOKEN_COUNT,
tokenBaseThrottlingCountHolder.getSandboxMaxTotalTokenCount());
}
}

String policies = policyBuilder.toString();

if (!"".equals(policies)) {
Expand Down Expand Up @@ -622,7 +660,30 @@ public static API getAPI(GovernanceArtifact artifact, Registry registry)
api.setImplementation(artifact.getAttribute(APIConstants.PROTOTYPE_OVERVIEW_IMPLEMENTATION));
api.setType(artifact.getAttribute(APIConstants.API_OVERVIEW_TYPE));
api.setProductionMaxTps(artifact.getAttribute(APIConstants.API_PRODUCTION_THROTTLE_MAXTPS));
api.setProductionTimeUnit(artifact.getAttribute(APIConstants.API_PRODUCTION_THROTTLE_TIMEUNIT));
api.setSandboxMaxTps(artifact.getAttribute(APIConstants.API_SANDBOX_THROTTLE_MAXTPS));
api.setSandboxTimeUnit(artifact.getAttribute(APIConstants.API_SANDBOX_THROTTLE_TIMEUNIT));

if (artifact.getAttribute(APIConstants.AI_TOKEN_BASED_THROTTLING_ENABLED) != null && Boolean.parseBoolean(
artifact.getAttribute(APIConstants.AI_TOKEN_BASED_THROTTLING_ENABLED))) {
TokenBaseThrottlingCountHolder aiThrottlingConfiguration = new TokenBaseThrottlingCountHolder();
aiThrottlingConfiguration.setTokenBasedThrottlingEnabled(true);
aiThrottlingConfiguration.setProductionMaxPromptTokenCount(
artifact.getAttribute(APIConstants.AI_PRODUCTION_MAX_PROMPT_TOKEN_COUNT));
aiThrottlingConfiguration.setProductionMaxCompletionTokenCount(
artifact.getAttribute(APIConstants.AI_PRODUCTION_MAX_COMPLETION_TOKEN_COUNT));
aiThrottlingConfiguration.setProductionMaxTotalTokenCount(
artifact.getAttribute(APIConstants.AI_PRODUCTION_MAX_TOTAL_TOKEN_COUNT));
aiThrottlingConfiguration.setSandboxMaxPromptTokenCount(
artifact.getAttribute(APIConstants.AI_SANDBOX_MAX_PROMPT_TOKEN_COUNT));
aiThrottlingConfiguration.setSandboxMaxCompletionTokenCount(
artifact.getAttribute(APIConstants.AI_SANDBOX_MAX_COMPLETION_TOKEN_COUNT));
aiThrottlingConfiguration.setSandboxMaxTotalTokenCount(
artifact.getAttribute(APIConstants.AI_SANDBOX_MAX_TOTAL_TOKEN_COUNT));
AIConfiguration aiConfiguration = new AIConfiguration();
aiConfiguration.setTokenBasedThrottlingConfiguration(aiThrottlingConfiguration);
api.setAiConfiguration(aiConfiguration);
}
api.setGatewayVendor(artifact.getAttribute(APIConstants.API_OVERVIEW_GATEWAY_VENDOR));
api.setAsyncTransportProtocols(artifact.getAttribute(APIConstants.ASYNC_API_TRANSPORT_PROTOCOLS));

Expand Down
Loading

0 comments on commit 205323c

Please sign in to comment.