Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
abeatrix committed Mar 6, 2025
1 parent b776e63 commit 25a6bfe
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 22 deletions.
215 changes: 213 additions & 2 deletions lib/shared/src/chat/chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import { describe, expect, it } from 'vitest'

import { type Mock, afterEach, beforeEach, vi } from 'vitest'
import { AUTH_STATUS_FIXTURE_AUTHED, graphqlClient } from '..'
import { mockAuthStatus } from '../auth/authStatus'
import { ps } from '../prompt/prompt-string'
import type { Message } from '../sourcegraph-api'
import { sanitizeMessages } from './chat'
import * as siteVersionModule from '../sourcegraph-api/siteVersion'
import { ChatClient, buildChatRequestParams, sanitizeMessages } from './chat'

const hello = ps`Hello`
const hiThere = ps`Hi there!`
const isAnyoneThere = ps`Is anyone there?`
const followUpQuestion = ps`Can you explain more?`

describe('sanitizeMessages', () => {
it('removes empty assistant messages and the human question before it', () => {
Expand Down Expand Up @@ -51,3 +55,210 @@ describe('sanitizeMessages', () => {
expect(result).toEqual(messages)
})
})

describe('buildChatRequestParams', () => {
it('sets apiVersion to 0 for Claude models older than 3.5', () => {
const result = buildChatRequestParams({
model: 'claude-2-sonnet',
codyAPIVersion: 8,
isFireworksTracingEnabled: false,
})

expect(result.apiVersion).toBe(0)
expect(result.customHeaders).toEqual({})
})

it('keeps default apiVersion for Claude models 3.5 or newer', () => {
const result = buildChatRequestParams({
model: 'claude-3-5-sonnet',
codyAPIVersion: 8,
isFireworksTracingEnabled: false,
})

expect(result.apiVersion).toBe(8)
expect(result.customHeaders).toEqual({})
})

it('adds X-Fireworks-Genie header for Fireworks models with tracing enabled', () => {
const result = buildChatRequestParams({
model: 'fireworks/model',
codyAPIVersion: 8,
isFireworksTracingEnabled: true,
})

expect(result.apiVersion).toBe(8)
expect(result.customHeaders).toEqual({ 'X-Fireworks-Genie': 'true' })
})

it('passes through interactionId when provided', () => {
const result = buildChatRequestParams({
model: 'model-name',
codyAPIVersion: 8,
isFireworksTracingEnabled: false,
interactionId: 'test-interaction-id',
})

expect(result.interactionId).toBe('test-interaction-id')
})
})

// Add this test suite after existing describe blocks
describe('ChatClient.chat', () => {
let chatClient: ChatClient
let mockCompletions: { stream: Mock }

beforeEach(() => {
mockAuthStatus(AUTH_STATUS_FIXTURE_AUTHED)

// Mock inferCodyApiVersion to return a specific version
vi.spyOn(siteVersionModule, 'inferCodyApiVersion').mockReturnValue(8)

// Mock currentSiteVersion to return a consistent object with your desired codyAPIVersion
vi.spyOn(siteVersionModule, 'currentSiteVersion').mockResolvedValue({
siteVersion: '1.2.3',
codyAPIVersion: 8,
})

// Mock stream method that returns an async generator
mockCompletions = {
stream: vi.fn().mockImplementation(async function* () {
yield { text: 'mocked response' }
}),
}

chatClient = new ChatClient(mockCompletions as any)

vi.spyOn(graphqlClient, 'getSiteVersion').mockResolvedValue('1.2.3')
})

afterEach(() => {
vi.restoreAllMocks()
})

it('streams chat completion with standard parameters', async () => {
const messages: Message[] = [
{ speaker: 'human', text: hello },
{ speaker: 'assistant', text: hiThere },
]

const params = {
maxTokensToSample: 2000,
model: 'anthropic/claude-3-sonnet',
}

const generator = await chatClient.chat(messages, params)
const firstResponse = await generator.next()

expect(mockCompletions.stream).toHaveBeenCalledWith(
expect.objectContaining({
messages: [
{ speaker: 'human', text: hello, cacheEnabled: undefined, content: undefined },
{ speaker: 'assistant', text: hiThere },
],
maxTokensToSample: 2000,
model: 'anthropic/claude-3-sonnet',
temperature: 0.2,
topK: -1,
topP: -1,
}),
expect.objectContaining({
apiVersion: 0,
customHeaders: {},
interactionId: undefined,
}),
undefined
)

expect(firstResponse.value).toEqual({ text: 'mocked response' })
})

it('throws error when not authenticated', async () => {
mockAuthStatus({ ...AUTH_STATUS_FIXTURE_AUTHED, authenticated: false })

const messages: Message[] = [{ speaker: 'human', text: hello }]
const params = {
maxTokensToSample: 1000,
model: 'anthropic/claude-3-sonnet',
}

await expect(chatClient.chat(messages, params)).rejects.toThrow('not authenticated')
})

it('appends empty assistant message for older API versions when last message is human', async () => {
vi.spyOn(graphqlClient, 'getSiteVersion').mockResolvedValue('1.2.3')

const messages: Message[] = [
{ speaker: 'human', text: hello },
{ speaker: 'assistant', text: hiThere },
{ speaker: 'human', text: followUpQuestion },
]

const params = {
maxTokensToSample: 1000,
model: 'claude-2-sonnet',
}

await chatClient.chat(messages, params)

expect(mockCompletions.stream).toHaveBeenCalledWith(
expect.objectContaining({
messages: [
{ speaker: 'human', text: hello },
{ speaker: 'assistant', text: hiThere },
{ speaker: 'human', text: followUpQuestion },
{ speaker: 'assistant' },
],
}),
expect.any(Object),
undefined
)
})

it('passes through abort signal and interaction ID', async () => {
const messages: Message[] = [{ speaker: 'human', text: hello }]
const params = {
maxTokensToSample: 1000,
model: 'anthropic/claude-3-sonnet',
}

const abortController = new AbortController()
const interactionId = 'test-interaction-id'

await chatClient.chat(messages, params, abortController.signal, interactionId)

expect(mockCompletions.stream).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({
interactionId: 'test-interaction-id',
}),
abortController.signal
)
})

it('sanitizes messages before sending them', async () => {
const messagesWithEmpty: Message[] = [
{ speaker: 'human', text: hello },
{ speaker: 'assistant', text: ps`` }, // Empty assistant message
{ speaker: 'human', text: followUpQuestion },
{ speaker: 'assistant', text: ps`` },
]

const params = {
maxTokensToSample: 1000,
model: 'anthropic/claude-3.5-sonnet',
cacheEnabled: undefined,
content: undefined,
}

await chatClient.chat(messagesWithEmpty, params)

// Expect sanitized messages (first human message and empty assistant removed)
expect(mockCompletions.stream).toHaveBeenCalledWith(
expect.objectContaining({
messages: [{ speaker: 'human', text: followUpQuestion }],
}),
expect.any(Object),
undefined
)
})
})
84 changes: 64 additions & 20 deletions lib/shared/src/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ const DEFAULT_CHAT_COMPLETION_PARAMETERS: Omit<ChatParameters, 'maxTokensToSampl
topP: -1,
}

const claudeRegex = /claude-(\d+\.\d+)-/

export class ChatClient {
constructor(private completions: SourcegraphCompletionsClient) {}

Expand Down Expand Up @@ -48,26 +46,15 @@ export class ChatClient {
throw new Error('not authenticated')
}

const requestParams = {
apiVersion: versions.codyAPIVersion,
const requestParams = buildChatRequestParams({
model: params.model,
codyAPIVersion: versions.codyAPIVersion,
isFireworksTracingEnabled: !!authStatus_.isFireworksTracingEnabled,
interactionId,
customHeaders: {},
}
})

// TODO: We should probably do this check on prompt building instead of here?
const isClaude = params.model?.match(claudeRegex)?.[1]
const isFireworks = params?.model?.startsWith('fireworks')

// Enabled Fireworks tracing for Sourcegraph teammates.
// https://readme.fireworks.ai/docs/enabling-tracing
if (isFireworks && authStatus_.isFireworksTracingEnabled) {
requestParams.customHeaders = { 'X-Fireworks-Genie': 'true' }
messages = sanitizeMessages(messages)
} else if (isClaude && Number.parseFloat(isClaude) < 3.5) {
// Set api version to 0 (unversion) for Claude models older than 3.5.
// Example: claude-3-haiku or claude-2-sonnet or claude-2.1-instant v.s. claude-3-5-haiku or 3.5-haiku or 3-7-haiku
requestParams.apiVersion = 0
}
// Sanitize messages before sending them to the completions API.
messages = sanitizeMessages(messages)

// Older models or API versions look for prepended assistant messages.
if (requestParams.apiVersion === 0 && messages.at(-1)?.speaker === 'human') {
Expand All @@ -91,6 +78,18 @@ export class ChatClient {
}
}

/**
* Sanitizes an array of conversation messages to ensure proper formatting for model processing.
*
* Performs three cleaning operations:
* 1. Removes trailing empty assistant messages
* 2. Removes pairs of messages where an assistant message in the middle has empty content
* (also removes the preceding message that prompted the empty response)
* 3. Trims trailing whitespace from the final assistant message
*
* @param messages - The array of Message objects representing the conversation
* @returns A new array with sanitized messages
*/
export function sanitizeMessages(messages: Message[]): Message[] {
let sanitizedMessages = messages

Expand Down Expand Up @@ -131,3 +130,48 @@ export function sanitizeMessages(messages: Message[]): Message[] {

return sanitizedMessages
}

// Check if model is Claude and extract version
// It should capture the numbers between "claude-" and the "-" after the digits
// It should take in the form of "claude-3.5-haiku" or "claude-3-5-haiku" or "claude-2-1-sonnet" or "claude-2.1-instant" or "claude-2-instant"
// And then turn it into "3.5" or "3.5" or "2.1" or "2.1" or "2"
const claudeRegex = /claude-([\d.-]+)-[^-]*$/

/**
* Builds the request parameters for the chat API.
*
* @param options - The options for building the chat request parameters.
* @returns The request parameters for the chat API.
*/
export function buildChatRequestParams({
model,
codyAPIVersion,
isFireworksTracingEnabled,
interactionId,
}: {
model?: string
codyAPIVersion: number
isFireworksTracingEnabled: boolean
interactionId?: string
}): { apiVersion: number; interactionId?: string; customHeaders: Record<string, string> } {
const requestParams = { apiVersion: codyAPIVersion, interactionId, customHeaders: {} }

const isClaude = model?.match(claudeRegex)
const claudeVersion = Number.parseFloat(isClaude?.[1]?.replace(/-/g, '.') ?? '3.5')
const isFireworks = model?.startsWith('fireworks')

// Enabled Fireworks tracing for Sourcegraph teammates.
// https://readme.fireworks.ai/docs/enabling-tracing
if (isFireworks && isFireworksTracingEnabled) {
requestParams.customHeaders = { 'X-Fireworks-Genie': 'true' }
}

// Set api version to 0 (unversion) for Claude models older than 3.5.
// E.g. claude-3-haiku or claude-2-sonnet or claude-2.1-instant v.s. claude-3-5-haiku or 3.5-haiku or 3-7-haiku
if (codyAPIVersion > 0 && claudeVersion < 3.5) {
// Set api version to 0 (unversion) for Claude models older than 3.5
requestParams.apiVersion = 0
}

return requestParams
}

0 comments on commit 25a6bfe

Please sign in to comment.