diff --git a/lib/langchain/llm/openai.rb b/lib/langchain/llm/openai.rb index da573557d..0767ceea6 100644 --- a/lib/langchain/llm/openai.rb +++ b/lib/langchain/llm/openai.rb @@ -127,9 +127,9 @@ def chat(params = {}, &block) if block @response_chunks = [] parameters[:stream] = proc do |chunk, _bytesize| - chunk_content = chunk.dig("choices", 0) - @response_chunks << chunk - yield chunk_content + wrapped_chunk = OpenAIResponse.new(chunk) + @response_chunks << wrapped_chunk + yield wrapped_chunk end end @@ -140,7 +140,7 @@ def chat(params = {}, &block) response = response_from_chunks if block reset_response_chunks - Langchain::LLM::OpenAIResponse.new(response) + OpenAIResponse.new(response) end # Generate a summary for a given text @@ -182,34 +182,34 @@ def validate_max_tokens(messages, model, max_tokens = nil) end def response_from_chunks - grouped_chunks = @response_chunks.group_by { |chunk| chunk.dig("choices", 0, "index") } + grouped_chunks = @response_chunks.group_by { |chunk| chunk.chat_completions.dig(0, "index") } final_choices = grouped_chunks.map do |index, chunks| { "index" => index, "message" => { "role" => "assistant", - "content" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "content") }.join, + "content" => chunks.map { |chunk| chunk.chat_completions.dig(0, "delta", "content") }.join, "tool_calls" => tool_calls_from_choice_chunks(chunks) }.compact, - "finish_reason" => chunks.last.dig("choices", 0, "finish_reason") + "finish_reason" => chunks.last.chat_completions.dig(0, "finish_reason") } end - @response_chunks.first&.slice("id", "object", "created", "model")&.merge({"choices" => final_choices}) + @response_chunks.first&.raw_response&.slice("id", "object", "created", "model")&.merge({"choices" => final_choices}) end def tool_calls_from_choice_chunks(choice_chunks) - tool_call_chunks = choice_chunks.select { |chunk| chunk.dig("choices", 0, "delta", "tool_calls") } + tool_call_chunks = choice_chunks.select { |chunk| chunk.chat_completions.dig(0, "delta", "tool_calls") } return nil if tool_call_chunks.empty? - tool_call_chunks.group_by { |chunk| chunk.dig("choices", 0, "delta", "tool_calls", 0, "index") }.map do |index, chunks| + tool_call_chunks.group_by { |chunk| chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "index") }.map do |index, chunks| first_chunk = chunks.first { - "id" => first_chunk.dig("choices", 0, "delta", "tool_calls", 0, "id"), - "type" => first_chunk.dig("choices", 0, "delta", "tool_calls", 0, "type"), + "id" => first_chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "id"), + "type" => first_chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "type"), "function" => { - "name" => first_chunk.dig("choices", 0, "delta", "tool_calls", 0, "function", "name"), - "arguments" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "tool_calls", 0, "function", "arguments") }.join + "name" => first_chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "function", "name"), + "arguments" => chunks.map { |chunk| chunk.chat_completions.dig(0, "delta", "tool_calls", 0, "function", "arguments") }.join } } end diff --git a/lib/langchain/llm/response/openai_response.rb b/lib/langchain/llm/response/openai_response.rb index 0e3006855..3d911fb72 100644 --- a/lib/langchain/llm/response/openai_response.rb +++ b/lib/langchain/llm/response/openai_response.rb @@ -13,11 +13,11 @@ def created_at end def completion - completions&.dig(0, "message", "content") + completions&.dig(0, message_key, "content") end def role - completions&.dig(0, "message", "role") + completions&.dig(0, message_key, "role") end def chat_completion @@ -25,8 +25,8 @@ def chat_completion end def tool_calls - if chat_completions.dig(0, "message").has_key?("tool_calls") - chat_completions.dig(0, "message", "tool_calls") + if chat_completions.dig(0, message_key).has_key?("tool_calls") + chat_completions.dig(0, message_key, "tool_calls") else [] end @@ -59,5 +59,18 @@ def completion_tokens def total_tokens raw_response.dig("usage", "total_tokens") end + + private + + def message_key + done? ? "message" : "delta" + end + + # Check if OpenAI response is done streaming or not + # + # @return [Boolean] true if response is done, false otherwise + def done? + !!raw_response.dig("choices", 0).has_key?("message") + end end end diff --git a/spec/langchain/llm/openai_spec.rb b/spec/langchain/llm/openai_spec.rb index 33945ee67..82c297a60 100644 --- a/spec/langchain/llm/openai_spec.rb +++ b/spec/langchain/llm/openai_spec.rb @@ -647,7 +647,7 @@ end expect(response).to be_a(Langchain::LLM::OpenAIResponse) - expect(response.raw_response.dig("choices", 0, "message", "tool_calls")).to eq(expected_tool_calls) + expect(response.tool_calls).to eq(expected_tool_calls) end end @@ -688,8 +688,8 @@ context "without tool_calls" do let(:chunks) do [ - {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => nil}}]}, - {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => "Hello"}}]} + Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => nil}}]}), + Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => "Hello"}}]}) ] end @@ -710,11 +710,9 @@ {"tool_calls" => [{"index" => 0, "function" => {"arguments" => "g\"}"}}]} ] end - let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } } + let(:chunks) { chunk_deltas.map { |delta| Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]}) } } let(:expected_tool_calls) do - [ - {"id" => "call_123456", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}} - ] + [{"id" => "call_123456", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}}] end it "returns the tool_calls" do @@ -740,7 +738,7 @@ {"tool_calls" => [{"index" => 1, "function" => {"arguments" => "g\"}"}}]} ] end - let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } } + let(:chunks) { chunk_deltas.map { |delta| Langchain::LLM::OpenAIResponse.new({"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]}) } } let(:expected_tool_calls) do [ {"id" => "call_123", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}},