Skip to content

Commit

Permalink
feature-adding-max-token-size-memory | lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zivkovicn committed Jul 24, 2023
1 parent 0b144d4 commit bf100ce
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 46 deletions.
1 change: 0 additions & 1 deletion chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/stretchr/testify/require"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand Down
15 changes: 10 additions & 5 deletions chains/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"

"github.com/stretchr/testify/require"

"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/memory"
)
Expand All @@ -18,10 +17,10 @@ func TestConversation(t *testing.T) {
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}
model, err := openai.New()
llm, err := openai.New()
require.NoError(t, err)

c := NewConversation(model, memory.NewBuffer())
c := NewConversation(llm, memory.NewBuffer())
_, err = Run(context.Background(), c, "Hi! I'm Jim")
require.NoError(t, err)

Expand All @@ -31,19 +30,25 @@ func TestConversation(t *testing.T) {
}

func TestConversationMemoryPrune(t *testing.T) {
t.Parallel()

if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}

llm, err := openai.New()
require.NoError(t, err)

c := NewConversation(llm, memory.NewTokenBuffer(llm, 100, memory.WithReturnMessages(true)))
c := NewConversation(llm, memory.NewTokenBuffer(llm, 50))
_, err = Run(context.Background(), c, "Hi! I'm Jim")
require.NoError(t, err)

res, err := Run(context.Background(), c, "What is my name?")
require.NoError(t, err)
require.True(t, strings.Contains(res, "Jim"), `result does not contain the keyword 'Jim'`)
require.True(t, strings.Contains(res, "Jim"), `result does contain the keyword 'Jim'`)

// this message will hit the maxTokenLimit and will initiate the prune of the messages to fit the context
res, err = Run(context.Background(), c, "Are you sure that my name is Jim?")
require.NoError(t, err)
require.True(t, strings.Contains(res, "Jim"), `result does contain the keyword 'Jim'`)
}
3 changes: 1 addition & 2 deletions memory/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/tmc/langchaingo/schema"
)

Expand Down Expand Up @@ -87,7 +86,7 @@ func (t testChatMessageHistory) AddMessage(_ schema.ChatMessage) {
func (t testChatMessageHistory) Clear() {
}

func (t testChatMessageHistory) SetMessages(messages []schema.ChatMessage) {
func (t testChatMessageHistory) SetMessages(_ []schema.ChatMessage) {
}

func (t testChatMessageHistory) Messages() []schema.ChatMessage {
Expand Down
7 changes: 4 additions & 3 deletions memory/token_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (tb *TokenBuffer) LoadMemoryVariables(inputs map[string]any) (map[string]an
return tb.Buffer.LoadMemoryVariables(inputs)
}

// SaveContext uses Buffer method for saving context and prunes memory buffer.
// SaveContext uses Buffer method for saving context and prunes memory buffer if needed.
func (tb *TokenBuffer) SaveContext(inputValues map[string]any, outputValues map[string]any) error {
err := tb.Buffer.SaveContext(inputValues, outputValues)
if err != nil {
Expand All @@ -49,8 +49,9 @@ func (tb *TokenBuffer) SaveContext(inputValues map[string]any, outputValues map[

if currBufferLength > tb.MaxTokenLimit {
// while currBufferLength is greater than MaxTokenLimit we keep removing messages from the memory
// from the oldest
for currBufferLength > tb.MaxTokenLimit {
tb.chatHistory.SetMessages(append(tb.chatHistory.Messages()[:0], tb.chatHistory.Messages()[1:]...))
tb.ChatHistory.SetMessages(append(tb.ChatHistory.Messages()[:0], tb.ChatHistory.Messages()[1:]...))
currBufferLength, err = tb.getNumTokensFromMessages()
if err != nil {
return err
Expand All @@ -68,7 +69,7 @@ func (tb *TokenBuffer) Clear() error {

func (tb *TokenBuffer) getNumTokensFromMessages() (int, error) {
sum := 0
for _, message := range tb.chatHistory.Messages() {
for _, message := range tb.ChatHistory.Messages() {
bufferString, err := schema.GetBufferString([]schema.ChatMessage{message}, tb.Buffer.HumanPrefix, tb.Buffer.AIPrefix)
if err != nil {
return 0, err
Expand Down
35 changes: 0 additions & 35 deletions memory/token_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
)
Expand Down Expand Up @@ -79,37 +78,3 @@ func TestTokenBufferMemoryWithPreLoadedHistory(t *testing.T) {
expected := map[string]any{"history": "Human: bar\nAI: foo"}
assert.Equal(t, expected, result)
}

func TestTokenBufferMemoryPrune(t *testing.T) {
t.Parallel()

llm, err := openai.New()
require.NoError(t, err)

m := NewTokenBuffer(llm, 20, WithChatHistory(NewChatMessageHistory(
WithPreviousMessages([]schema.ChatMessage{
schema.HumanChatMessage{Text: "human message test for max token"},
schema.AIChatMessage{Text: "ai message test for max token"},
}),
)))

buffStringMsg1, err := schema.GetBufferString([]schema.ChatMessage{
schema.HumanChatMessage{Text: "human message test for max token"},
}, "Human", "AI")
require.NoError(t, err)
tokenNumMsg1 := m.LLM.GetNumTokens(buffStringMsg1)
assert.Equal(t, 9, tokenNumMsg1)

buffStringMsg2, err := schema.GetBufferString([]schema.ChatMessage{
schema.AIChatMessage{Text: "ai message test for max token"},
}, "Human", "AI")
require.NoError(t, err)
tokenNumMsg2 := m.LLM.GetNumTokens(buffStringMsg2)
assert.Equal(t, 8, tokenNumMsg2)

assert.Equal(t, tokenNumMsg1+tokenNumMsg2, 17)

_, err = llm.Call()
require.NoError(t, err)

}

0 comments on commit bf100ce

Please sign in to comment.