From eee61f0204e52fbdf4a4f91365f891b3ea196712 Mon Sep 17 00:00:00 2001 From: Michal Warda Date: Sun, 18 Feb 2024 19:56:27 +0100 Subject: [PATCH] Introduce better tokenizer and token counting --- apps/api/lib/buildel/langchain/chat_gpt.ex | 26 ++++++ .../buildel/langchain/chat_gpt_tokenizer.ex | 48 +++++------ apps/api/lib/buildel_web/google_token.ex | 1 - .../buildel/langchain/chat_gpt_tokenizer.exs | 79 +++++++++++++++++++ 4 files changed, 131 insertions(+), 23 deletions(-) create mode 100644 apps/api/test/buildel/langchain/chat_gpt_tokenizer.exs diff --git a/apps/api/lib/buildel/langchain/chat_gpt.ex b/apps/api/lib/buildel/langchain/chat_gpt.ex index 23da500d7..e43000274 100644 --- a/apps/api/lib/buildel/langchain/chat_gpt.ex +++ b/apps/api/lib/buildel/langchain/chat_gpt.ex @@ -13,6 +13,7 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do require Logger import Ecto.Changeset import LangChain.Utils.ApiOverride + alias Buildel.Langchain.TokenUsage alias __MODULE__ alias LangChain.Config alias LangChain.ChatModels.ChatModel @@ -287,6 +288,8 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do # parse the body and return it as parsed structs |> case do {:ok, %Req.Response{body: data}} -> + call_callback_with_token_usage(data, callback_fn) + case do_process_response(data) do {:error, reason} -> {:error, reason} @@ -327,6 +330,22 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do |> Req.post(into: Utils.handle_stream_fn(openai, &do_process_response/1, callback_fn)) |> case do {:ok, %Req.Response{body: data}} -> + chain_tokens = + Buildel.Langchain.ChatGptTokenizer.init(openai.model) + |> Buildel.Langchain.ChatGptTokenizer.count_chain_tokens(%{ + functions: functions, + input_messages: messages, + messages: data |> List.flatten() + }) + + callback_fn.( + TokenUsage.new!(%{ + prompt_tokens: chain_tokens.input_tokens, + completion_tokens: chain_tokens.output_tokens, + total_tokens: chain_tokens.input_tokens + chain_tokens.output_tokens + }) + ) + data {:error, %LangChainError{message: reason}} -> @@ -497,4 +516,11 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do req end end + + defp call_callback_with_token_usage(_data, nil), do: nil + + defp call_callback_with_token_usage(%{"usage" => usage}, callback_fn) when is_map(usage), + do: callback_fn.(TokenUsage.new!(usage)) + + defp call_callback_with_token_usage(_data, _callback_fn), do: nil end diff --git a/apps/api/lib/buildel/langchain/chat_gpt_tokenizer.ex b/apps/api/lib/buildel/langchain/chat_gpt_tokenizer.ex index aa1c58b9e..db37134b9 100644 --- a/apps/api/lib/buildel/langchain/chat_gpt_tokenizer.ex +++ b/apps/api/lib/buildel/langchain/chat_gpt_tokenizer.ex @@ -14,42 +14,38 @@ defmodule Buildel.Langchain.ChatGptTokenizer do def count_chain_tokens(%__MODULE__{} = tokenizer, %{ functions: functions, - messages: messages, + messages: output_messages, input_messages: input_messages }) do function_metadata_tokens = functions - |> Enum.map(&count_function_tokens(tokenizer, &1)) + |> Enum.map(fn function -> count_function_tokens(tokenizer, function) + 4 end) |> Enum.sum() + function_metadata_tokens = + if function_metadata_tokens > 0 do + function_metadata_tokens + 4 + else + function_metadata_tokens + end + input_message_tokens = - messages - |> Enum.take(Enum.count(input_messages)) - |> Enum.map(&count_message_tokens(tokenizer, &1)) + input_messages + |> Enum.map(fn message -> count_message_tokens(tokenizer, message) + 4 end) |> Enum.sum() - output_messages = - messages - |> Enum.drop(Enum.count(input_messages)) - - output_function_messages_tokens = - output_messages - |> Enum.filter(&(&1.function_name != nil && &1.arguments == nil)) - |> Enum.map(&count_message_tokens(tokenizer, &1)) - |> Enum.sum() + input_message_tokens = input_message_tokens + 3 - output_text_messages_tokens = + output_message_tokens = output_messages - |> Enum.filter(&(&1.function_name == nil || &1.arguments != nil)) |> Enum.map(&count_message_tokens(tokenizer, &1)) |> Enum.sum() summary = %Buildel.Langchain.ChatTokenSummary{ model: tokenizer.model, endpoint: "openai", - input_tokens: - input_message_tokens + function_metadata_tokens + output_function_messages_tokens, - output_tokens: output_text_messages_tokens + input_tokens: input_message_tokens + function_metadata_tokens, + output_tokens: output_message_tokens } Logger.debug("ChatTokenSumary: #{inspect(summary)}") @@ -57,10 +53,18 @@ defmodule Buildel.Langchain.ChatGptTokenizer do summary end - def count_message_tokens(%__MODULE__{}, %{content: nil, function_name: nil}) do + def count_message_tokens(%__MODULE__{}, %{content: nil, function_name: nil, arguments: nil}) do 0 end + def count_message_tokens(%__MODULE__{} = tokenizer, %{ + content: nil, + function_name: nil, + arguments: arguments + }) do + count_text_tokens(tokenizer, arguments) + end + def count_message_tokens(%__MODULE__{} = tokenizer, %{ content: nil, function_name: function_name, @@ -76,14 +80,14 @@ defmodule Buildel.Langchain.ChatGptTokenizer do }) when is_binary(function_name) do count_text_tokens(tokenizer, function_name) + - count_text_tokens(tokenizer, arguments |> Jason.encode!()) + count_text_tokens(tokenizer, arguments |> Jason.encode!()) + 7 end def count_message_tokens(%__MODULE__{} = tokenizer, %{ content: content, function_name: nil }) do - count_text_tokens(tokenizer, content) + 7 + count_text_tokens(tokenizer, content) end def count_message_tokens(%__MODULE__{} = tokenizer, %{ diff --git a/apps/api/lib/buildel_web/google_token.ex b/apps/api/lib/buildel_web/google_token.ex index ab83ccc75..6408e19af 100644 --- a/apps/api/lib/buildel_web/google_token.ex +++ b/apps/api/lib/buildel_web/google_token.ex @@ -4,7 +4,6 @@ defmodule BuildelWeb.GoogleToken do add_hook(JokenJwks, strategy: BuildelWeb.GoogleJwksStrategy) def token_config do - IO.puts("config") %{} end end diff --git a/apps/api/test/buildel/langchain/chat_gpt_tokenizer.exs b/apps/api/test/buildel/langchain/chat_gpt_tokenizer.exs new file mode 100644 index 000000000..30327d005 --- /dev/null +++ b/apps/api/test/buildel/langchain/chat_gpt_tokenizer.exs @@ -0,0 +1,79 @@ +defmodule Buildel.LangChain.ChatGptTokenizerTest do + alias Buildel.Langchain.ChatGptTokenizer + + use Buildel.LangChain.BaseCase + + test "correctly counts tokens" do + tokenizer = ChatGptTokenizer.init("gpt-3.5-turbo") + + messages = [ + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: 0, + function_name: "query", + role: :assistant, + arguments: "" + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: 0, + function_name: nil, + role: :unknown, + arguments: "{\"" + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: 0, + function_name: nil, + role: :unknown, + arguments: "query" + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: 0, + function_name: nil, + role: :unknown, + arguments: "\":\"" + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: 0, + function_name: nil, + role: :unknown, + arguments: "TEST" + }, + %LangChain.MessageDelta{ + content: nil, + status: :incomplete, + index: 0, + function_name: nil, + role: :unknown, + arguments: "\"}" + }, + %LangChain.MessageDelta{ + content: nil, + status: :complete, + index: 0, + function_name: nil, + role: :unknown, + arguments: nil + } + ] + + assert ChatGptTokenizer.count_chain_tokens(tokenizer, %{ + functions: [], + input_messages: [], + messages: messages + }) == %Buildel.Langchain.ChatTokenSummary{ + model: "gpt-3.5-turbo", + endpoint: "openai", + input_tokens: 3, + output_tokens: 14 + } + end +end