Skip to content

Commit

Permalink
Introduce better tokenizer and token counting
Browse files Browse the repository at this point in the history
  • Loading branch information
michalwarda committed Feb 18, 2024
1 parent 004b058 commit eee61f0
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 23 deletions.
26 changes: 26 additions & 0 deletions apps/api/lib/buildel/langchain/chat_gpt.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do
require Logger
import Ecto.Changeset
import LangChain.Utils.ApiOverride
alias Buildel.Langchain.TokenUsage
alias __MODULE__
alias LangChain.Config
alias LangChain.ChatModels.ChatModel
Expand Down Expand Up @@ -287,6 +288,8 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do
# parse the body and return it as parsed structs
|> case do
{:ok, %Req.Response{body: data}} ->
call_callback_with_token_usage(data, callback_fn)

case do_process_response(data) do
{:error, reason} ->
{:error, reason}
Expand Down Expand Up @@ -327,6 +330,22 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do
|> Req.post(into: Utils.handle_stream_fn(openai, &do_process_response/1, callback_fn))
|> case do
{:ok, %Req.Response{body: data}} ->
chain_tokens =
Buildel.Langchain.ChatGptTokenizer.init(openai.model)
|> Buildel.Langchain.ChatGptTokenizer.count_chain_tokens(%{
functions: functions,
input_messages: messages,
messages: data |> List.flatten()
})

callback_fn.(
TokenUsage.new!(%{
prompt_tokens: chain_tokens.input_tokens,
completion_tokens: chain_tokens.output_tokens,
total_tokens: chain_tokens.input_tokens + chain_tokens.output_tokens
})
)

data

{:error, %LangChainError{message: reason}} ->
Expand Down Expand Up @@ -497,4 +516,11 @@ defmodule Buildel.LangChain.ChatModels.ChatOpenAI do
req
end
end

defp call_callback_with_token_usage(_data, nil), do: nil

defp call_callback_with_token_usage(%{"usage" => usage}, callback_fn) when is_map(usage),
do: callback_fn.(TokenUsage.new!(usage))

defp call_callback_with_token_usage(_data, _callback_fn), do: nil
end
48 changes: 26 additions & 22 deletions apps/api/lib/buildel/langchain/chat_gpt_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,57 @@ defmodule Buildel.Langchain.ChatGptTokenizer do

def count_chain_tokens(%__MODULE__{} = tokenizer, %{
functions: functions,
messages: messages,
messages: output_messages,
input_messages: input_messages
}) do
function_metadata_tokens =
functions
|> Enum.map(&count_function_tokens(tokenizer, &1))
|> Enum.map(fn function -> count_function_tokens(tokenizer, function) + 4 end)
|> Enum.sum()

function_metadata_tokens =
if function_metadata_tokens > 0 do
function_metadata_tokens + 4
else
function_metadata_tokens
end

input_message_tokens =
messages
|> Enum.take(Enum.count(input_messages))
|> Enum.map(&count_message_tokens(tokenizer, &1))
input_messages
|> Enum.map(fn message -> count_message_tokens(tokenizer, message) + 4 end)
|> Enum.sum()

output_messages =
messages
|> Enum.drop(Enum.count(input_messages))

output_function_messages_tokens =
output_messages
|> Enum.filter(&(&1.function_name != nil && &1.arguments == nil))
|> Enum.map(&count_message_tokens(tokenizer, &1))
|> Enum.sum()
input_message_tokens = input_message_tokens + 3

output_text_messages_tokens =
output_message_tokens =
output_messages
|> Enum.filter(&(&1.function_name == nil || &1.arguments != nil))
|> Enum.map(&count_message_tokens(tokenizer, &1))
|> Enum.sum()

summary = %Buildel.Langchain.ChatTokenSummary{
model: tokenizer.model,
endpoint: "openai",
input_tokens:
input_message_tokens + function_metadata_tokens + output_function_messages_tokens,
output_tokens: output_text_messages_tokens
input_tokens: input_message_tokens + function_metadata_tokens,
output_tokens: output_message_tokens
}

Logger.debug("ChatTokenSumary: #{inspect(summary)}")

summary
end

def count_message_tokens(%__MODULE__{}, %{content: nil, function_name: nil}) do
def count_message_tokens(%__MODULE__{}, %{content: nil, function_name: nil, arguments: nil}) do
0
end

def count_message_tokens(%__MODULE__{} = tokenizer, %{
content: nil,
function_name: nil,
arguments: arguments
}) do
count_text_tokens(tokenizer, arguments)
end

def count_message_tokens(%__MODULE__{} = tokenizer, %{
content: nil,
function_name: function_name,
Expand All @@ -76,14 +80,14 @@ defmodule Buildel.Langchain.ChatGptTokenizer do
})
when is_binary(function_name) do
count_text_tokens(tokenizer, function_name) +
count_text_tokens(tokenizer, arguments |> Jason.encode!())
count_text_tokens(tokenizer, arguments |> Jason.encode!()) + 7
end

def count_message_tokens(%__MODULE__{} = tokenizer, %{
content: content,
function_name: nil
}) do
count_text_tokens(tokenizer, content) + 7
count_text_tokens(tokenizer, content)
end

def count_message_tokens(%__MODULE__{} = tokenizer, %{
Expand Down
1 change: 0 additions & 1 deletion apps/api/lib/buildel_web/google_token.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ defmodule BuildelWeb.GoogleToken do
add_hook(JokenJwks, strategy: BuildelWeb.GoogleJwksStrategy)

def token_config do
IO.puts("config")
%{}
end
end
79 changes: 79 additions & 0 deletions apps/api/test/buildel/langchain/chat_gpt_tokenizer.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
defmodule Buildel.LangChain.ChatGptTokenizerTest do
alias Buildel.Langchain.ChatGptTokenizer

use Buildel.LangChain.BaseCase

test "correctly counts tokens" do
tokenizer = ChatGptTokenizer.init("gpt-3.5-turbo")

messages = [
%LangChain.MessageDelta{
content: nil,
status: :incomplete,
index: 0,
function_name: "query",
role: :assistant,
arguments: ""
},
%LangChain.MessageDelta{
content: nil,
status: :incomplete,
index: 0,
function_name: nil,
role: :unknown,
arguments: "{\""
},
%LangChain.MessageDelta{
content: nil,
status: :incomplete,
index: 0,
function_name: nil,
role: :unknown,
arguments: "query"
},
%LangChain.MessageDelta{
content: nil,
status: :incomplete,
index: 0,
function_name: nil,
role: :unknown,
arguments: "\":\""
},
%LangChain.MessageDelta{
content: nil,
status: :incomplete,
index: 0,
function_name: nil,
role: :unknown,
arguments: "TEST"
},
%LangChain.MessageDelta{
content: nil,
status: :incomplete,
index: 0,
function_name: nil,
role: :unknown,
arguments: "\"}"
},
%LangChain.MessageDelta{
content: nil,
status: :complete,
index: 0,
function_name: nil,
role: :unknown,
arguments: nil
}
]

assert ChatGptTokenizer.count_chain_tokens(tokenizer, %{
functions: [],
input_messages: [],
messages: messages
}) == %Buildel.Langchain.ChatTokenSummary{
model: "gpt-3.5-turbo",
endpoint: "openai",
input_tokens: 3,
output_tokens: 14
}
end
end

0 comments on commit eee61f0

Please sign in to comment.