Skip to content

Commit

Permalink
[Security Solution] [AI Assistant] Clean up content references code (e…
Browse files Browse the repository at this point in the history
…lastic#208902)

## Summary

This PR addressed the remaining comments left in:
elastic#206683. This PR does not contain
any material changes. It is just fixing some types and variable naming.

Changes:
- Fix the
[type](https://github.com/elastic/kibana/pull/208902/files#diff-9f3f1c92910d7207ed15dd7bc3289d0a8a6bd7f656584fce33cfbad40823a32bL52)
of the optional content reference store. Once the feature flag is
removed, the content reference store will no longer be optional.
- Rename `contentReferencesStoreFactory()` to
`newContentReferencesStore()` because it is not actually a factory
method and was named poorly.
- Update [structured system
prompt](https://github.com/elastic/kibana/pull/208902/files#diff-1efcb0cc37b72d43ee9ff1036fad33f143c577a9c9818e3c8ace2efbfc9e64b0R26)
to include instructions for citations too.

### Checklist

Check the PR satisfies following conditions. 

Reviewers should verify this PR satisfies this list as well.

- [X] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md)
- [X]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [X] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [X] If a plugin configuration key changed, check if it needs to be
allowlisted in the cloud and added to the [docker
list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)
- [X] This was checked for breaking HTTP API changes, and any breaking
changes have been approved by the breaking-change committee. The
`release_note:breaking` label should be applied in these situations.
- [X] [Flaky Test
Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was
used on any tests changed
- [X] The PR description includes the appropriate Release Notes section,
and the correct `release_note:*` label is applied per the
[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)

### Identify risks

Does this PR introduce any risks? For example, consider risks like hard
to test bugs, performance regression, potential of data loss.

Describe the risk, its severity, and mitigation for each identified
risk. Invite stakeholders and evaluate how to proceed before merging.

- [ ] [See some risk
examples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx)
- [ ] ...

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
KDKHD and elasticmachine authored Feb 3, 2025
1 parent 4e0c0a7 commit 0f62fa1
Show file tree
Hide file tree
Showing 25 changed files with 53 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import { ContentReferencesStore } from '../../types';

export const contentReferencesStoreFactoryMock: () => ContentReferencesStore = jest
export const newContentReferencesStoreMock: () => ContentReferencesStore = jest
.fn()
.mockReturnValue({
add: jest.fn().mockImplementation((creator: Parameters<ContentReferencesStore['add']>[0]) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
* 2.0.
*/

import { contentReferencesStoreFactory } from './content_references_store_factory';
import { newContentReferencesStore } from './content_references_store';
import { securityAlertsPageReference } from '../references';
import { ContentReferencesStore } from '../types';

describe('contentReferencesStoreFactory', () => {
describe('newContentReferencesStore', () => {
let contentReferencesStore: ContentReferencesStore;
beforeEach(() => {
contentReferencesStore = contentReferencesStoreFactory();
contentReferencesStore = newContentReferencesStore();
});

it('adds multiple content reference', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const CONTENT_REFERENCE_ID_ALPHABET =
/**
* Creates a new ContentReferencesStore used for storing references (also known as citations)
*/
export const contentReferencesStoreFactory: () => ContentReferencesStore = () => {
export const newContentReferencesStore: () => ContentReferencesStore = () => {
const store = new Map<string, ContentReference>();

const add: ContentReferencesStore['add'] = (creator) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import { pruneContentReferences } from './prune_content_references';
import { securityAlertsPageReference } from '../references';
import { contentReferenceBlock } from '../references/utils';
import { ContentReferencesStore } from '../types';
import { contentReferencesStoreFactory } from './content_references_store_factory';
import { newContentReferencesStore } from './content_references_store';

describe('pruneContentReferences', () => {
let contentReferencesStore: ContentReferencesStore;
beforeEach(() => {
contentReferencesStore = contentReferencesStoreFactory();
contentReferencesStore = newContentReferencesStore();
});

it('prunes content references correctly', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

export { contentReferencesStoreFactory } from './content_references_store/content_references_store_factory';
export { newContentReferencesStore } from './content_references_store/content_references_store';
export { pruneContentReferences } from './content_references_store/prune_content_references';
export {
securityAlertReference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export {
} from './impl/data_anonymization/helpers';

export {
contentReferencesStoreFactory,
newContentReferencesStore,
securityAlertReference,
knowledgeBaseReference,
securityAlertsPageReference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
EsqlContentReference,
IndexEntry,
} from '@kbn/elastic-assistant-common';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';

// Mock dependencies
jest.mock('@elastic/elasticsearch');
Expand Down Expand Up @@ -149,7 +149,7 @@ describe('getStructuredToolForIndexEntry', () => {
const mockEsClient = {} as ElasticsearchClient;

const mockIndexEntry = getCreateKnowledgeBaseEntrySchemaMock({ type: 'index' }) as IndexEntry;
const contentReferencesStore = contentReferencesStoreFactoryMock();
const contentReferencesStore = newContentReferencesStoreMock();

it('should return a DynamicStructuredTool with correct name and schema', () => {
const tool = getStructuredToolForIndexEntry({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ export const getStructuredToolForIndexEntry = ({
}: {
indexEntry: IndexEntry;
esClient: ElasticsearchClient;
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
logger: Logger;
}): DynamicStructuredTool => {
const inputSchema = indexEntry.inputSchema?.reduce((prev, input) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import {
getSecurityLabsDocsCount,
} from '../../lib/langchain/content_loaders/security_labs_loader';
import { DynamicStructuredTool } from '@langchain/core/tools';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
jest.mock('../../lib/langchain/content_loaders/security_labs_loader');
jest.mock('p-retry');
const date = '2023-03-28T22:27:28.159Z';
Expand Down Expand Up @@ -522,7 +522,7 @@ describe('AIAssistantKnowledgeBaseDataClient', () => {

const result = await client.getAssistantTools({
esClient: esClientMock,
contentReferencesStore: contentReferencesStoreFactoryMock(),
contentReferencesStore: newContentReferencesStoreMock(),
});

expect(result).toHaveLength(1);
Expand All @@ -537,7 +537,7 @@ describe('AIAssistantKnowledgeBaseDataClient', () => {

const result = await client.getAssistantTools({
esClient: esClientMock,
contentReferencesStore: contentReferencesStoreFactoryMock(),
contentReferencesStore: newContentReferencesStoreMock(),
});

expect(result).toEqual([]);
Expand All @@ -550,7 +550,7 @@ describe('AIAssistantKnowledgeBaseDataClient', () => {

const result = await client.getAssistantTools({
esClient: esClientMock,
contentReferencesStore: contentReferencesStoreFactoryMock(),
contentReferencesStore: newContentReferencesStoreMock(),
});

expect(result).toEqual([]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ export class AIAssistantKnowledgeBaseDataClient extends AIAssistantDataClient {
contentReferencesStore,
esClient,
}: {
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
esClient: ElasticsearchClient;
}): Promise<StructuredTool[]> => {
const user = this.options.currentUser;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export interface AgentExecutorParams<T extends boolean> {
assistantTools?: AssistantTool[];
connectorId: string;
conversationId?: string;
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
dataClients?: AssistantDataClients;
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
jest.mock('./graph');
Expand Down Expand Up @@ -79,7 +79,7 @@ describe('callAssistantGraph', () => {
telemetryParams: {},
traceOptions: {},
responseLanguage: 'English',
contentReferencesStore: contentReferencesStoreFactoryMock(),
contentReferencesStore: newContentReferencesStoreMock(),
} as unknown as AgentExecutorParams<boolean>;

beforeEach(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH} {include_
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response. ALWAYS return the exact response from NaturalLanguageESQLTool verbatim in the final response, without adding further description.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;

export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} {include_citations_prompt_placeholder} You have access to the following tools:
{tools}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
transformRawData,
getAnonymizedValue,
ConversationResponse,
contentReferencesStoreFactory,
newContentReferencesStore,
pruneContentReferences,
} from '@kbn/elastic-assistant-common';
import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common';
Expand Down Expand Up @@ -186,8 +186,9 @@ export const chatCompleteRoute = (
}));
}

const contentReferencesStore =
contentReferencesEnabled && contentReferencesStoreFactory();
const contentReferencesStore = contentReferencesEnabled
? newContentReferencesStore()
: undefined;

const onLlmResponse = async (
content: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export function getAssistantToolParams({
langSmithProject?: string;
langSmithApiKey?: string;
logger: Logger;
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
latestReplacements: Replacements;
onNewReplacements: (newReplacements: Replacements) => void;
request: KibanaRequest<unknown, unknown, DefendInsightsPostRequestBody>;
Expand All @@ -136,7 +136,7 @@ export function getAssistantToolParams({
langChainTimeout: number;
llm: ActionsClientLlm;
logger: Logger;
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
replacements: Replacements;
onNewReplacements: (newReplacements: Replacements) => void;
request: KibanaRequest<unknown, unknown, DefendInsightsPostRequestBody>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ export const postDefendInsightsRoute = (router: IRouter<ElasticAssistantRequestH
apiConfig,
esClient,
latestReplacements,
contentReferencesStore: false,
contentReferencesStore: undefined,
connectorTimeout: CONNECTOR_TIMEOUT,
langChainTimeout: LANG_CHAIN_TIMEOUT,
langSmithProject,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { v4 as uuidv4 } from 'uuid';
import { getRequestAbortedSignal } from '@kbn/data-plugin/server';
import {
API_VERSIONS,
contentReferencesStoreFactory,
newContentReferencesStore,
ELASTIC_AI_ASSISTANT_EVALUATE_URL,
ExecuteConnectorRequestBody,
INTERNAL_API_ACCESS,
Expand Down Expand Up @@ -292,8 +292,9 @@ export const postEvaluateRoute = (
assistantContext.getRegisteredFeatures(
DEFAULT_PLUGIN_NAME
).contentReferencesEnabled;
const contentReferencesStore =
contentReferencesEnabled && contentReferencesStoreFactory();
const contentReferencesStore = contentReferencesEnabled
? newContentReferencesStore()
: undefined;

// Fetch any applicable tools that the source plugin may have registered
const assistantToolParams: AssistantToolParams = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ export interface LangChainExecuteParams {
telemetry: AnalyticsServiceSetup;
actionTypeId: string;
connectorId: string;
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
llmTasks?: LlmTasksPluginStart;
inference: InferenceServerStart;
isOssModel?: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { getRequestAbortedSignal } from '@kbn/data-plugin/server';
import { schema } from '@kbn/config-schema';
import {
API_VERSIONS,
contentReferencesStoreFactory,
newContentReferencesStore,
ExecuteConnectorRequestBody,
Message,
Replacements,
Expand Down Expand Up @@ -119,8 +119,9 @@ export const postActionsConnectorExecuteRoute = (
});
const promptsDataClient = await assistantContext.getAIAssistantPromptsDataClient();

const contentReferencesStore =
contentReferencesEnabled && contentReferencesStoreFactory();
const contentReferencesStore = contentReferencesEnabled
? newContentReferencesStore()
: undefined;

onLlmResponse = async (
content: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ export interface AssistantToolParams {
inference?: InferenceServerStart;
isEnabledKnowledgeBase: boolean;
connectorId?: string;
contentReferencesStore: ContentReferencesStore | false;
contentReferencesStore: ContentReferencesStore | undefined;
esClient: ElasticsearchClient;
kbDataClient?: AIAssistantKnowledgeBaseDataClient;
langChainTimeout?: number;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { ALERT_COUNTS_TOOL } from './alert_counts_tool';
import type { RetrievalQAChain } from 'langchain/chains';
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen';
import type { ContentReferencesStore } from '@kbn/elastic-assistant-common';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';

describe('AlertCountsTool', () => {
const alertsIndexPattern = 'alerts-index';
Expand All @@ -32,7 +32,7 @@ describe('AlertCountsTool', () => {
const isEnabledKnowledgeBase = true;
const chain = {} as unknown as RetrievalQAChain;
const logger = loggerMock.create();
const contentReferencesStore = contentReferencesStoreFactoryMock();
const contentReferencesStore = newContentReferencesStoreMock();
const rest = {
isEnabledKnowledgeBase,
chain,
Expand Down Expand Up @@ -191,7 +191,7 @@ describe('AlertCountsTool', () => {
replacements,
request,
...rest,
contentReferencesStore: false,
contentReferencesStore: undefined,
}) as DynamicTool;

const result = await tool.func('');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import type {
ContentReferencesStore,
KnowledgeBaseEntryContentReference,
} from '@kbn/elastic-assistant-common';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { loggerMock } from '@kbn/logging-mocks';
import { Document } from 'langchain/document';

describe('KnowledgeBaseRetievalTool', () => {
const logger = loggerMock.create();
const contentReferencesStore = contentReferencesStoreFactoryMock();
const contentReferencesStore = newContentReferencesStoreMock();
const getKnowledgeBaseDocumentEntries = jest.fn();
const kbDataClient = { getKnowledgeBaseDocumentEntries };
const defaultArgs = {
Expand Down Expand Up @@ -68,7 +68,7 @@ describe('KnowledgeBaseRetievalTool', () => {
it('does not include citations if contentReferenceStore is false', async () => {
const tool = KNOWLEDGE_BASE_RETRIEVAL_TOOL.getTool({
...defaultArgs,
contentReferencesStore: false,
contentReferencesStore: undefined,
}) as DynamicStructuredTool;

getKnowledgeBaseDocumentEntries.mockResolvedValue([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import type {
ContentReferencesStore,
SecurityAlertContentReference,
} from '@kbn/elastic-assistant-common';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';

const MAX_SIZE = 10000;

Expand All @@ -44,7 +44,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
const isEnabledKnowledgeBase = true;
const chain = {} as unknown as RetrievalQAChain;
const logger = loggerMock.create();
const contentReferencesStore = contentReferencesStoreFactoryMock();
const contentReferencesStore = newContentReferencesStoreMock();
const rest = {
isEnabledKnowledgeBase,
esClient,
Expand Down Expand Up @@ -274,7 +274,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
request,
size: request.body.size,
...rest,
contentReferencesStore: false,
contentReferencesStore: undefined,
}) as DynamicTool;

(esClient.search as jest.Mock).mockResolvedValue({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import type {
ContentReferencesStore,
ProductDocumentationContentReference,
} from '@kbn/elastic-assistant-common';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';

describe('ProductDocumentationTool', () => {
const chain = {} as RetrievalQAChain;
Expand All @@ -35,7 +35,7 @@ describe('ProductDocumentationTool', () => {
retrieveDocumentationAvailable: jest.fn(),
} as LlmTasksPluginStart;
const connectorId = 'fake-connector';
const contentReferencesStore = contentReferencesStoreFactoryMock();
const contentReferencesStore = newContentReferencesStoreMock();
const defaultArgs = {
chain,
esClient,
Expand Down Expand Up @@ -144,7 +144,7 @@ describe('ProductDocumentationTool', () => {
it('does not include citations if contentReferencesStore is false', async () => {
const tool = PRODUCT_DOCUMENTATION_TOOL.getTool({
...defaultArgs,
contentReferencesStore: false,
contentReferencesStore: undefined,
}) as DynamicStructuredTool;

(retrieveDocumentation as jest.Mock).mockResolvedValue({
Expand Down
Loading

0 comments on commit 0f62fa1

Please sign in to comment.