Skip to content

Commit

Permalink
[Security Solution] [GenAi] refactor security ai assistant tools to u…
Browse files Browse the repository at this point in the history
…se tool helper method (#212865)

## Summary

Clean up some security ai assistant code.

- Replace the usage of `new DynamicStructuredTool()` with the `tool()`
helper method. This is the recommended approach today and has the
correct types to work with
[`Command`](https://langchain-ai.github.io/langgraphjs/concepts/low_level/#command).
- Extract code such as the default assistant graph state and
agentRunnableFactory to reduce cognitive overload.
- Update AssistantTool type definition

### Checklist

Check the PR satisfies following conditions.

Reviewers should verify this PR satisfies this list as well.

- [X] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md)
- [X]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [X] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [X] If a plugin configuration key changed, check if it needs to be
allowlisted in the cloud and added to the [docker
list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)
- [X] This was checked for breaking HTTP API changes, and any breaking
changes have been approved by the breaking-change committee. The
`release_note:breaking` label should be applied in these situations.
- [X] [Flaky Test
Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was
used on any tests changed
- [X] The PR description includes the appropriate Release Notes section,
and the correct `release_note:*` label is applied per the
[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)

### Identify risks

Does this PR introduce any risks? For example, consider risks like hard
to test bugs, performance regression, potential of data loss.

Describe the risk, its severity, and mitigation for each identified
risk. Invite stakeholders and evaluate how to proceed before merging.

- [ ] [See some risk
examples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx)
- [ ] ...

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
(cherry picked from commit d37fcb6)
  • Loading branch information
KDKHD committed Mar 3, 2025
1 parent d6e8b96 commit dc3e07e
Show file tree
Hide file tree
Showing 14 changed files with 318 additions and 248 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ToolDefinition } from '@langchain/core/language_models/base';
import {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import type { StructuredToolInterface } from '@langchain/core/tools';
import {
AgentRunnableSequence,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
import { ChatPromptTemplate } from '@langchain/core/prompts';

export const TOOL_CALLING_LLM_TYPES = new Set(['bedrock', 'gemini']);

export const agentRunableFactory = async ({
llm,
isOpenAI,
llmType,
tools,
isStream,
prompt,
}: {
llm: ActionsClientChatBedrockConverse | ActionsClientChatVertexAI | ActionsClientChatOpenAI;
isOpenAI: boolean;
llmType: string | undefined;
tools: StructuredToolInterface[] | ToolDefinition[];
isStream: boolean;
prompt: ChatPromptTemplate;
}): Promise<AgentRunnableSequence> => {
const params = {
llm,
tools,
streamRunnable: isStream,
prompt,
} as const;

if (isOpenAI || llmType === 'inference') {
return createOpenAIToolsAgent(params);
}

if (llmType && TOOL_CALLING_LLM_TYPES.has(llmType)) {
return createToolCallingAgent(params);
}

return createStructuredChatAgent(params);
};
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
* 2.0.
*/

import { Annotation, END, START, StateGraph } from '@langchain/langgraph';
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
import { END, START, StateGraph } from '@langchain/langgraph';
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
import { StructuredTool } from '@langchain/core/tools';
import type { Logger } from '@kbn/logging';

import { BaseMessage } from '@langchain/core/messages';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common';
import { Replacements } from '@kbn/elastic-assistant-common';
import { PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
Expand All @@ -29,6 +27,7 @@ import { getPersistedConversation } from './nodes/get_persisted_conversation';
import { persistConversationChanges } from './nodes/persist_conversation_changes';
import { respond } from './nodes/respond';
import { NodeType } from './constants';
import { getStateAnnotation } from './state';

export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph';

Expand Down Expand Up @@ -61,87 +60,17 @@ export const getDefaultAssistantGraph = ({
getFormattedTime,
}: GetDefaultAssistantGraphParams) => {
try {
// Default graph state
const graphAnnotation = Annotation.Root({
input: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
lastNode: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'start',
}),
steps: Annotation<AgentStep[]>({
reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
default: () => [],
}),
hasRespondStep: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
agentOutcome: Annotation<AgentAction | AgentFinish | undefined>({
reducer: (
x: AgentAction | AgentFinish | undefined,
y?: AgentAction | AgentFinish | undefined
) => y ?? x,
default: () => undefined,
}),
messages: Annotation<BaseMessage[]>({
reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
default: () => [],
}),
chatTitle: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
llmType: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'unknown',
}),
isStream: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
isOssModel: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
connectorId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
conversation: Annotation<ConversationResponse | undefined>({
reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
default: () => undefined,
}),
conversationId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
responseLanguage: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'English',
}),
provider: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
formattedTime: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: getFormattedTime ?? (() => ''),
}),
});

// Default node parameters
const nodeParams: NodeParamsBase = {
actionsClient,
logger,
savedObjectsClient,
};

const stateAnnotation = getStateAnnotation({ getFormattedTime });

// Put together a new graph using default state from above
const graph = new StateGraph(graphAnnotation)
const graph = new StateGraph(stateAnnotation)
.addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) =>
getPersistedConversation({
...nodeParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@

import { StructuredTool } from '@langchain/core/tools';
import { getDefaultArguments } from '@kbn/langchain/server';
import {
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
Expand All @@ -25,12 +20,13 @@ import { getLlmClass } from '../../../../routes/utils';
import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types';
import { AssistantToolParams } from '../../../../types';
import { AgentExecutor } from '../../executors/types';
import { formatPrompt, formatPromptStructured } from './prompts';
import { formatPrompt } from './prompts';
import { GraphInputs } from './types';
import { getDefaultAssistantGraph } from './graph';
import { invokeGraph, streamGraph } from './helpers';
import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers';
import { DEFAULT_DATE_FORMAT_TZ } from '../../../../../common/constants';
import { agentRunableFactory } from './agentRunnable';

export const callAssistantGraph: AgentExecutor<true | false> = async ({
abortSignal,
Expand Down Expand Up @@ -179,28 +175,21 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
savedObjectsClient,
});

const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? await createToolCallingAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
})
: // used with OSS models
await createStructuredChatAgent({
llm: createLlmInstance(),
tools,
prompt: formatPromptStructured(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
});
const chatPromptTemplate = formatPrompt({
prompt: defaultSystemPrompt,
additionalPrompt: systemPrompt,
llmType,
isOpenAI,
});

const agentRunnable = await agentRunableFactory({
llm: createLlmInstance(),
isOpenAI,
llmType,
tools,
isStream,
prompt: chatPromptTemplate,
});

const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger);
const telemetryTracer = telemetryParams
Expand All @@ -214,6 +203,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
logger
)
: undefined;

const { provider } =
!llmType || llmType === 'inference'
? await resolveProviderAndModel({
Expand All @@ -240,7 +230,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
...(llmType === 'bedrock' ? { signal: abortSignal } : {}),
getFormattedTime: () =>
getFormattedTime({
screenContextTimezone: request.body.screenContext?.timeZone,
screenContextTimezone: screenContext?.timeZone,
uiSettingsDateFormatTimezone,
}),
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
*/

import { ChatPromptTemplate } from '@langchain/core/prompts';
import { TOOL_CALLING_LLM_TYPES } from './agentRunnable';

export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
const formatPromptToolcalling = (prompt: string, additionalPrompt?: string) =>
ChatPromptTemplate.fromMessages([
['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt],
['placeholder', '{knowledge_history}'],
Expand All @@ -16,7 +17,7 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
['placeholder', '{agent_scratchpad}'],
]);

export const formatPromptStructured = (prompt: string, additionalPrompt?: string) =>
const formatPromptStructured = (prompt: string, additionalPrompt?: string) =>
ChatPromptTemplate.fromMessages([
['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt],
['placeholder', '{knowledge_history}'],
Expand All @@ -26,3 +27,20 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string
'{input}\n\n{agent_scratchpad}\n\n(reminder to respond in a JSON blob no matter what)',
],
]);

export const formatPrompt = ({
isOpenAI,
llmType,
prompt,
additionalPrompt,
}: {
isOpenAI: boolean;
llmType: string | undefined;
prompt: string;
additionalPrompt?: string;
}) => {
if (isOpenAI || llmType === 'inference' || (llmType && TOOL_CALLING_LLM_TYPES.has(llmType))) {
return formatPromptToolcalling(prompt, additionalPrompt);
}
return formatPromptStructured(prompt, additionalPrompt);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ConversationResponse } from '@kbn/elastic-assistant-common';
import { BaseMessage } from '@langchain/core/messages';
import { Annotation } from '@langchain/langgraph';
import { AgentStep, AgentAction, AgentFinish } from 'langchain/agents';

export const getStateAnnotation = ({ getFormattedTime }: { getFormattedTime?: () => string }) => {
const graphAnnotation = Annotation.Root({
input: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
lastNode: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'start',
}),
steps: Annotation<AgentStep[]>({
reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
default: () => [],
}),
hasRespondStep: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
agentOutcome: Annotation<AgentAction | AgentFinish | undefined>({
reducer: (
x: AgentAction | AgentFinish | undefined,
y?: AgentAction | AgentFinish | undefined
) => y ?? x,
default: () => undefined,
}),
messages: Annotation<BaseMessage[]>({
reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
default: () => [],
}),
chatTitle: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
llmType: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'unknown',
}),
isStream: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
isOssModel: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
}),
connectorId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
conversation: Annotation<ConversationResponse | undefined>({
reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
default: () => undefined,
}),
conversationId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
responseLanguage: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'English',
}),
provider: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
formattedTime: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: getFormattedTime ?? (() => ''),
}),
});

return graphAnnotation;
};
Loading

0 comments on commit dc3e07e

Please sign in to comment.