From 15aac3606e6fb025928f546bed7e0fe573fd582c Mon Sep 17 00:00:00 2001 From: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> Date: Tue, 4 Mar 2025 04:06:24 +1100 Subject: [PATCH] [8.x] [Security Solution] [GenAi] refactor security ai assistant tools to use tool helper method (#212865) (#212928) # Backport This will backport the following commits from `main` to `8.x`: - [[Security Solution] [GenAi] refactor security ai assistant tools to use tool helper method (#212865)](https://github.com/elastic/kibana/pull/212865) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) Co-authored-by: Kenneth Kreindler <42113355+KDKHD@users.noreply.github.com> --- .../default_assistant_graph/agentRunnable.ts | 56 ++++++++++++ .../graphs/default_assistant_graph/graph.ts | 83 ++---------------- .../graphs/default_assistant_graph/index.ts | 48 ++++------- .../graphs/default_assistant_graph/prompts.ts | 22 ++++- .../graphs/default_assistant_graph/state.ts | 86 +++++++++++++++++++ .../server/routes/evaluate/post_evaluate.ts | 48 ++++------- .../plugins/elastic_assistant/server/types.ts | 4 +- .../tools/alert_counts/alert_counts_tool.ts | 18 ++-- .../assistant/tools/esql/nl_to_esql_tool.ts | 29 ++++--- .../knowledge_base_retrieval_tool.ts | 23 ++--- .../knowledge_base_write_tool.ts | 41 ++++----- .../open_and_acknowledged_alerts_tool.ts | 18 ++-- .../product_documentation_tool.ts | 59 ++++++------- .../tools/security_labs/security_labs_tool.ts | 31 +++---- 14 files changed, 318 insertions(+), 248 deletions(-) create mode 100644 x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts create mode 100644 x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/state.ts diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts new file mode 100644 index 0000000000000..8fc4faa371d3f --- /dev/null +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ToolDefinition } from '@langchain/core/language_models/base'; +import { + ActionsClientChatBedrockConverse, + ActionsClientChatVertexAI, + ActionsClientChatOpenAI, +} from '@kbn/langchain/server'; +import type { StructuredToolInterface } from '@langchain/core/tools'; +import { + AgentRunnableSequence, + createOpenAIToolsAgent, + createStructuredChatAgent, + createToolCallingAgent, +} from 'langchain/agents'; +import { ChatPromptTemplate } from '@langchain/core/prompts'; + +export const TOOL_CALLING_LLM_TYPES = new Set(['bedrock', 'gemini']); + +export const agentRunableFactory = async ({ + llm, + isOpenAI, + llmType, + tools, + isStream, + prompt, +}: { + llm: ActionsClientChatBedrockConverse | ActionsClientChatVertexAI | ActionsClientChatOpenAI; + isOpenAI: boolean; + llmType: string | undefined; + tools: StructuredToolInterface[] | ToolDefinition[]; + isStream: boolean; + prompt: ChatPromptTemplate; +}): Promise => { + const params = { + llm, + tools, + streamRunnable: isStream, + prompt, + } as const; + + if (isOpenAI || llmType === 'inference') { + return createOpenAIToolsAgent(params); + } + + if (llmType && TOOL_CALLING_LLM_TYPES.has(llmType)) { + return createToolCallingAgent(params); + } + + return createStructuredChatAgent(params); +}; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 8e85b20b06c8a..7f8502cf4b4c7 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -5,15 +5,13 @@ * 2.0. */ -import { Annotation, END, START, StateGraph } from '@langchain/langgraph'; -import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents'; +import { END, START, StateGraph } from '@langchain/langgraph'; import { AgentRunnableSequence } from 'langchain/dist/agents/agent'; import { StructuredTool } from '@langchain/core/tools'; import type { Logger } from '@kbn/logging'; -import { BaseMessage } from '@langchain/core/messages'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common'; +import { Replacements } from '@kbn/elastic-assistant-common'; import { PublicMethodsOf } from '@kbn/utility-types'; import { ActionsClient } from '@kbn/actions-plugin/server'; import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; @@ -29,6 +27,7 @@ import { getPersistedConversation } from './nodes/get_persisted_conversation'; import { persistConversationChanges } from './nodes/persist_conversation_changes'; import { respond } from './nodes/respond'; import { NodeType } from './constants'; +import { getStateAnnotation } from './state'; export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph'; @@ -61,78 +60,6 @@ export const getDefaultAssistantGraph = ({ getFormattedTime, }: GetDefaultAssistantGraphParams) => { try { - // Default graph state - const graphAnnotation = Annotation.Root({ - input: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => '', - }), - lastNode: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => 'start', - }), - steps: Annotation({ - reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y), - default: () => [], - }), - hasRespondStep: Annotation({ - reducer: (x: boolean, y?: boolean) => y ?? x, - default: () => false, - }), - agentOutcome: Annotation({ - reducer: ( - x: AgentAction | AgentFinish | undefined, - y?: AgentAction | AgentFinish | undefined - ) => y ?? x, - default: () => undefined, - }), - messages: Annotation({ - reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x, - default: () => [], - }), - chatTitle: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => '', - }), - llmType: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => 'unknown', - }), - isStream: Annotation({ - reducer: (x: boolean, y?: boolean) => y ?? x, - default: () => false, - }), - isOssModel: Annotation({ - reducer: (x: boolean, y?: boolean) => y ?? x, - default: () => false, - }), - connectorId: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => '', - }), - conversation: Annotation({ - reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) => - y ?? x, - default: () => undefined, - }), - conversationId: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => '', - }), - responseLanguage: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => 'English', - }), - provider: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: () => '', - }), - formattedTime: Annotation({ - reducer: (x: string, y?: string) => y ?? x, - default: getFormattedTime ?? (() => ''), - }), - }); - // Default node parameters const nodeParams: NodeParamsBase = { actionsClient, @@ -140,8 +67,10 @@ export const getDefaultAssistantGraph = ({ savedObjectsClient, }; + const stateAnnotation = getStateAnnotation({ getFormattedTime }); + // Put together a new graph using default state from above - const graph = new StateGraph(graphAnnotation) + const graph = new StateGraph(stateAnnotation) .addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) => getPersistedConversation({ ...nodeParams, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 3bfd41329ebae..da1d2244e5c5e 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -7,11 +7,6 @@ import { StructuredTool } from '@langchain/core/tools'; import { getDefaultArguments } from '@kbn/langchain/server'; -import { - createOpenAIToolsAgent, - createStructuredChatAgent, - createToolCallingAgent, -} from 'langchain/agents'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry'; import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common'; @@ -25,12 +20,13 @@ import { getLlmClass } from '../../../../routes/utils'; import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; import { AgentExecutor } from '../../executors/types'; -import { formatPrompt, formatPromptStructured } from './prompts'; +import { formatPrompt } from './prompts'; import { GraphInputs } from './types'; import { getDefaultAssistantGraph } from './graph'; import { invokeGraph, streamGraph } from './helpers'; import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers'; import { DEFAULT_DATE_FORMAT_TZ } from '../../../../../common/constants'; +import { agentRunableFactory } from './agentRunnable'; export const callAssistantGraph: AgentExecutor = async ({ abortSignal, @@ -179,28 +175,21 @@ export const callAssistantGraph: AgentExecutor = async ({ savedObjectsClient, }); - const agentRunnable = - isOpenAI || llmType === 'inference' - ? await createOpenAIToolsAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPrompt(defaultSystemPrompt, systemPrompt), - streamRunnable: isStream, - }) - : llmType && ['bedrock', 'gemini'].includes(llmType) - ? await createToolCallingAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPrompt(defaultSystemPrompt, systemPrompt), - streamRunnable: isStream, - }) - : // used with OSS models - await createStructuredChatAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPromptStructured(defaultSystemPrompt, systemPrompt), - streamRunnable: isStream, - }); + const chatPromptTemplate = formatPrompt({ + prompt: defaultSystemPrompt, + additionalPrompt: systemPrompt, + llmType, + isOpenAI, + }); + + const agentRunnable = await agentRunableFactory({ + llm: createLlmInstance(), + isOpenAI, + llmType, + tools, + isStream, + prompt: chatPromptTemplate, + }); const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); const telemetryTracer = telemetryParams @@ -214,6 +203,7 @@ export const callAssistantGraph: AgentExecutor = async ({ logger ) : undefined; + const { provider } = !llmType || llmType === 'inference' ? await resolveProviderAndModel({ @@ -240,7 +230,7 @@ export const callAssistantGraph: AgentExecutor = async ({ ...(llmType === 'bedrock' ? { signal: abortSignal } : {}), getFormattedTime: () => getFormattedTime({ - screenContextTimezone: request.body.screenContext?.timeZone, + screenContextTimezone: screenContext?.timeZone, uiSettingsDateFormatTimezone, }), }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts index bc28f00e5d76e..79327648dde34 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts @@ -6,8 +6,9 @@ */ import { ChatPromptTemplate } from '@langchain/core/prompts'; +import { TOOL_CALLING_LLM_TYPES } from './agentRunnable'; -export const formatPrompt = (prompt: string, additionalPrompt?: string) => +const formatPromptToolcalling = (prompt: string, additionalPrompt?: string) => ChatPromptTemplate.fromMessages([ ['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt], ['placeholder', '{knowledge_history}'], @@ -16,7 +17,7 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) => ['placeholder', '{agent_scratchpad}'], ]); -export const formatPromptStructured = (prompt: string, additionalPrompt?: string) => +const formatPromptStructured = (prompt: string, additionalPrompt?: string) => ChatPromptTemplate.fromMessages([ ['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt], ['placeholder', '{knowledge_history}'], @@ -26,3 +27,20 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string '{input}\n\n{agent_scratchpad}\n\n(reminder to respond in a JSON blob no matter what)', ], ]); + +export const formatPrompt = ({ + isOpenAI, + llmType, + prompt, + additionalPrompt, +}: { + isOpenAI: boolean; + llmType: string | undefined; + prompt: string; + additionalPrompt?: string; +}) => { + if (isOpenAI || llmType === 'inference' || (llmType && TOOL_CALLING_LLM_TYPES.has(llmType))) { + return formatPromptToolcalling(prompt, additionalPrompt); + } + return formatPromptStructured(prompt, additionalPrompt); +}; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/state.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/state.ts new file mode 100644 index 0000000000000..f1ab308adfefb --- /dev/null +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/state.ts @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ConversationResponse } from '@kbn/elastic-assistant-common'; +import { BaseMessage } from '@langchain/core/messages'; +import { Annotation } from '@langchain/langgraph'; +import { AgentStep, AgentAction, AgentFinish } from 'langchain/agents'; + +export const getStateAnnotation = ({ getFormattedTime }: { getFormattedTime?: () => string }) => { + const graphAnnotation = Annotation.Root({ + input: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => '', + }), + lastNode: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => 'start', + }), + steps: Annotation({ + reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y), + default: () => [], + }), + hasRespondStep: Annotation({ + reducer: (x: boolean, y?: boolean) => y ?? x, + default: () => false, + }), + agentOutcome: Annotation({ + reducer: ( + x: AgentAction | AgentFinish | undefined, + y?: AgentAction | AgentFinish | undefined + ) => y ?? x, + default: () => undefined, + }), + messages: Annotation({ + reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x, + default: () => [], + }), + chatTitle: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => '', + }), + llmType: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => 'unknown', + }), + isStream: Annotation({ + reducer: (x: boolean, y?: boolean) => y ?? x, + default: () => false, + }), + isOssModel: Annotation({ + reducer: (x: boolean, y?: boolean) => y ?? x, + default: () => false, + }), + connectorId: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => '', + }), + conversation: Annotation({ + reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) => + y ?? x, + default: () => undefined, + }), + conversationId: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => '', + }), + responseLanguage: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => 'English', + }), + provider: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: () => '', + }), + formattedTime: Annotation({ + reducer: (x: string, y?: string) => y ?? x, + default: getFormattedTime ?? (() => ''), + }), + }); + + return graphAnnotation; +}; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index bb440f034605a..ae82fec6ceeca 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -26,21 +26,13 @@ import { import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getDefaultArguments } from '@kbn/langchain/server'; import { StructuredTool } from '@langchain/core/tools'; -import { - AgentFinish, - createOpenAIToolsAgent, - createStructuredChatAgent, - createToolCallingAgent, -} from 'langchain/agents'; +import { AgentFinish } from 'langchain/agents'; import { omit } from 'lodash/fp'; import { localToolPrompts, promptGroupId as toolsGroupId } from '../../lib/prompt/tool_prompts'; import { promptGroupId } from '../../lib/prompt/local_prompt_object'; import { getFormattedTime, getModelOrOss } from '../../lib/prompt/helpers'; import { getAttackDiscoveryPrompts } from '../../lib/attack_discovery/graphs/default_attack_discovery_graph/nodes/helpers/prompts'; -import { - formatPrompt, - formatPromptStructured, -} from '../../lib/langchain/graphs/default_assistant_graph/prompts'; +import { formatPrompt } from '../../lib/langchain/graphs/default_assistant_graph/prompts'; import { getPrompt as localGetPrompt, promptDictionary } from '../../lib/prompt'; import { buildResponse } from '../../lib/build_response'; import { AssistantDataClients } from '../../lib/langchain/executors/types'; @@ -57,6 +49,7 @@ import { import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils'; import { getGraphsFromNames } from './get_graphs_from_names'; import { DEFAULT_DATE_FORMAT_TZ } from '../../../common/constants'; +import { agentRunableFactory } from '../../lib/langchain/graphs/default_assistant_graph/agentRunnable'; const DEFAULT_SIZE = 20; const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes @@ -356,27 +349,20 @@ export const postEvaluateRoute = ( savedObjectsClient, }); - const agentRunnable = - isOpenAI || llmType === 'inference' - ? await createOpenAIToolsAgent({ - llm, - tools, - prompt: formatPrompt(defaultSystemPrompt), - streamRunnable: false, - }) - : llmType && ['bedrock', 'gemini'].includes(llmType) - ? createToolCallingAgent({ - llm, - tools, - prompt: formatPrompt(defaultSystemPrompt), - streamRunnable: false, - }) - : await createStructuredChatAgent({ - llm, - tools, - prompt: formatPromptStructured(defaultSystemPrompt), - streamRunnable: false, - }); + const chatPromptTemplate = formatPrompt({ + prompt: defaultSystemPrompt, + llmType, + isOpenAI, + }); + + const agentRunnable = await agentRunableFactory({ + llm: createLlmInstance(), + isOpenAI, + llmType, + tools, + isStream: false, + prompt: chatPromptTemplate, + }); const uiSettingsDateFormatTimezone = await ctx.core.uiSettings.client.get( DEFAULT_DATE_FORMAT_TZ diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/types.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/types.ts index e6d2b90787002..00d70c3845563 100755 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/types.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/types.ts @@ -24,7 +24,7 @@ import { } from '@kbn/core/server'; import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server'; import { type MlPluginSetup } from '@kbn/ml-plugin/server'; -import { DynamicStructuredTool, Tool } from '@langchain/core/tools'; +import { StructuredToolInterface } from '@langchain/core/tools'; import { SpacesPluginSetup, SpacesPluginStart } from '@kbn/spaces-plugin/server'; import { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server'; import { @@ -224,7 +224,7 @@ export interface AssistantTool { description: string; sourceRegister: string; isSupported: (params: AssistantToolParams) => boolean; - getTool: (params: AssistantToolParams) => Tool | DynamicStructuredTool | null; + getTool: (params: AssistantToolParams) => StructuredToolInterface | null; } export type AssistantToolLlm = diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/alert_counts/alert_counts_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/alert_counts/alert_counts_tool.ts index 801b4054a8ef3..cc36d029f1c2e 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/alert_counts/alert_counts_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/alert_counts/alert_counts_tool.ts @@ -6,8 +6,7 @@ */ import type { SearchResponse } from '@elastic/elasticsearch/lib/api/types'; -import { DynamicStructuredTool } from '@langchain/core/tools'; -import { z } from '@kbn/zod'; +import { tool } from '@langchain/core/tools'; import { requestHasRequiredAnonymizationParams } from '@kbn/elastic-assistant-plugin/server/lib/langchain/helpers'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import { contentReferenceString, securityAlertsPageReference } from '@kbn/elastic-assistant-common'; @@ -36,11 +35,8 @@ export const ALERT_COUNTS_TOOL: AssistantTool = { if (!this.isSupported(params)) return null; const { alertsIndexPattern, esClient, contentReferencesStore } = params as AlertCountsToolParams; - return new DynamicStructuredTool({ - name: 'AlertCountsTool', - description: params.description || ALERT_COUNTS_TOOL_DESCRIPTION, - schema: z.object({}), - func: async () => { + return tool( + async () => { const query = getAlertsCountQuery(alertsIndexPattern); const result = await esClient.search(query); const alertsCountReference = contentReferencesStore?.add((p) => @@ -51,7 +47,11 @@ export const ALERT_COUNTS_TOOL: AssistantTool = { return `${JSON.stringify(result)}${reference}`; }, - tags: ['alerts', 'alerts-count'], - }); + { + name: 'AlertCountsTool', + description: params.description || ALERT_COUNTS_TOOL_DESCRIPTION, + tags: ['alerts', 'alerts-count'], + } + ); }, }; diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/esql/nl_to_esql_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/esql/nl_to_esql_tool.ts index 0d2af41232f70..33a4286c020b1 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/esql/nl_to_esql_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/esql/nl_to_esql_tool.ts @@ -5,11 +5,11 @@ * 2.0. */ -import { DynamicStructuredTool } from '@langchain/core/tools'; -import { z } from '@kbn/zod'; +import { tool } from '@langchain/core/tools'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import { lastValueFrom } from 'rxjs'; import { naturalLanguageToEsql } from '@kbn/inference-plugin/server'; +import { z } from '@kbn/zod'; import { APP_UI_ID } from '../../../../common'; import { getPromptSuffixForOssModel } from './common'; @@ -57,23 +57,24 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { ); }; - return new DynamicStructuredTool({ - name: toolDetails.name, - description: - (params.description || toolDetails.description) + - (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''), - schema: z.object({ - question: z.string().describe(`The user's exact question about ESQL`), - }), - func: async (input) => { + return tool( + async (input) => { const generateEvent = await callNaturalLanguageToEsql(input.question); const answer = generateEvent.content ?? 'An error occurred in the tool'; logger.debug(`Received response from NL to ESQL tool: ${answer}`); return answer; }, - tags: ['esql', 'query-generation', 'knowledge-base'], - // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts - }) as unknown as DynamicStructuredTool; + { + name: toolDetails.name, + description: + (params.description || toolDetails.description) + + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''), + schema: z.object({ + question: z.string().describe(`The user's exact question about ESQL`), + }), + tags: ['esql', 'query-generation', 'knowledge-base'], + } + ); }, }; diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts index beddd4efeadb9..6cff2ccb63722 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_retrieval_tool.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { DynamicStructuredTool } from '@langchain/core/tools'; +import { tool } from '@langchain/core/tools'; import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base'; @@ -41,13 +41,8 @@ export const KNOWLEDGE_BASE_RETRIEVAL_TOOL: AssistantTool = { params as KnowledgeBaseRetrievalToolParams; if (kbDataClient == null) return null; - return new DynamicStructuredTool({ - name: toolDetails.name, - description: params.description || toolDetails.description, - schema: z.object({ - query: z.string().describe(`Summary of items/things to search for in the knowledge base`), - }), - func: async (input) => { + return tool( + async (input) => { logger.debug( () => `KnowledgeBaseRetrievalToolParams:input\n ${JSON.stringify(input, null, 2)}` ); @@ -60,9 +55,15 @@ export const KNOWLEDGE_BASE_RETRIEVAL_TOOL: AssistantTool = { return JSON.stringify(docs.map(enrichDocument(contentReferencesStore))); }, - tags: ['knowledge-base'], - // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts - }) as unknown as DynamicStructuredTool; + { + name: toolDetails.name, + description: params.description || toolDetails.description, + schema: z.object({ + query: z.string().describe(`Summary of items/things to search for in the knowledge base`), + }), + tags: ['knowledge-base'], + } + ); }, }; diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts index d3fb2110e7c79..abaca93e43ada 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/knowledge_base/knowledge_base_write_tool.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { DynamicStructuredTool } from '@langchain/core/tools'; +import { tool } from '@langchain/core/tools'; import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base'; @@ -41,22 +41,8 @@ export const KNOWLEDGE_BASE_WRITE_TOOL: AssistantTool = { const { telemetry, kbDataClient, logger } = params as KnowledgeBaseWriteToolParams; if (kbDataClient == null) return null; - return new DynamicStructuredTool({ - name: toolDetails.name, - description: params.description || toolDetails.description, - schema: z.object({ - name: z - .string() - .describe(`This is what the user will use to refer to the entry in the future.`), - query: z.string().describe(`Summary of items/things to save in the knowledge base`), - required: z - .boolean() - .describe( - `Whether or not the entry is required to always be included in conversations. Is only true if the user explicitly asks for it to be required or always included in conversations, otherwise this is always false.` - ) - .default(false), - }), - func: async (input) => { + return tool( + async (input) => { logger.debug( () => `KnowledgeBaseWriteToolParams:input\n ${JSON.stringify(input, null, 2)}` ); @@ -78,8 +64,23 @@ export const KNOWLEDGE_BASE_WRITE_TOOL: AssistantTool = { } return "I've successfully saved this entry to your knowledge base. You can ask me to recall this information at any time."; }, - tags: ['knowledge-base'], - // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts - }) as unknown as DynamicStructuredTool; + { + name: toolDetails.name, + description: params.description || toolDetails.description, + schema: z.object({ + name: z + .string() + .describe(`This is what the user will use to refer to the entry in the future.`), + query: z.string().describe(`Summary of items/things to save in the knowledge base`), + required: z + .boolean() + .describe( + `Whether or not the entry is required to always be included in conversations. Is only true if the user explicitly asks for it to be required or always included in conversations, otherwise this is always false.` + ) + .default(false), + }), + tags: ['knowledge-base'], + } + ); }, }; diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool.ts index d73bc266239d5..835c6a8d38d0c 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool.ts @@ -16,9 +16,8 @@ import { transformRawData, contentReferenceBlock, } from '@kbn/elastic-assistant-common'; -import { DynamicStructuredTool } from '@langchain/core/tools'; +import { tool } from '@langchain/core/tools'; import { requestHasRequiredAnonymizationParams } from '@kbn/elastic-assistant-plugin/server/lib/langchain/helpers'; -import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import { APP_UI_ID } from '../../../../common'; @@ -63,11 +62,8 @@ export const OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL: AssistantTool = { size, contentReferencesStore, } = params as OpenAndAcknowledgedAlertsToolParams; - return new DynamicStructuredTool({ - name: 'OpenAndAcknowledgedAlertsTool', - description: params.description || OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL_DESCRIPTION, - schema: z.object({}), - func: async () => { + return tool( + async () => { const query = getOpenAndAcknowledgedAlertsQuery({ alertsIndexPattern, anonymizationFields: anonymizationFields ?? [], @@ -105,7 +101,11 @@ export const OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL: AssistantTool = { }) ); }, - tags: ['alerts', 'open-and-acknowledged-alerts'], - }); + { + name: 'OpenAndAcknowledgedAlertsTool', + description: params.description || OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL_DESCRIPTION, + tags: ['alerts', 'open-and-acknowledged-alerts'], + } + ); }, }; diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/product_docs/product_documentation_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/product_docs/product_documentation_tool.ts index 30dd2a1ec50cb..014be943cca3e 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/product_docs/product_documentation_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/product_docs/product_documentation_tool.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { DynamicStructuredTool } from '@langchain/core/tools'; +import { tool } from '@langchain/core/tools'; import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; @@ -41,31 +41,8 @@ export const PRODUCT_DOCUMENTATION_TOOL: AssistantTool = { // This check is here in order to satisfy TypeScript if (llmTasks == null || connectorId == null) return null; - return new DynamicStructuredTool({ - name: toolDetails.name, - description: params.description || toolDetails.description, - schema: z.object({ - query: z.string().describe( - `The query to use to retrieve documentation - Examples: - - "How to enable TLS for Elasticsearch?" - - "What is Kibana Security?"` - ), - product: z - .enum(['kibana', 'elasticsearch', 'observability', 'security']) - .describe( - `If specified, will filter the products to retrieve documentation for - Possible options are: - - "kibana": Kibana product - - "elasticsearch": Elasticsearch product - - "observability": Elastic Observability solution - - "security": Elastic Security solution - If not specified, will search against all products - ` - ) - .optional(), - }), - func: async ({ query, product }) => { + return tool( + async ({ query, product }) => { const response = await llmTasks.retrieveDocumentation({ searchTerm: query, products: product ? [product] : undefined, @@ -83,9 +60,33 @@ export const PRODUCT_DOCUMENTATION_TOOL: AssistantTool = { }, }; }, - tags: ['product-documentation'], - // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts - }) as unknown as DynamicStructuredTool; + { + name: toolDetails.name, + description: params.description || toolDetails.description, + schema: z.object({ + query: z.string().describe( + `The query to use to retrieve documentation + Examples: + - "How to enable TLS for Elasticsearch?" + - "What is Kibana Security?"` + ), + product: z + .enum(['kibana', 'elasticsearch', 'observability', 'security']) + .describe( + `If specified, will filter the products to retrieve documentation for + Possible options are: + - "kibana": Kibana product + - "elasticsearch": Elasticsearch product + - "observability": Elastic Observability solution + - "security": Elastic Security solution + If not specified, will search against all products + ` + ) + .optional(), + }), + tags: ['product-documentation'], + } + ); }, }; diff --git a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/security_labs/security_labs_tool.ts b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/security_labs/security_labs_tool.ts index 2faad9ba71c06..61c7c1dc82299 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/security_labs/security_labs_tool.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/assistant/tools/security_labs/security_labs_tool.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { DynamicStructuredTool } from '@langchain/core/tools'; +import { tool } from '@langchain/core/tools'; import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; @@ -35,17 +35,8 @@ export const SECURITY_LABS_KNOWLEDGE_BASE_TOOL: AssistantTool = { const { kbDataClient, contentReferencesStore } = params as AssistantToolParams; if (kbDataClient == null) return null; - return new DynamicStructuredTool({ - name: toolDetails.name, - description: params.description || toolDetails.description, - schema: z.object({ - question: z - .string() - .describe( - `Key terms to retrieve Elastic Security Labs content for, like specific malware names or attack techniques.` - ), - }), - func: async (input) => { + return tool( + async (input) => { const docs = await kbDataClient.getKnowledgeBaseDocumentEntries({ kbResource: SECURITY_LABS_RESOURCE, query: input.question, @@ -61,8 +52,18 @@ export const SECURITY_LABS_KNOWLEDGE_BASE_TOOL: AssistantTool = { const citation = contentReferenceString(reference); return `${result}\n${citation}`; }, - tags: ['security-labs', 'knowledge-base'], - // TODO: Remove after ZodAny is fixed https://github.com/langchain-ai/langchainjs/blob/main/langchain-core/src/tools.ts - }) as unknown as DynamicStructuredTool; + { + name: toolDetails.name, + description: params.description || toolDetails.description, + schema: z.object({ + question: z + .string() + .describe( + `Key terms to retrieve Elastic Security Labs content for, like specific malware names or attack techniques.` + ), + }), + tags: ['security-labs', 'knowledge-base'], + } + ); }, };