diff --git a/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java b/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java index b602e11f9..46bc5106d 100644 --- a/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java +++ b/src/main/java/com/epam/aidial/core/controller/DeploymentPostController.java @@ -20,6 +20,7 @@ import com.epam.aidial.core.util.BufferingReadStream; import com.epam.aidial.core.util.HttpException; import com.epam.aidial.core.util.HttpStatus; +import com.epam.aidial.core.util.ModelCostCalculator; import com.epam.aidial.core.util.ProxyUtil; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; @@ -297,6 +298,12 @@ void handleResponse() { context.setTokenUsage(tokenUsage); proxy.getRateLimiter().increase(context); tokenUsageFuture = Future.succeededFuture(tokenUsage); + try { + tokenUsage.setCost(ModelCostCalculator.calculate(context)); + } catch (Throwable e) { + log.warn("Failed to calculate cost for model={}. Trace: {}. Span: {}", + context.getDeployment().getName(), context.getTraceId(), context.getSpanId()); + } if (tokenUsage == null) { log.warn("Can't find token usage. Trace: {}. Span: {}. Key: {}. Deployment: {}. Endpoint: {}. Upstream: {}. Status: {}. Length: {}", context.getTraceId(), context.getSpanId(), @@ -305,13 +312,6 @@ void handleResponse() { context.getUpstreamRoute().get().getEndpoint(), context.getResponse().getStatusCode(), context.getResponseBody().length()); - } else { - Model model = (Model) context.getDeployment(); - try { - tokenUsage.calculateCost(model.getPricing()); - } catch (Throwable e) { - log.warn("Failed to calculate cost for model={}", model.getName()); - } } } } else { diff --git a/src/main/java/com/epam/aidial/core/token/TokenUsage.java b/src/main/java/com/epam/aidial/core/token/TokenUsage.java index baf0a9713..8d29cc2b4 100644 --- a/src/main/java/com/epam/aidial/core/token/TokenUsage.java +++ b/src/main/java/com/epam/aidial/core/token/TokenUsage.java @@ -1,6 +1,5 @@ package com.epam.aidial.core.token; -import com.epam.aidial.core.config.Pricing; import lombok.Data; @Data @@ -11,26 +10,6 @@ public class TokenUsage { private Double cost; private Double aggCost; - public void calculateCost(Pricing pricing) { - if (pricing == null) { - return; - } - String unit = pricing.getUnit(); - if (!"token".equals(unit)) { - return; - } - double cost = 0.0; - if (pricing.getPrompt() != null) { - cost += promptTokens * Double.parseDouble(pricing.getPrompt()); - } - if (pricing.getCompletion() != null) { - cost += completionTokens * Double.parseDouble(pricing.getCompletion()); - } - if (pricing.getPrompt() != null || pricing.getCompletion() != null) { - this.cost = cost; - } - } - public void increase(TokenUsage other) { if (other == null) { return; diff --git a/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java b/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java new file mode 100644 index 000000000..abc0f396c --- /dev/null +++ b/src/main/java/com/epam/aidial/core/util/ModelCostCalculator.java @@ -0,0 +1,158 @@ +package com.epam.aidial.core.util; + +import com.epam.aidial.core.ProxyContext; +import com.epam.aidial.core.config.Deployment; +import com.epam.aidial.core.config.Model; +import com.epam.aidial.core.config.ModelType; +import com.epam.aidial.core.config.Pricing; +import com.epam.aidial.core.token.TokenUsage; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.netty.buffer.ByteBufInputStream; +import io.vertx.core.buffer.Buffer; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; + +import java.io.InputStream; +import java.util.Scanner; + +@Slf4j +@UtilityClass +public class ModelCostCalculator { + + public static Double calculate(ProxyContext context) { + Deployment deployment = context.getDeployment(); + if (!(deployment instanceof Model model)) { + return null; + } + + Pricing pricing = model.getPricing(); + if (pricing == null) { + return null; + } + + return switch (pricing.getUnit()) { + case "token" -> calculate(context.getTokenUsage(), pricing.getPrompt(), pricing.getCompletion()); + case "char_without_whitespace" -> + calculate(model.getType(), context.getRequestBody(), context.getResponseBody(), pricing.getPrompt(), pricing.getCompletion()); + default -> null; + }; + } + + private static Double calculate(TokenUsage tokenUsage, String promptRate, String completionRate) { + if (tokenUsage == null) { + return null; + } + double cost = 0.0; + if (promptRate != null) { + cost += tokenUsage.getPromptTokens() * Double.parseDouble(promptRate); + } + if (completionRate != null) { + cost += tokenUsage.getCompletionTokens() * Double.parseDouble(completionRate); + } + if (promptRate != null || completionRate != null) { + return cost; + } + return null; + } + + private static Double calculate(ModelType modelType, Buffer requestBody, Buffer responseBody, String promptRate, String completionRate) { + RequestLengthResult requestLengthResult = getRequestContentLength(modelType, requestBody); + int responseLength = getResponseContentLength(modelType, responseBody, requestLengthResult.stream()); + double cost = 0.0; + if (promptRate != null) { + cost += requestLengthResult.length() * Double.parseDouble(promptRate); + } + if (completionRate != null) { + cost += responseLength * Double.parseDouble(completionRate); + } + if (promptRate != null || completionRate != null) { + return cost; + } + return null; + } + + private static int getResponseContentLength(ModelType modelType, Buffer responseBody, boolean isStreamingResponse) { + if (modelType == ModelType.EMBEDDING) { + return 0; + } + if (isStreamingResponse) { + try (Scanner scanner = new Scanner(new ByteBufInputStream(responseBody.getByteBuf()))) { + scanner.useDelimiter("\n*data: *"); + int len = 0; + while (scanner.hasNext()) { + String chunk = scanner.next(); + if (chunk.startsWith("[DONE]")) { + break; + } + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(chunk); + ArrayNode choices = (ArrayNode) tree.get("choices"); + JsonNode contentNode = choices.get(0).get("delta").get("content"); + if (contentNode != null) { + len += getLengthWithoutWhitespace(contentNode.textValue()); + } + } + return len; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } else { + try (InputStream stream = new ByteBufInputStream(responseBody.getByteBuf())) { + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream); + ArrayNode choices = (ArrayNode) tree.get("choices"); + JsonNode contentNode = choices.get(0).get("message").get("content"); + return getLengthWithoutWhitespace(contentNode.textValue()); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + } + + private static RequestLengthResult getRequestContentLength(ModelType modelType, Buffer requestBody) { + try (InputStream stream = new ByteBufInputStream(requestBody.getByteBuf())) { + int len; + ObjectNode tree = (ObjectNode) ProxyUtil.MAPPER.readTree(stream); + if (modelType == ModelType.CHAT) { + ArrayNode messages = (ArrayNode) tree.get("messages"); + len = 0; + for (int i = 0; i < messages.size(); i++) { + JsonNode message = messages.get(i); + len += getLengthWithoutWhitespace(message.get("content").textValue()); + } + return new RequestLengthResult(len, tree.get("stream").asBoolean(false)); + } else { + JsonNode input = tree.get("input"); + if (input instanceof ArrayNode array) { + len = 0; + for (int i = 0; i < array.size(); i++) { + len += getLengthWithoutWhitespace(array.get(i).textValue()); + } + } else { + len = getLengthWithoutWhitespace(input.textValue()); + } + } + return new RequestLengthResult(len, false); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + private static int getLengthWithoutWhitespace(String s) { + if (s == null) { + return 0; + } + int len = 0; + for (int i = 0; i < s.length(); i++) { + if (s.charAt(i) != ' ') { + len++; + } + } + return len; + } + + private record RequestLengthResult(int length, boolean stream) { + + } + +} diff --git a/src/test/java/com/epam/aidial/core/token/TokenUsageTest.java b/src/test/java/com/epam/aidial/core/token/TokenUsageTest.java index 3c31d460b..df4aaafa7 100644 --- a/src/test/java/com/epam/aidial/core/token/TokenUsageTest.java +++ b/src/test/java/com/epam/aidial/core/token/TokenUsageTest.java @@ -1,88 +1,11 @@ package com.epam.aidial.core.token; -import com.epam.aidial.core.config.Pricing; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; public class TokenUsageTest { - @Test - public void testCalculateCost_NullPricing() { - TokenUsage tokenUsage = new TokenUsage(); - tokenUsage.calculateCost(null); - assertNull(tokenUsage.getCost()); - } - - @Test - public void testCalculateCost_DifferentUnit() { - TokenUsage tokenUsage = new TokenUsage(); - Pricing pricing = new Pricing(); - pricing.setUnit("other"); - - tokenUsage.calculateCost(pricing); - - assertNull(tokenUsage.getCost()); - } - - @Test - public void testCalculateCost_PromptCompletionNulls() { - TokenUsage tokenUsage = new TokenUsage(); - Pricing pricing = new Pricing(); - pricing.setUnit("token"); - - tokenUsage.calculateCost(pricing); - - assertNull(tokenUsage.getCost()); - } - - @Test - public void testCalculateCost_Normal() { - TokenUsage tokenUsage = new TokenUsage(); - tokenUsage.setPromptTokens(10); - tokenUsage.setCompletionTokens(50); - - Pricing pricing = new Pricing(); - pricing.setUnit("token"); - pricing.setPrompt("0.5"); - pricing.setCompletion("0.8"); - - tokenUsage.calculateCost(pricing); - - assertEquals(45, tokenUsage.getCost()); - } - - @Test - public void testCalculateCost_PromptNull() { - TokenUsage tokenUsage = new TokenUsage(); - tokenUsage.setPromptTokens(10); - tokenUsage.setCompletionTokens(50); - - Pricing pricing = new Pricing(); - pricing.setUnit("token"); - pricing.setCompletion("0.8"); - - tokenUsage.calculateCost(pricing); - - assertEquals(40, tokenUsage.getCost()); - } - - @Test - public void testCalculateCost_CompletionNull() { - TokenUsage tokenUsage = new TokenUsage(); - tokenUsage.setPromptTokens(10); - tokenUsage.setCompletionTokens(50); - - Pricing pricing = new Pricing(); - pricing.setUnit("token"); - pricing.setPrompt("0.5"); - - tokenUsage.calculateCost(pricing); - - assertEquals(5, tokenUsage.getCost()); - } - @Test public void testIncrease_Model() { TokenUsage tokenUsage = new TokenUsage(); diff --git a/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java b/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java new file mode 100644 index 000000000..e602e7c6d --- /dev/null +++ b/src/test/java/com/epam/aidial/core/util/ModelCostCalculatorTest.java @@ -0,0 +1,222 @@ +package com.epam.aidial.core.util; + +import com.epam.aidial.core.ProxyContext; +import com.epam.aidial.core.config.Model; +import com.epam.aidial.core.config.ModelType; +import com.epam.aidial.core.config.Pricing; +import com.epam.aidial.core.token.TokenUsage; +import io.vertx.core.buffer.Buffer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.when; + +@SuppressWarnings("checkstyle:LineLength") +@ExtendWith(MockitoExtension.class) +public class ModelCostCalculatorTest { + + @Mock + private ProxyContext context; + + @Test + public void testCalculate_DeploymentIsNotModel() { + assertNull(ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_PricingIsNull() { + when(context.getDeployment()).thenReturn(new Model()); + assertNull(ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_UnknownCostUnit() { + Model model = new Model(); + Pricing pricing = new Pricing(); + pricing.setUnit("unknown"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + assertNull(ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_TokenCost() { + Model model = new Model(); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("token"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + TokenUsage tokenUsage = new TokenUsage(); + tokenUsage.setCompletionTokens(10); + tokenUsage.setPromptTokens(10); + when(context.getTokenUsage()).thenReturn(tokenUsage); + + assertEquals(6.0, ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_Chat_StreamIsFalse() { + Model model = new Model(); + model.setType(ModelType.CHAT); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + { + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "A file is a named collection." + } + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 343, + "total_tokens": 347 + }, + "id": "fd3be95a-c208-4dca-90cf-67e5082a4e5b", + "created": 1705319789, + "object": "chat.completion" + } + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "How are you?" + } + ], + "max_tokens": 500, + "temperature": 1, + "stream": false + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(24 * 0.5 + 10 * 0.1, ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_Chat_StreamIsTrue() { + Model model = new Model(); + model.setType(ModelType.CHAT); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"role":"assistant"}}],"usage":null} + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":"this"}}],"usage":null} + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":" is "}}],"usage":null} + + + + data: {"id":"chatcmpl-7VfCSOSOS1gYQbDFiEMyh71RJSy1m","object":"chat.completion.chunk","created":1687780896,"model":"gpt-35-turbo","choices":[{"index":0,"finish_reason":null,"delta":{"content":"a text"}}],"usage":null} + + data: [DONE] + + + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "How are you?" + } + ], + "max_tokens": 500, + "temperature": 1, + "stream": true + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(11 * 0.5 + 10 * 0.1, ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_EmbeddingInputIsArray() { + Model model = new Model(); + model.setType(ModelType.EMBEDDING); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + {} + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "input": ["text", "123"] + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(0.1 * 7, ModelCostCalculator.calculate(context)); + } + + @Test + public void testCalculate_LengthCost_EmbeddingInputIsString() { + Model model = new Model(); + model.setType(ModelType.EMBEDDING); + Pricing pricing = new Pricing(); + pricing.setPrompt("0.1"); + pricing.setCompletion("0.5"); + pricing.setUnit("char_without_whitespace"); + model.setPricing(pricing); + when(context.getDeployment()).thenReturn(model); + + String response = """ + {} + """; + when(context.getResponseBody()).thenReturn(Buffer.buffer(response)); + + String request = """ + { + "input": "text" + } + """; + when(context.getRequestBody()).thenReturn(Buffer.buffer(request)); + + assertEquals(0.1 * 4, ModelCostCalculator.calculate(context)); + } +}