Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security AI] Bedrock prompt tuning and inference corrections #209011

Merged
merged 8 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

export { promptType } from './src/saved_object_mappings';
export { getPrompt, getPromptsByGroupId } from './src/get_prompt';
export { getPrompt, getPromptsByGroupId, resolveProviderAndModel } from './src/get_prompt';
export {
type PromptArray,
type Prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ export const getPrompt = async ({
return prompt;
};

const resolveProviderAndModel = async ({
export const resolveProviderAndModel = async ({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector,
}: {
providedProvider: string | undefined;
providedModel: string | undefined;
providedProvider?: string;
providedModel?: string;
connectorId: string;
actionsClient: PublicMethodsOf<ActionsClient>;
providedConnector?: Connector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => contentReferencesEnabled,
},
provider: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
};

// Default node parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ describe('streamGraph', () => {
input: 'input',
responseLanguage: 'English',
llmType: 'openai',
provider: 'openai',
connectorId: '123',
},
logger: mockLogger,
Expand Down Expand Up @@ -291,6 +292,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'gemini',
provider: 'gemini',
},
});

Expand All @@ -306,6 +308,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'bedrock',
provider: 'bedrock',
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,21 @@ export const streamGraph = async ({

// Stream is from openai functions agent
let finalMessage = '';
const stream = assistantGraph.streamEvents(inputs, {
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
});
const stream = assistantGraph.streamEvents(
inputs,
{
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
},
inputs?.provider === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
);

const pushStreamUpdate = async () => {
for await (const { event, data, tags } of stream) {
Expand All @@ -155,8 +159,6 @@ export const streamGraph = async ({
const chunk = data?.chunk;
const msg = chunk.message;
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually an important param for OpenAI, removing my comment

// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
push({ payload: msg.content, type: 'content' });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ import {
} from 'langchain/agents';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
jest.mock('@kbn/langchain/server/tracers/apm');
jest.mock('@kbn/langchain/server/tracers/telemetry');
jest.mock('@kbn/security-ai-prompts');
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
const resolveProviderAndModelMock = resolveProviderAndModel as jest.Mock;
describe('callAssistantGraph', () => {
const mockDataClients = {
anonymizationFieldsDataClient: {
Expand Down Expand Up @@ -83,6 +86,9 @@ describe('callAssistantGraph', () => {
jest.clearAllMocks();
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
getDefaultAssistantGraphMock.mockReturnValue({});
resolveProviderAndModelMock.mockResolvedValue({
provider: 'bedrock',
});
(invokeGraph as jest.Mock).mockResolvedValue({
output: 'test-output',
traceData: {},
Expand Down Expand Up @@ -224,5 +230,23 @@ describe('callAssistantGraph', () => {
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
it('does not calls resolveProviderAndModel when llmType === openai', async () => {
const params = { ...defaultParams, llmType: 'openai' };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).not.toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === inference', async () => {
const params = { ...defaultParams, llmType: 'inference' };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === undefined', async () => {
const params = { ...defaultParams, llmType: undefined };
await callAssistantGraph(params);

expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { promptGroupId } from '../../../prompt/local_prompt_object';
import { getModelOrOss } from '../../../prompt/helpers';
import { getPrompt, promptDictionary } from '../../../prompt';
Expand Down Expand Up @@ -183,6 +184,13 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
logger
)
: undefined;
const { provider } =
!llmType || llmType === 'inference'
? await resolveProviderAndModel({
connectorId,
actionsClient,
})
: { provider: llmType };
const assistantGraph = getDefaultAssistantGraph({
agentRunnable,
dataClients,
Expand All @@ -205,6 +213,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
provider: provider ?? '',
};

if (isStream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ interface ModelInputParams extends NodeParamsBase {
*/
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const hasRespondStep = state.isStream && (state.isOssModel || state.llmType === 'bedrock');
const hasRespondStep = state.isStream && (state.isOssModel || state.provider === 'bedrock');

return {
hasRespondStep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export interface GraphInputs {
isStream?: boolean;
isOssModel?: boolean;
input: string;
provider: string;
responseLanguage?: string;
}

Expand All @@ -37,6 +38,7 @@ export interface AgentState extends AgentStateBase {
isStream: boolean;
isOssModel: boolean;
llmType: string;
provider: string;
responseLanguage: string;
connectorId: string;
conversation: ConversationResponse | undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const BASE_GEMINI_PROMPT =
const KB_CATCH =
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH} {include_citations_prompt_placeholder}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response. ALWAYS return the exact response from NaturalLanguageESQLTool verbatim in the final response, without adding further description.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;

export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/
import { getDefaultArguments } from '@kbn/langchain/server';
import { StructuredTool } from '@langchain/core/tools';
import {
createOpenAIFunctionsAgent,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
Expand Down Expand Up @@ -331,26 +331,27 @@ export const postEvaluateRoute = (
savedObjectsClient,
});

const agentRunnable = isOpenAI
? await createOpenAIFunctionsAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});
const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});

return {
connectorId: connector.id,
Expand Down