Skip to content

Commit

Permalink
[8.x] [Security Solution] [GenAi] refactor security ai assistant tool…
Browse files Browse the repository at this point in the history
…s to use tool helper method (#212865) (#212928)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Security Solution] [GenAi] refactor security ai assistant tools to
use tool helper method
(#212865)](#212865)

<!--- Backport version: 9.6.6 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

<!--BACKPORT [{"author":{"name":"Kenneth
Kreindler","email":"42113355+KDKHD@users.noreply.github.com"},"sourceCommit":{"committedDate":"2025-03-03T14:35:22Z","message":"[Security
Solution] [GenAi] refactor security ai assistant tools to use tool
helper method (#212865)\n\n## Summary\n\nClean up some security ai
assistant code.\n\n- Replace the usage of `new DynamicStructuredTool()`
with the `tool()`\nhelper method. This is the recommended approach today
and has the\ncorrect types to work
with\n[`Command`](https://langchain-ai.github.io/langgraphjs/concepts/low_level/#command).\n-
Extract code such as the default assistant graph state
and\nagentRunnableFactory to reduce cognitive overload.\n- Update
AssistantTool type definition\n\n### Checklist\n\nCheck the PR satisfies
following conditions. \n\nReviewers should verify this PR satisfies this
list as well.\n\n- [X] Any text added follows [EUI's
writing\nguidelines](https://elastic.github.io/eui/#/guidelines/writing),
uses\nsentence case text and includes
[i18n\nsupport](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md)\n-
[X]\n[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)\nwas
added for features that require explanation or tutorials\n- [X] [Unit or
functional\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\nwere
updated or added to match the most common scenarios\n- [X] If a plugin
configuration key changed, check if it needs to be\nallowlisted in the
cloud and added to the
[docker\nlist](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)\n-
[X] This was checked for breaking HTTP API changes, and any
breaking\nchanges have been approved by the breaking-change committee.
The\n`release_note:breaking` label should be applied in these
situations.\n- [X] [Flaky
Test\nRunner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1)
was\nused on any tests changed\n- [X] The PR description includes the
appropriate Release Notes section,\nand the correct `release_note:*`
label is applied per
the\n[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)\n\n###
Identify risks\n\nDoes this PR introduce any risks? For example,
consider risks like hard\nto test bugs, performance regression,
potential of data loss.\n\nDescribe the risk, its severity, and
mitigation for each identified\nrisk. Invite stakeholders and evaluate
how to proceed before merging.\n\n- [ ] [See some
risk\nexamples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx)\n-
[ ] ...\n\n---------\n\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>\nCo-authored-by:
Elastic Machine
<elasticmachine@users.noreply.github.com>","sha":"d37fcb6fb686e59f18e541113ef16d01030c8d86","branchLabelMapping":{"^v9.1.0$":"main","^v8.19.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:Security
Generative AI","backport:version","v9.1.0","v8.19.0"],"title":"[Security
Solution] [GenAi] refactor security ai assistant tools to use tool
helper
method","number":212865,"url":"https://github.com/elastic/kibana/pull/212865","mergeCommit":{"message":"[Security
Solution] [GenAi] refactor security ai assistant tools to use tool
helper method (#212865)\n\n## Summary\n\nClean up some security ai
assistant code.\n\n- Replace the usage of `new DynamicStructuredTool()`
with the `tool()`\nhelper method. This is the recommended approach today
and has the\ncorrect types to work
with\n[`Command`](https://langchain-ai.github.io/langgraphjs/concepts/low_level/#command).\n-
Extract code such as the default assistant graph state
and\nagentRunnableFactory to reduce cognitive overload.\n- Update
AssistantTool type definition\n\n### Checklist\n\nCheck the PR satisfies
following conditions. \n\nReviewers should verify this PR satisfies this
list as well.\n\n- [X] Any text added follows [EUI's
writing\nguidelines](https://elastic.github.io/eui/#/guidelines/writing),
uses\nsentence case text and includes
[i18n\nsupport](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md)\n-
[X]\n[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)\nwas
added for features that require explanation or tutorials\n- [X] [Unit or
functional\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\nwere
updated or added to match the most common scenarios\n- [X] If a plugin
configuration key changed, check if it needs to be\nallowlisted in the
cloud and added to the
[docker\nlist](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)\n-
[X] This was checked for breaking HTTP API changes, and any
breaking\nchanges have been approved by the breaking-change committee.
The\n`release_note:breaking` label should be applied in these
situations.\n- [X] [Flaky
Test\nRunner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1)
was\nused on any tests changed\n- [X] The PR description includes the
appropriate Release Notes section,\nand the correct `release_note:*`
label is applied per
the\n[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)\n\n###
Identify risks\n\nDoes this PR introduce any risks? For example,
consider risks like hard\nto test bugs, performance regression,
potential of data loss.\n\nDescribe the risk, its severity, and
mitigation for each identified\nrisk. Invite stakeholders and evaluate
how to proceed before merging.\n\n- [ ] [See some
risk\nexamples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx)\n-
[ ] ...\n\n---------\n\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>\nCo-authored-by:
Elastic Machine
<elasticmachine@users.noreply.github.com>","sha":"d37fcb6fb686e59f18e541113ef16d01030c8d86"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/212865","number":212865,"mergeCommit":{"message":"[Security
Solution] [GenAi] refactor security ai assistant tools to use tool
helper method (#212865)\n\n## Summary\n\nClean up some security ai
assistant code.\n\n- Replace the usage of `new DynamicStructuredTool()`
with the `tool()`\nhelper method. This is the recommended approach today
and has the\ncorrect types to work
with\n[`Command`](https://langchain-ai.github.io/langgraphjs/concepts/low_level/#command).\n-
Extract code such as the default assistant graph state
and\nagentRunnableFactory to reduce cognitive overload.\n- Update
AssistantTool type definition\n\n### Checklist\n\nCheck the PR satisfies
following conditions. \n\nReviewers should verify this PR satisfies this
list as well.\n\n- [X] Any text added follows [EUI's
writing\nguidelines](https://elastic.github.io/eui/#/guidelines/writing),
uses\nsentence case text and includes
[i18n\nsupport](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md)\n-
[X]\n[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)\nwas
added for features that require explanation or tutorials\n- [X] [Unit or
functional\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\nwere
updated or added to match the most common scenarios\n- [X] If a plugin
configuration key changed, check if it needs to be\nallowlisted in the
cloud and added to the
[docker\nlist](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)\n-
[X] This was checked for breaking HTTP API changes, and any
breaking\nchanges have been approved by the breaking-change committee.
The\n`release_note:breaking` label should be applied in these
situations.\n- [X] [Flaky
Test\nRunner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1)
was\nused on any tests changed\n- [X] The PR description includes the
appropriate Release Notes section,\nand the correct `release_note:*`
label is applied per
the\n[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)\n\n###
Identify risks\n\nDoes this PR introduce any risks? For example,
consider risks like hard\nto test bugs, performance regression,
potential of data loss.\n\nDescribe the risk, its severity, and
mitigation for each identified\nrisk. Invite stakeholders and evaluate
how to proceed before merging.\n\n- [ ] [See some
risk\nexamples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx)\n-
[ ] ...\n\n---------\n\nCo-authored-by: kibanamachine
<42973632+kibanamachine@users.noreply.github.com>\nCo-authored-by:
Elastic Machine
<elasticmachine@users.noreply.github.com>","sha":"d37fcb6fb686e59f18e541113ef16d01030c8d86"}},{"branch":"8.x","label":"v8.19.0","branchLabelMappingKey":"^v8.19.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Kenneth Kreindler <42113355+KDKHD@users.noreply.github.com>
  • Loading branch information
kibanamachine and KDKHD authored Mar 3, 2025
1 parent dd55c99 commit 9d69b19
Show file tree
Hide file tree
Showing 14 changed files with 318 additions and 248 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ToolDefinition } from '@langchain/core/language_models/base';
import {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import type { StructuredToolInterface } from '@langchain/core/tools';
import {
AgentRunnableSequence,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
import { ChatPromptTemplate } from '@langchain/core/prompts';

export const TOOL_CALLING_LLM_TYPES = new Set(['bedrock', 'gemini']);

export const agentRunableFactory = async ({
llm,
isOpenAI,
llmType,
tools,
isStream,
prompt,
}: {
llm: ActionsClientChatBedrockConverse | ActionsClientChatVertexAI | ActionsClientChatOpenAI;
isOpenAI: boolean;
llmType: string | undefined;
tools: StructuredToolInterface[] | ToolDefinition[];
isStream: boolean;
prompt: ChatPromptTemplate;
}): Promise<AgentRunnableSequence> => {
const params = {
llm,
tools,
streamRunnable: isStream,
prompt,
} as const;

if (isOpenAI || llmType === 'inference') {
return createOpenAIToolsAgent(params);
}

if (llmType && TOOL_CALLING_LLM_TYPES.has(llmType)) {
return createToolCallingAgent(params);
}

return createStructuredChatAgent(params);
};
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
* 2.0.
*/

import { Annotation, END, START, StateGraph } from '@langchain/langgraph';
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
import { END, START, StateGraph } from '@langchain/langgraph';
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
import { StructuredTool } from '@langchain/core/tools';
import type { Logger } from '@kbn/logging';

import { BaseMessage } from '@langchain/core/messages';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common';
import { Replacements } from '@kbn/elastic-assistant-common';
import { PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
Expand All @@ -29,6 +27,7 @@ import { getPersistedConversation } from './nodes/get_persisted_conversation';
import { persistConversationChanges } from './nodes/persist_conversation_changes';
import { respond } from './nodes/respond';
import { NodeType } from './constants';
import { getStateAnnotation } from './state';

export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph';

Expand Down Expand Up @@ -61,87 +60,17 @@ export const getDefaultAssistantGraph = ({
getFormattedTime,
}: GetDefaultAssistantGraphParams) => {
try {
// Default graph state
const graphAnnotation = Annotation.Root({
input: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
lastNode: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'start',
}),
steps: Annotation<AgentStep[]>({
reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
default: () => [],
}),
hasRespondStep: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
agentOutcome: Annotation<AgentAction | AgentFinish | undefined>({
reducer: (
x: AgentAction | AgentFinish | undefined,
y?: AgentAction | AgentFinish | undefined
) => y ?? x,
default: () => undefined,
}),
messages: Annotation<BaseMessage[]>({
reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
default: () => [],
}),
chatTitle: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
llmType: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'unknown',
}),
isStream: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
isOssModel: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
connectorId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
conversation: Annotation<ConversationResponse | undefined>({
reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
default: () => undefined,
}),
conversationId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
responseLanguage: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'English',
}),
provider: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
formattedTime: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: getFormattedTime ?? (() => ''),
}),
});

// Default node parameters
const nodeParams: NodeParamsBase = {
actionsClient,
logger,
savedObjectsClient,
};

const stateAnnotation = getStateAnnotation({ getFormattedTime });

// Put together a new graph using default state from above
const graph = new StateGraph(graphAnnotation)
const graph = new StateGraph(stateAnnotation)
.addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) =>
getPersistedConversation({
...nodeParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@

import { StructuredTool } from '@langchain/core/tools';
import { getDefaultArguments } from '@kbn/langchain/server';
import {
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
Expand All @@ -25,12 +20,13 @@ import { getLlmClass } from '../../../../routes/utils';
import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types';
import { AssistantToolParams } from '../../../../types';
import { AgentExecutor } from '../../executors/types';
import { formatPrompt, formatPromptStructured } from './prompts';
import { formatPrompt } from './prompts';
import { GraphInputs } from './types';
import { getDefaultAssistantGraph } from './graph';
import { invokeGraph, streamGraph } from './helpers';
import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers';
import { DEFAULT_DATE_FORMAT_TZ } from '../../../../../common/constants';
import { agentRunableFactory } from './agentRunnable';

export const callAssistantGraph: AgentExecutor<true | false> = async ({
abortSignal,
Expand Down Expand Up @@ -179,28 +175,21 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
savedObjectsClient,
});

const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? await createToolCallingAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
})
: // used with OSS models
await createStructuredChatAgent({
llm: createLlmInstance(),
tools,
prompt: formatPromptStructured(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
});
const chatPromptTemplate = formatPrompt({
prompt: defaultSystemPrompt,
additionalPrompt: systemPrompt,
llmType,
isOpenAI,
});

const agentRunnable = await agentRunableFactory({
llm: createLlmInstance(),
isOpenAI,
llmType,
tools,
isStream,
prompt: chatPromptTemplate,
});

const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger);
const telemetryTracer = telemetryParams
Expand All @@ -214,6 +203,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
logger
)
: undefined;

const { provider } =
!llmType || llmType === 'inference'
? await resolveProviderAndModel({
Expand All @@ -240,7 +230,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
...(llmType === 'bedrock' ? { signal: abortSignal } : {}),
getFormattedTime: () =>
getFormattedTime({
screenContextTimezone: request.body.screenContext?.timeZone,
screenContextTimezone: screenContext?.timeZone,
uiSettingsDateFormatTimezone,
}),
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
*/

import { ChatPromptTemplate } from '@langchain/core/prompts';
import { TOOL_CALLING_LLM_TYPES } from './agentRunnable';

export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
const formatPromptToolcalling = (prompt: string, additionalPrompt?: string) =>
ChatPromptTemplate.fromMessages([
['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt],
['placeholder', '{knowledge_history}'],
Expand All @@ -16,7 +17,7 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
['placeholder', '{agent_scratchpad}'],
]);

export const formatPromptStructured = (prompt: string, additionalPrompt?: string) =>
const formatPromptStructured = (prompt: string, additionalPrompt?: string) =>
ChatPromptTemplate.fromMessages([
['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt],
['placeholder', '{knowledge_history}'],
Expand All @@ -26,3 +27,20 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string
'{input}\n\n{agent_scratchpad}\n\n(reminder to respond in a JSON blob no matter what)',
],
]);

export const formatPrompt = ({
isOpenAI,
llmType,
prompt,
additionalPrompt,
}: {
isOpenAI: boolean;
llmType: string | undefined;
prompt: string;
additionalPrompt?: string;
}) => {
if (isOpenAI || llmType === 'inference' || (llmType && TOOL_CALLING_LLM_TYPES.has(llmType))) {
return formatPromptToolcalling(prompt, additionalPrompt);
}
return formatPromptStructured(prompt, additionalPrompt);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ConversationResponse } from '@kbn/elastic-assistant-common';
import { BaseMessage } from '@langchain/core/messages';
import { Annotation } from '@langchain/langgraph';
import { AgentStep, AgentAction, AgentFinish } from 'langchain/agents';

export const getStateAnnotation = ({ getFormattedTime }: { getFormattedTime?: () => string }) => {
const graphAnnotation = Annotation.Root({
input: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
lastNode: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'start',
}),
steps: Annotation<AgentStep[]>({
reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
default: () => [],
}),
hasRespondStep: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
agentOutcome: Annotation<AgentAction | AgentFinish | undefined>({
reducer: (
x: AgentAction | AgentFinish | undefined,
y?: AgentAction | AgentFinish | undefined
) => y ?? x,
default: () => undefined,
}),
messages: Annotation<BaseMessage[]>({
reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
default: () => [],
}),
chatTitle: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
llmType: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'unknown',
}),
isStream: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
isOssModel: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
connectorId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
conversation: Annotation<ConversationResponse | undefined>({
reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
default: () => undefined,
}),
conversationId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
responseLanguage: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'English',
}),
provider: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
formattedTime: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: getFormattedTime ?? (() => ''),
}),
});

return graphAnnotation;
};
Loading

0 comments on commit 9d69b19

Please sign in to comment.