Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap OpenAI streaming responses #646

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions lib/langchain/llm/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def chat(params = {}, &block)
if block
@response_chunks = []
parameters[:stream] = proc do |chunk, _bytesize|
chunk_content = chunk.dig("choices", 0)
@response_chunks << chunk
yield chunk_content
wrapped_chunk = OpenAIResponse.new(chunk)
@response_chunks << wrapped_chunk
yield wrapped_chunk
end
end

Expand All @@ -140,7 +140,7 @@ def chat(params = {}, &block)
response = response_from_chunks if block
reset_response_chunks

Langchain::LLM::OpenAIResponse.new(response)
OpenAIResponse.new(response)
end

# Generate a summary for a given text
Expand Down Expand Up @@ -182,34 +182,34 @@ def validate_max_tokens(messages, model, max_tokens = nil)
end

def response_from_chunks
grouped_chunks = @response_chunks.group_by { |chunk| chunk.dig("choices", 0, "index") }
grouped_chunks = @response_chunks.group_by { |chunk| chunk.chat_completions.dig(0, "index") }
final_choices = grouped_chunks.map do |index, chunks|
{
"index" => index,
"message" => {
"role" => "assistant",
"content" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "content") }.join,
"content" => chunks.map { |chunk| chunk.chat_completions.dig(0, "delta", "content") }.join,
"tool_calls" => tool_calls_from_choice_chunks(chunks)
}.compact,
"finish_reason" => chunks.last.dig("choices", 0, "finish_reason")
"finish_reason" => chunks.last.chat_completions.dig(0, "finish_reason")
}
end
@response_chunks.first&.slice("id", "object", "created", "model")&.merge({"choices" => final_choices})
@response_chunks.first&.raw_response&.slice("id", "object", "created", "model")&.merge({"choices" => final_choices})
end

def tool_calls_from_choice_chunks(choice_chunks)
tool_call_chunks = choice_chunks.select { |chunk| chunk.dig("choices", 0, "delta", "tool_calls") }
tool_call_chunks = choice_chunks.select { |chunk| chunk.chat_completions.dig(0, "delta", "tool_calls") }
return nil if tool_call_chunks.empty?

tool_call_chunks.group_by { |chunk| chunk.dig("choices", 0, "delta", "tool_calls", 0, "index") }.map do |index, chunks|
tool_call_chunks.group_by { |chunk| chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "index") }.map do |index, chunks|
first_chunk = chunks.first

{
"id" => first_chunk.dig("choices", 0, "delta", "tool_calls", 0, "id"),
"type" => first_chunk.dig("choices", 0, "delta", "tool_calls", 0, "type"),
"id" => first_chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "id"),
"type" => first_chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "type"),
"function" => {
"name" => first_chunk.dig("choices", 0, "delta", "tool_calls", 0, "function", "name"),
"arguments" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "tool_calls", 0, "function", "arguments") }.join
"name" => first_chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "function", "name"),
"arguments" => chunks.map { |chunk| chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "function", "arguments") }.join
}
}
end
Expand Down
21 changes: 17 additions & 4 deletions lib/langchain/llm/response/openai_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def created_at
end

def completion
completions&.dig(0, "message", "content")
completions&.dig(0, message_key, "content")
end

def role
completions&.dig(0, "message", "role")
completions&.dig(0, message_key, "role")
end

def chat_completion
completion
end

def tool_calls
if chat_completions.dig(0, "message").has_key?("tool_calls")
chat_completions.dig(0, "message", "tool_calls")
if chat_completions.dig(0, message_key).has_key?("tool_calls")
chat_completions.dig(0, message_key, "tool_calls")
else
[]
end
Expand Down Expand Up @@ -59,5 +59,18 @@ def completion_tokens
def total_tokens
raw_response.dig("usage", "total_tokens")
end

private

def message_key
done? ? "message" : "delta"
end

# Check if OpenAI response is done streaming or not
#
# @return [Boolean] true if response is done, false otherwise
def done?
!!raw_response.dig("choices", 0).has_key?("message")
end
end
end
14 changes: 6 additions & 8 deletions spec/langchain/llm/openai_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@
end

expect(response).to be_a(Langchain::LLM::OpenAIResponse)
expect(response.raw_response.dig("choices", 0, "message", "tool_calls")).to eq(expected_tool_calls)
expect(response.tool_calls).to eq(expected_tool_calls)
end
end

Expand Down Expand Up @@ -688,8 +688,8 @@
context "without tool_calls" do
let(:chunks) do
[
{"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => nil}}]},
{"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => "Hello"}}]}
Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => nil}}]}),
Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => "Hello"}}]})
]
end

Expand All @@ -710,11 +710,9 @@
{"tool_calls" => [{"index" => 0, "function" => {"arguments" => "g\"}"}}]}
]
end
let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } }
let(:chunks) { chunk_deltas.map { |delta| Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]}) } }
let(:expected_tool_calls) do
[
{"id" => "call_123456", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}}
]
[{"id" => "call_123456", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}}]
end

it "returns the tool_calls" do
Expand All @@ -740,7 +738,7 @@
{"tool_calls" => [{"index" => 1, "function" => {"arguments" => "g\"}"}}]}
]
end
let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } }
let(:chunks) { chunk_deltas.map { |delta| Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]}) } }
let(:expected_tool_calls) do
[
{"id" => "call_123", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}},
Expand Down