Skip to content

Commit

Permalink
Assistant fixes (#780)
Browse files Browse the repository at this point in the history
* Resetting instructions on Langchain::Assistant with Google Gemini no longer throws an error.
* Throw an error when add_message_callback is not a callable proc.
  • Loading branch information
andreibondarev authored Sep 16, 2024
1 parent ad55da1 commit 4e6c274
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 82 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ assistant = Langchain::Assistant.new(
)

# Add a user message and run the assistant
assistant.add_message_and_run(content: "What's the latest news about AI?")
assistant.add_message_and_run!(content: "What's the latest news about AI?")

# Access the conversation thread
messages = assistant.messages
Expand Down
37 changes: 29 additions & 8 deletions lib/langchain/assistants/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ class Assistant

attr_reader :llm, :thread, :instructions, :state, :llm_adapter, :tool_choice
attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens
attr_accessor :tools
attr_accessor :tools, :add_message_callback

# Create a new assistant
#
# @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
# @param thread [Langchain::Thread] The thread that'll keep track of the conversation
# @param tools [Array<Langchain::Tool::Base>] Tools that the assistant has access to
# @param instructions [String] The system instructions to include in the thread
# @param tool_choice [String] Specify how tools should be selected. Options: "auto", "any", "none", or <specific function name>
# @params add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
def initialize(
llm:,
thread: nil,
Expand All @@ -41,6 +43,11 @@ def initialize(
@llm_adapter = LLM::Adapter.build(llm)

@thread = thread || Langchain::Thread.new

# TODO: Validate that it is, indeed, a Proc or lambda
if !add_message_callback.nil? && !add_message_callback.respond_to?(:call)
raise ArgumentError, "add_message_callback must be a callable object, like Proc or lambda"
end
@thread.add_message_callback = add_message_callback

@tools = tools
Expand Down Expand Up @@ -157,26 +164,40 @@ def clear_thread!

# Set new instructions
#
# @param [String] New instructions that will be set as a system message
# @param new_instructions [String] New instructions that will be set as a system message
# @return [Array<Langchain::Message>] The messages in the thread
def instructions=(new_instructions)
@instructions = new_instructions

# Find message with role: "system" in thread.messages and delete it from the thread.messages array
thread.messages.delete_if(&:system?)

# Set new instructions by adding new system message
message = build_message(role: "system", content: new_instructions)
thread.messages.unshift(message)
# This only needs to be done that support Message#@role="system"
if !llm.is_a?(Langchain::LLM::GoogleGemini) && !llm.is_a?(Langchain::LLM::Anthropic)
# Find message with role: "system" in thread.messages and delete it from the thread.messages array
replace_system_message!(content: new_instructions)
end
end

# Set tool_choice, how tools should be selected
#
# @param new_tool_choice [String] Tool choice
# @return [String] Selected tool choice
def tool_choice=(new_tool_choice)
validate_tool_choice!(new_tool_choice)
@tool_choice = new_tool_choice
end

private

# Replace old system message with new one
#
# @param content [String] New system message content
# @return [Array<Langchain::Message>] The messages in the thread
def replace_system_message!(content:)
thread.messages.delete_if(&:system?)

message = build_message(role: "system", content: content)
thread.messages.unshift(message)
end

# TODO: If tool_choice = "tool_function_name" and then tool is removed from the assistant, should we set tool_choice back to "auto"?
def validate_tool_choice!(tool_choice)
allowed_tool_choices = llm_adapter.allowed_tool_choices.concat(available_tool_names)
Expand Down
157 changes: 84 additions & 73 deletions spec/langchain/assistants/assistant_spec.rb
Original file line number Diff line number Diff line change
@@ -1,31 +1,71 @@
# frozen_string_literal: true

RSpec.describe Langchain::Assistant do
context "when llm is OpenAI" do
context "initialization" do
let(:llm) { Langchain::LLM::OpenAI.new(api_key: "123") }
let(:calculator) { Langchain::Tool::Calculator.new }
let(:instructions) { "You are an expert assistant" }

subject {
described_class.new(
llm: llm,
tools: [calculator],
instructions: instructions
)
}

it "raises an error if tools array contains non-Langchain::Tool instance(s)" do
expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError)
end

describe "#add_message_callback" do
it "raises an error if the callback is not a Proc" do
expect { described_class.new(llm: llm, add_message_callback: "foo") }.to raise_error(ArgumentError)
end

it "does not raise an error if the callback is a Proc" do
expect { described_class.new(llm: llm, add_message_callback: -> {}) }.not_to raise_error
end
end

it "raises an error if thread is not an instance of Langchain::Thread" do
expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError)
end

it "raises an error if LLM class does not implement `chat()` method" do
llm = Langchain::LLM::Replicate.new(api_key: "123")
expect { described_class.new(llm: llm) }.to raise_error(ArgumentError)
end

it "raises an error if thread is not an instance of Langchain::Thread" do
expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError)
it "sets new thread if thread is not provided" do
subject = described_class.new(llm: llm)
expect(subject.thread).to be_a(Langchain::Thread)
end
end

context "methods" do
let(:llm) { Langchain::LLM::OpenAI.new(api_key: "123") }

describe "#clear_thread!" do
it "clears the thread" do
assistant = described_class.new(llm: llm)
assistant.add_message(content: "foo")
expect { assistant.clear_thread! }.to change { assistant.messages.count }.from(1).to(0)
end
end

describe "#replace_system_message!" do
it "replaces the system message" do
assistant = described_class.new(llm: llm)
assistant.add_message(content: "foo")
assistant.send(:replace_system_message!, content: "bar")
expect(assistant.messages.first.content).to eq("bar")
end
end
end

context "when llm is OpenAI" do
let(:llm) { Langchain::LLM::OpenAI.new(api_key: "123") }
let(:calculator) { Langchain::Tool::Calculator.new }
let(:instructions) { "You are an expert assistant" }

subject {
described_class.new(
llm: llm,
tools: [calculator],
instructions: instructions
)
}

describe "#initialize" do
it "adds a system message to the thread" do
Expand All @@ -34,11 +74,6 @@
expect(assistant.messages.first.content).to eq("You are an expert assistant")
end

it "sets new thread if thread is not provided" do
subject = described_class.new(llm: llm, instructions: instructions)
expect(subject.thread).to be_a(Langchain::Thread)
end

it "the system message always comes first" do
thread = Langchain::Thread.new
system_message = Langchain::Messages::OpenAIMessage.new(role: "system", content: "System message")
Expand All @@ -60,7 +95,7 @@
end

it "calls the add_message_callback with the message" do
callback = double("callback")
callback = double("callback", call: true)
thread = described_class.new(llm: llm, instructions: instructions, add_message_callback: callback)

expect(callback).to receive(:call).with(instance_of(Langchain::Messages::OpenAIMessage))
Expand Down Expand Up @@ -340,7 +375,7 @@
end
end

context "tool_choice" do
describe "tool_choice" do
it "initiliazes to 'auto' by default" do
expect(subject.tool_choice).to eq("auto")
end
Expand All @@ -360,6 +395,14 @@
expect { subject.tool_choice = "invalid_choice" }.to raise_error(ArgumentError)
end
end

describe "#instructions=" do
it "resets instructions" do
subject.instructions = "New instructions"
expect(subject.messages.first.content).to eq("New instructions")
expect(subject.instructions).to eq("New instructions")
end
end
end

context "when llm is MistralAI" do
Expand All @@ -375,31 +418,13 @@
)
}

it "raises an error if tools array contains non-Langchain::Tool instance(s)" do
expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError)
end

it "raises an error if LLM class does not implement `chat()` method" do
llm = Langchain::LLM::Replicate.new(api_key: "123")
expect { described_class.new(llm: llm) }.to raise_error(ArgumentError)
end

it "raises an error if thread is not an instance of Langchain::Thread" do
expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError)
end

describe "#initialize" do
it "adds a system message to the thread" do
described_class.new(llm: llm, instructions: instructions)
expect(subject.messages.first.role).to eq("system")
expect(subject.messages.first.content).to eq("You are an expert assistant")
end

it "sets new thread if thread is not provided" do
subject = described_class.new(llm: llm, instructions: instructions)
expect(subject.thread).to be_a(Langchain::Thread)
end

it "the system message always comes first" do
thread = Langchain::Thread.new
system_message = Langchain::Messages::OpenAIMessage.new(role: "system", content: "System message")
Expand All @@ -421,7 +446,7 @@
end

it "calls the add_message_callback with the message" do
callback = double("callback")
callback = double("callback", call: true)
thread = described_class.new(llm: llm, instructions: instructions, add_message_callback: callback)

expect(callback).to receive(:call).with(instance_of(Langchain::Messages::MistralAIMessage))
Expand Down Expand Up @@ -701,7 +726,7 @@
end
end

context "tool_choice" do
describe "tool_choice" do
it "initiliazes to 'auto' by default" do
expect(subject.tool_choice).to eq("auto")
end
Expand All @@ -721,6 +746,14 @@
expect { subject.tool_choice = "invalid_choice" }.to raise_error(ArgumentError)
end
end

describe "#instructions=" do
it "resets instructions" do
subject.instructions = "New instructions"
expect(subject.messages.first.content).to eq("New instructions")
expect(subject.instructions).to eq("New instructions")
end
end
end

context "when llm is GoogleGemini" do
Expand All @@ -736,19 +769,6 @@
)
}

it "raises an error if tools array contains non-Langchain::Tool instance(s)" do
expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError)
end

it "raises an error if LLM class does not implement `chat()` method" do
llm = Langchain::LLM::Replicate.new(api_key: "123")
expect { described_class.new(llm: llm) }.to raise_error(ArgumentError)
end

it "raises an error if thread is not an instance of Langchain::Thread" do
expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError)
end

describe "#add_message" do
it "adds a message to the thread" do
subject.add_message(content: "foo")
Expand All @@ -757,7 +777,7 @@
end

it "calls the add_message_callback with the message" do
callback = double("callback")
callback = double("callback", call: true)
thread = described_class.new(llm: llm, instructions: instructions, add_message_callback: callback)

expect(callback).to receive(:call).with(instance_of(Langchain::Messages::GoogleGeminiMessage))
Expand Down Expand Up @@ -892,7 +912,7 @@
end
end

context "tool_choice" do
describe "tool_choice" do
it "initiliazes to 'auto' by default" do
expect(subject.tool_choice).to eq("auto")
end
Expand All @@ -912,6 +932,14 @@
expect { subject.tool_choice = "invalid_choice" }.to raise_error(ArgumentError)
end
end

describe "#instructions=" do
it "resets instructions" do
subject.instructions = "New instructions"
expect(subject).not_to receive(:replace_system_message!)
expect(subject.instructions).to eq("New instructions")
end
end
end

context "when llm is Anthropic" do
Expand All @@ -927,19 +955,6 @@
)
}

it "raises an error if tools array contains non-Langchain::Tool instance(s)" do
expect { described_class.new(tools: [Langchain::Tool::Calculator.new, "foo"]) }.to raise_error(ArgumentError)
end

it "raises an error if LLM class does not implement `chat()` method" do
llm = Langchain::LLM::Replicate.new(api_key: "123")
expect { described_class.new(llm: llm) }.to raise_error(ArgumentError)
end

it "raises an error if thread is not an instance of Langchain::Thread" do
expect { described_class.new(thread: "foo") }.to raise_error(ArgumentError)
end

describe "#add_message" do
it "adds a message to the thread" do
subject.add_message(content: "foo")
Expand All @@ -948,7 +963,7 @@
end

it "calls the add_message_callback with the message" do
callback = double("callback")
callback = double("callback", call: true)
thread = described_class.new(llm: llm, instructions: instructions, add_message_callback: callback)

expect(callback).to receive(:call).with(instance_of(Langchain::Messages::AnthropicMessage))
Expand Down Expand Up @@ -1151,10 +1166,6 @@
end
end

xdescribe "#clear_thread!"

xdescribe "#instructions="

xdescribe "when llm is Ollama" do
xdescribe "#set_state_for" do
xcontext "when response contains completion" do
Expand Down

0 comments on commit 4e6c274

Please sign in to comment.