Skip to content

Commit

Permalink
[Security assistant] Use inference connector in security AI features (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Jan 8, 2025
1 parent 5a3c914 commit c6501da
Show file tree
Hide file tree
Showing 19 changed files with 715 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export interface AssistantProviderProps {
children: React.ReactNode;
getComments: GetAssistantMessages;
http: HttpSetup;
inferenceEnabled?: boolean;
baseConversations: Record<string, Conversation>;
nameSpace?: string;
navigateToApp: (appId: string, options?: NavigateToAppOptions | undefined) => Promise<void>;
Expand Down Expand Up @@ -104,6 +105,7 @@ export interface UseAssistantContext {
currentUserAvatar?: UserAvatar;
getComments: GetAssistantMessages;
http: HttpSetup;
inferenceEnabled: boolean;
knowledgeBase: KnowledgeBaseConfig;
getLastConversationId: (conversationTitle?: string) => string;
promptContexts: Record<string, PromptContext>;
Expand Down Expand Up @@ -147,6 +149,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
children,
getComments,
http,
inferenceEnabled = false,
baseConversations,
navigateToApp,
nameSpace = DEFAULT_ASSISTANT_NAMESPACE,
Expand Down Expand Up @@ -280,6 +283,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
docLinks,
getComments,
http,
inferenceEnabled,
knowledgeBase: {
...DEFAULT_KNOWLEDGE_BASE_SETTINGS,
...localStorageKnowledgeBase,
Expand Down Expand Up @@ -322,6 +326,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
docLinks,
getComments,
http,
inferenceEnabled,
localStorageKnowledgeBase,
promptContexts,
navigateToApp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,10 @@ export const ConnectorSelector: React.FC<Props> = React.memo(
const connectorOptions = useMemo(
() =>
(aiConnectors ?? []).map((connector) => {
const connectorTypeTitle =
getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
const connectorDetails = connector.isPreconfigured
? i18n.PRECONFIGURED_CONNECTOR
: connectorTypeTitle;
: getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
const attackDiscoveryStats =
stats !== null
? stats.statsPerConnector.find((s) => s.connectorId === connector.id) ?? null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ interface Props {
actionTypeSelectorInline: boolean;
}
const itemClassName = css`
inline-size: 220px;
inline-size: 150px;
.euiKeyPadMenuItem__label {
white-space: nowrap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ export const getConnectorTypeTitle = (
if (!connector) {
return null;
}
const connectorTypeTitle =
getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));
const actionType = connector.isPreconfigured ? PRECONFIGURED_CONNECTOR : connectorTypeTitle;

const actionType = connector.isPreconfigured
? PRECONFIGURED_CONNECTOR
: getGenAiConfig(connector)?.apiProvider ??
getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId));

return actionType;
};
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,12 @@ export const useLoadActionTypes = ({
featureId: GenerativeAIForSecurityConnectorFeatureId,
});

const actionTypeKey = {
bedrock: '.bedrock',
openai: '.gen-ai',
gemini: '.gemini',
};
// TODO add .inference once all the providers support unified completion
const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];

const sortedData = queryResult
.filter((p) =>
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(p.id)
)
return queryResult
.filter((p) => actionTypes.includes(p.id))
.sort((a, b) => a.name.localeCompare(b.name));
return sortedData;
},
{
retry: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import { waitFor, renderHook } from '@testing-library/react';
import { useLoadConnectors, Props } from '.';
import { mockConnectors } from '../../mock/connectors';
import { TestProviders } from '../../mock/test_providers/test_providers';
import React, { ReactNode } from 'react';

const mockConnectorsAndExtras = [
...mockConnectors,
Expand Down Expand Up @@ -45,50 +47,73 @@ const loadConnectorsResult = mockConnectors.map((c) => ({
isSystemAction: false,
}));

jest.mock('@tanstack/react-query', () => ({
useQuery: jest.fn().mockImplementation(async (queryKey, fn, opts) => {
try {
const res = await fn();
return Promise.resolve(res);
} catch (e) {
opts.onError(e);
}
}),
}));

const http = {
get: jest.fn().mockResolvedValue(connectorsApiResponse),
};
const toasts = {
addError: jest.fn(),
};
const defaultProps = { http, toasts } as unknown as Props;

const createWrapper = (inferenceEnabled = false) => {
// eslint-disable-next-line react/display-name
return ({ children }: { children: ReactNode }) => (
<TestProviders providerContext={{ inferenceEnabled }}>{children}</TestProviders>
);
};

describe('useLoadConnectors', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should call api to load action types', async () => {
renderHook(() => useLoadConnectors(defaultProps));
renderHook(() => useLoadConnectors(defaultProps), {
wrapper: TestProviders,
});
await waitFor(() => {
expect(defaultProps.http.get).toHaveBeenCalledWith('/api/actions/connectors');
expect(toasts.addError).not.toHaveBeenCalled();
});
});

it('should return sorted action types, removing isMissingSecrets and wrong action type ids', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps));
it('should return sorted action types, removing isMissingSecrets and wrong action type ids, excluding .inference results', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
wrapper: TestProviders,
});
await waitFor(() => {
expect(result.current.data).toStrictEqual(
loadConnectorsResult
.filter((c) => c.actionTypeId !== '.inference')
// @ts-ignore ts does not like config, but we define it in the mock data
.map((c) => ({ ...c, apiProvider: c.config.apiProvider }))
);
});
});

it('includes preconfigured .inference results when inferenceEnabled is true', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
wrapper: createWrapper(true),
});
await waitFor(() => {
expect(result.current).resolves.toStrictEqual(
// @ts-ignore ts does not like config, but we define it in the mock data
loadConnectorsResult.map((c) => ({ ...c, apiProvider: c.config.apiProvider }))
expect(result.current.data).toStrictEqual(
mockConnectors
.filter(
(c) =>
c.actionTypeId !== '.inference' ||
(c.actionTypeId === '.inference' && c.isPreconfigured)
)
// @ts-ignore ts does not like config, but we define it in the mock data
.map((c) => ({ ...c, referencedByCount: 0, apiProvider: c?.config?.apiProvider }))
);
});
});
it('should display error toast when api throws error', async () => {
const mockHttp = {
get: jest.fn().mockRejectedValue(new Error('this is an error')),
} as unknown as Props['http'];
renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }));
renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }), {
wrapper: TestProviders,
});
await waitFor(() => expect(toasts.addError).toHaveBeenCalled());
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type { IHttpFetchError } from '@kbn/core-http-browser';
import { HttpSetup } from '@kbn/core-http-browser';
import { IToasts } from '@kbn/core-notifications-browser';
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants';
import { useAssistantContext } from '../../assistant_context';
import { AIConnector } from '../connector_selector';
import * as i18n from '../translations';

Expand All @@ -27,16 +28,17 @@ export interface Props {
toasts?: IToasts;
}

const actionTypeKey = {
bedrock: '.bedrock',
openai: '.gen-ai',
gemini: '.gemini',
};
const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];

export const useLoadConnectors = ({
http,
toasts,
}: Props): UseQueryResult<AIConnector[], IHttpFetchError> => {
const { inferenceEnabled } = useAssistantContext();
if (inferenceEnabled) {
actionTypes.push('.inference');
}

return useQuery(
QUERY_KEY,
async () => {
Expand All @@ -45,9 +47,9 @@ export const useLoadConnectors = ({
(acc: AIConnector[], connector) => [
...acc,
...(!connector.isMissingSecrets &&
[actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(
connector.actionTypeId
)
actionTypes.includes(connector.actionTypeId) &&
// only include preconfigured .inference connectors
(connector.actionTypeId !== '.inference' || connector.isPreconfigured)
? [
{
...connector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,26 @@ export const mockConnectors: AIConnector[] = [
apiProvider: 'OpenAI',
},
},
{
id: 'c29c28a0-20fe-11ee-9386-a1f4d42ec542',
name: 'Regular Inference Connector',
isMissingSecrets: false,
actionTypeId: '.inference',
secrets: {},
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
config: {
apiProvider: 'OpenAI',
},
},
{
id: 'c29c28a0-20fe-11ee-9396-a1f4d42ec542',
name: 'Preconfigured Inference Connector',
isMissingSecrets: false,
actionTypeId: '.inference',
isPreconfigured: true,
isDeprecated: false,
isSystemAction: false,
},
];
Loading

0 comments on commit c6501da

Please sign in to comment.