Skip to content

Commit

Permalink
tweaks based on testing
Browse files Browse the repository at this point in the history
  • Loading branch information
stakach committed Nov 28, 2023
1 parent ae7ba19 commit 0271643
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion shard.lock
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ shards:

openai:
git: https://github.com/spider-gazelle/crystal-openai.git
version: 0.9.1+git.commit.4f3127df2154434e02ed355d2fc321da5e1df10a
version: 0.9.1+git.commit.5ff22991f74b1fa09361728bec3e99c9d9661ab8

openapi-generator:
git: https://github.com/place-labs/openapi-generator.git
Expand Down
39 changes: 20 additions & 19 deletions src/placeos-rest-api/controllers/chat_gpt/chat_manager.cr
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,11 @@ module PlaceOS::Api
# perform function calls until we get a response for the user
if tool_calls = msg.tool_calls
discardable_tokens += resp.usage.completion_tokens

# handle the AI not providing a valid function name, we want it to retry
func_res = begin
executor.execute(tool_calls)
rescue ex
Log.error(exception: ex) { "executing function call" }
reply = "Encountered error: #{ex.message}"
result = DriverResponse.new(reply).as(JSON::Serializable)
func_name = tool_calls.first?.try &.function.name || begin
# try to get function name from the exception. Assuming it was raised by executor
if (msg = ex.message) && msg.starts_with?("OpenAI called unknown function: name: '")
msg.lchop("OpenAI called unknown function: name: '")[...-2]
end
end || "unknown_function"
request.messages << OpenAI::ChatMessage.new(:tool, result.to_pretty_json, func_name, tool_call_id: tool_calls.first?.try &.id)
next
end
func_res = executor.execute(tool_calls)

# process the function result
func_res.each_with_index do |res, index|
func_call = (tool_calls.find(tool_calls[index]) { |func| func.function.name == res.name }).function
func_call = (tool_calls.find(tool_calls[index]) { |func| func.id == res.tool_call_id }).function
case res.name
when "task_complete"
cleanup_messages(request, discardable_tokens)
Expand Down Expand Up @@ -268,8 +252,24 @@ module PlaceOS::Api
end

private def cleanup_messages(request, discardable_tokens)
keep = [] of String
request.messages.each do |mess|
calls = mess.tool_calls
next unless calls
ids = calls.compact_map do |call|
call.id if call.function.name == "task_complete"
end
keep.concat ids
end

# keep task summaries
request.messages.reject! { |mess| mess.tool_calls || (mess.role.function? && mess.name != "task_complete") }
request.messages.reject! do |mess|
if (tool_calls = mess.tool_calls) && !tool_calls.empty?
(tool_calls.map(&.id) & keep).empty?
elsif call_id = mess.tool_call_id
!keep.includes?(call_id)
end
end

# a good estimate of the total tokens once the cleanup is complete
@total_tokens = @total_tokens - discardable_tokens
Expand Down Expand Up @@ -306,6 +306,7 @@ module PlaceOS::Api
str << "my swipe card number is: #{user.card_number}\n" if user.card_number.presence
str << "my user_id is: #{user.id}\n"
str << "use these details in function calls as required.\n"
str << "if you encounter an error, check the schema, check the error message and try again. An empty response is not an error, just the absence of something.\n"
str << "perform one task at a time, making as many function calls as required to complete a task. Once a task is complete call the task_complete function with details of the progress you've made.\n"
str << "the chat client prepends the date-time each message was sent at in the following format YYYY-MM-DD HH:mm:ss +ZZ:ZZ:ZZ"
}
Expand Down

0 comments on commit 0271643

Please sign in to comment.