Skip to content

Commit

Permalink
feat: Calculate cost for models with "char without space" #154 (#155)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Jan 18, 2024
1 parent 825aa61 commit 8e4fcd9
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.epam.aidial.core.util.BufferingReadStream;
import com.epam.aidial.core.util.HttpException;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ModelCostCalculator;
import com.epam.aidial.core.util.ProxyUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
Expand Down Expand Up @@ -297,6 +298,12 @@ void handleResponse() {
context.setTokenUsage(tokenUsage);
proxy.getRateLimiter().increase(context);
tokenUsageFuture = Future.succeededFuture(tokenUsage);
try {
tokenUsage.setCost(ModelCostCalculator.calculate(context));
} catch (Throwable e) {
log.warn("Failed to calculate cost for model={}. Trace: {}. Span: {}",
context.getDeployment().getName(), context.getTraceId(), context.getSpanId());
}
if (tokenUsage == null) {
log.warn("Can't find token usage. Trace: {}. Span: {}. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}",
context.getTraceId(), context.getSpanId(),
Expand All @@ -305,13 +312,6 @@ void handleResponse() {
context.getUpstreamRoute().get().getEndpoint(),
context.getResponse().getStatusCode(),
context.getResponseBody().length());
} else {
Model model = (Model) context.getDeployment();
try {
tokenUsage.calculateCost(model.getPricing());
} catch (Throwable e) {
log.warn("Failed to calculate cost for model={}", model.getName());
}
}
}
} else {
Expand Down
21 changes: 0 additions & 21 deletions src/main/java/com/epam/aidial/core/token/TokenUsage.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.epam.aidial.core.token;

import com.epam.aidial.core.config.Pricing;
import lombok.Data;

@Data
Expand All @@ -11,26 +10,6 @@ public class TokenUsage {
private Double cost;
private Double aggCost;

public void calculateCost(Pricing pricing) {
if (pricing == null) {
return;
}
String unit = pricing.getUnit();
if (!"token".equals(unit)) {
return;
}
double cost = 0.0;
if (pricing.getPrompt() != null) {
cost += promptTokens * Double.parseDouble(pricing.getPrompt());
}
if (pricing.getCompletion() != null) {
cost += completionTokens * Double.parseDouble(pricing.getCompletion());
}
if (pricing.getPrompt() != null || pricing.getCompletion() != null) {
this.cost = cost;
}
}

public void increase(TokenUsage other) {
if (other == null) {
return;
Expand Down
158 changes: 158 additions & 0 deletions src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package com.epam.aidial.core.util;

import com.epam.aidial.core.ProxyContext;
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.ModelType;
import com.epam.aidial.core.config.Pricing;
import com.epam.aidial.core.token.TokenUsage;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.netty.buffer.ByteBufInputStream;
import io.vertx.core.buffer.Buffer;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;

import java.io.InputStream;
import java.util.Scanner;

@Slf4j
@UtilityClass
public class ModelCostCalculator {

public static Double calculate(ProxyContext context) {
Deployment deployment = context.getDeployment();
if (!(deployment instanceof Model model)) {
return null;
}

Pricing pricing = model.getPricing();
if (pricing == null) {
return null;
}

return switch (pricing.getUnit()) {
case "token" -> calculate(context.getTokenUsage(), pricing.getPrompt(), pricing.getCompletion());
case "char_without_whitespace" ->
calculate(model.getType(), context.getRequestBody(), context.getResponseBody(), pricing.getPrompt(), pricing.getCompletion());
default -> null;
};
}

private static Double calculate(TokenUsage tokenUsage, String promptRate, String completionRate) {
if (tokenUsage == null) {
return null;
}
double cost = 0.0;
if (promptRate != null) {
cost += tokenUsage.getPromptTokens() * Double.parseDouble(promptRate);
}
if (completionRate != null) {
cost += tokenUsage.getCompletionTokens() * Double.parseDouble(completionRate);
}
if (promptRate != null || completionRate != null) {
return cost;
}
return null;
}

private static Double calculate(ModelType modelType, Buffer requestBody, Buffer responseBody, String promptRate, String completionRate) {
RequestLengthResult requestLengthResult = getRequestContentLength(modelType, requestBody);
int responseLength = getResponseContentLength(modelType, responseBody, requestLengthResult.stream());
double cost = 0.0;
if (promptRate != null) {
cost += requestLengthResult.length() * Double.parseDouble(promptRate);
}
if (completionRate != null) {
cost += responseLength * Double.parseDouble(completionRate);
}
if (promptRate != null || completionRate != null) {
return cost;
}
return null;
}

private static int getResponseContentLength(ModelType modelType, Buffer responseBody, boolean isStreamingResponse) {
if (modelType == ModelType.EMBEDDING) {
return 0;
}
if (isStreamingResponse) {
try (Scanner scanner = new Scanner(new ByteBufInputStream(responseBody.getByteBuf()))) {
scanner.useDelimiter("\n*data: *");
int len = 0;
while (scanner.hasNext()) {
String chunk = scanner.next();
if (chunk.startsWith("[DONE]")) {
break;
}
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(chunk);
ArrayNode choices = (ArrayNode) tree.get("choices");
JsonNode contentNode = choices.get(0).get("delta").get("content");
if (contentNode != null) {
len += getLengthWithoutWhitespace(contentNode.textValue());
}
}
return len;
} catch (Throwable e) {
throw new RuntimeException(e);
}
} else {
try (InputStream stream = new ByteBufInputStream(responseBody.getByteBuf())) {
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
ArrayNode choices = (ArrayNode) tree.get("choices");
JsonNode contentNode = choices.get(0).get("message").get("content");
return getLengthWithoutWhitespace(contentNode.textValue());
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
}

private static RequestLengthResult getRequestContentLength(ModelType modelType, Buffer requestBody) {
try (InputStream stream = new ByteBufInputStream(requestBody.getByteBuf())) {
int len;
ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream);
if (modelType == ModelType.CHAT) {
ArrayNode messages = (ArrayNode) tree.get("messages");
len = 0;
for (int i = 0; i < messages.size(); i++) {
JsonNode message = messages.get(i);
len += getLengthWithoutWhitespace(message.get("content").textValue());
}
return new RequestLengthResult(len, tree.get("stream").asBoolean(false));
} else {
JsonNode input = tree.get("input");
if (input instanceof ArrayNode array) {
len = 0;
for (int i = 0; i < array.size(); i++) {
len += getLengthWithoutWhitespace(array.get(i).textValue());
}
} else {
len = getLengthWithoutWhitespace(input.textValue());
}
}
return new RequestLengthResult(len, false);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

private static int getLengthWithoutWhitespace(String s) {
if (s == null) {
return 0;
}
int len = 0;
for (int i = 0; i < s.length(); i++) {
if (s.charAt(i) != ' ') {
len++;
}
}
return len;
}

private record RequestLengthResult(int length, boolean stream) {

}

}
77 changes: 0 additions & 77 deletions src/test/java/com/epam/aidial/core/token/TokenUsageTest.java
Original file line number Diff line number Diff line change
@@ -1,88 +1,11 @@
package com.epam.aidial.core.token;

import com.epam.aidial.core.config.Pricing;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;

public class TokenUsageTest {

@Test
public void testCalculateCost_NullPricing() {
TokenUsage tokenUsage = new TokenUsage();
tokenUsage.calculateCost(null);
assertNull(tokenUsage.getCost());
}

@Test
public void testCalculateCost_DifferentUnit() {
TokenUsage tokenUsage = new TokenUsage();
Pricing pricing = new Pricing();
pricing.setUnit("other");

tokenUsage.calculateCost(pricing);

assertNull(tokenUsage.getCost());
}

@Test
public void testCalculateCost_PromptCompletionNulls() {
TokenUsage tokenUsage = new TokenUsage();
Pricing pricing = new Pricing();
pricing.setUnit("token");

tokenUsage.calculateCost(pricing);

assertNull(tokenUsage.getCost());
}

@Test
public void testCalculateCost_Normal() {
TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setPromptTokens(10);
tokenUsage.setCompletionTokens(50);

Pricing pricing = new Pricing();
pricing.setUnit("token");
pricing.setPrompt("0.5");
pricing.setCompletion("0.8");

tokenUsage.calculateCost(pricing);

assertEquals(45, tokenUsage.getCost());
}

@Test
public void testCalculateCost_PromptNull() {
TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setPromptTokens(10);
tokenUsage.setCompletionTokens(50);

Pricing pricing = new Pricing();
pricing.setUnit("token");
pricing.setCompletion("0.8");

tokenUsage.calculateCost(pricing);

assertEquals(40, tokenUsage.getCost());
}

@Test
public void testCalculateCost_CompletionNull() {
TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setPromptTokens(10);
tokenUsage.setCompletionTokens(50);

Pricing pricing = new Pricing();
pricing.setUnit("token");
pricing.setPrompt("0.5");

tokenUsage.calculateCost(pricing);

assertEquals(5, tokenUsage.getCost());
}

@Test
public void testIncrease_Model() {
TokenUsage tokenUsage = new TokenUsage();
Expand Down
Loading

0 comments on commit 8e4fcd9

Please sign in to comment.