diff --git a/src/platform/packages/shared/kbn-apm-synthtrace-client/src/lib/logs/index.ts b/src/platform/packages/shared/kbn-apm-synthtrace-client/src/lib/logs/index.ts index 5ffd12ff250d5..c521b7423df7d 100644 --- a/src/platform/packages/shared/kbn-apm-synthtrace-client/src/lib/logs/index.ts +++ b/src/platform/packages/shared/kbn-apm-synthtrace-client/src/lib/logs/index.ts @@ -57,6 +57,7 @@ export type LogDocument = Fields & 'cloud.availability_zone'?: string; 'cloud.project.id'?: string; 'cloud.instance.id'?: string; + 'client.ip'?: string; 'error.stack_trace'?: string; 'error.exception'?: unknown; 'error.log'?: unknown; @@ -68,6 +69,9 @@ export type LogDocument = Fields & 'event.duration': number; 'event.start': Date; 'event.end': Date; + 'event.category'?: string; + 'event.type'?: string; + 'event.outcome'?: string; labels?: Record; test_field: string | string[]; date: Date; @@ -76,8 +80,11 @@ export type LogDocument = Fields & svc: string; hostname: string; [LONG_FIELD_NAME]: string; - 'http.status_code'?: number; + 'http.response.status_code'?: number; + 'http.response.bytes'?: number; 'http.request.method'?: string; + 'http.request.referrer'?: string; + 'http.version'?: string; 'url.path'?: string; 'process.name'?: string; 'kubernetes.namespace'?: string; @@ -85,6 +92,7 @@ export type LogDocument = Fields & 'kubernetes.container.name'?: string; 'orchestrator.resource.name'?: string; tags?: string | string[]; + 'user_agent.name'?: string; }>; class Log extends Serializable { diff --git a/src/platform/packages/shared/kbn-apm-synthtrace/src/scenarios/apache_logs.ts b/src/platform/packages/shared/kbn-apm-synthtrace/src/scenarios/apache_logs.ts new file mode 100644 index 0000000000000..6a3b30867ff36 --- /dev/null +++ b/src/platform/packages/shared/kbn-apm-synthtrace/src/scenarios/apache_logs.ts @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +import { LogDocument, log } from '@kbn/apm-synthtrace-client'; +import moment from 'moment'; +import { random } from 'lodash'; +import { Scenario } from '../cli/scenario'; +import { withClient } from '../lib/utils/with_client'; +import { parseLogsScenarioOpts } from './helpers/logs_scenario_opts_parser'; +import { IndexTemplateName } from '../lib/logs/custom_logsdb_index_templates'; + +const scenario: Scenario = async (runOptions) => { + const { isLogsDb } = parseLogsScenarioOpts(runOptions.scenarioOpts); + + return { + bootstrap: async ({ logsEsClient }) => { + if (isLogsDb) await logsEsClient.createIndexTemplate(IndexTemplateName.LogsDb); + }, + teardown: async ({ logsEsClient }) => { + if (isLogsDb) await logsEsClient.deleteIndexTemplate(IndexTemplateName.LogsDb); + }, + + generate: ({ range, clients: { logsEsClient } }) => { + const { logger } = runOptions; + + // Normal access logs + const normalAccessLogs = range + .interval('1m') + .rate(50) + .generator((timestamp) => { + return Array(5) + .fill(0) + .map(() => { + const logsData = constructApacheLogData(); + + return log + .create({ isLogsDb }) + .message( + `${logsData['client.ip']} - - [${moment(timestamp).format( + 'DD/MMM/YYYY:HH:mm:ss Z' + )}] "${logsData['http.request.method']} ${logsData['url.path']} HTTP/${ + logsData['http.version'] + }" ${logsData['http.response.status_code']} ${logsData['http.response.bytes']}` + ) + .dataset('apache.access') + .defaults(logsData) + .timestamp(timestamp); + }); + }); + + // attack simulation logs + const attackSimulationLogs = range + .interval('1m') + .rate(2) + .generator((timestamp) => { + return Array(2) + .fill(0) + .map(() => { + const logsData = constructApacheLogData(); + + return log + .create({ isLogsDb }) + .message( + `ATTACK SIMULATION: ${logsData['client.ip']} attempted access to restricted path ${logsData['url.path']}` + ) + .dataset('apache.security') + .logLevel('warning') + .defaults({ + ...logsData, + 'event.category': 'network', + 'event.type': 'access', + 'event.outcome': 'failure', + }) + .timestamp(timestamp); + }); + }); + + return withClient( + logsEsClient, + logger.perf('generating_apache_logs', () => [normalAccessLogs, attackSimulationLogs]) + ); + }, + }; +}; + +export default scenario; + +function constructApacheLogData(): LogDocument { + const APACHE_LOG_SCENARIOS = [ + { + method: 'GET', + path: '/index.html', + responseCode: 200, + userAgent: 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36', + referrer: 'https://www.google.com', + }, + { + method: 'POST', + path: '/login', + responseCode: 401, + userAgent: 'PostmanRuntime/7.29.0', + referrer: '-', + }, + { + method: 'GET', + path: '/admin/dashboard', + responseCode: 403, + userAgent: 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)', + referrer: 'https://example.com/home', + }, + ]; + + const HOSTNAMES = ['www.example.com', 'blog.example.com', 'api.example.com']; + const CLOUD_REGIONS = ['us-east-1', 'eu-west-2', 'ap-southeast-1']; + + const index = Math.floor(Math.random() * APACHE_LOG_SCENARIOS.length); + const { method, path, responseCode, userAgent, referrer } = APACHE_LOG_SCENARIOS[index]; + + const clientIp = generateIpAddress(); + const hostname = HOSTNAMES[Math.floor(Math.random() * HOSTNAMES.length)]; + const cloudRegion = CLOUD_REGIONS[Math.floor(Math.random() * CLOUD_REGIONS.length)]; + + return { + 'http.request.method': method, + 'url.path': path, + 'http.response.status_code': responseCode, + hostname, + 'cloud.region': cloudRegion, + 'cloud.availability_zone': `${cloudRegion}a`, + 'client.ip': clientIp, + 'user_agent.name': userAgent, + 'http.request.referrer': referrer, + }; +} + +function generateIpAddress() { + return `${random(0, 255)}.${random(0, 255)}.${random(0, 255)}.${random(0, 255)}`; +} diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts index ffa83dbe92d77..8acbc59903dcf 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts @@ -16,7 +16,7 @@ import { FunctionCallChatFunction } from '../../service/types'; const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields'; export const GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE = `You are a helpful assistant for Elastic Observability. Your task is to determine which fields are relevant to the conversation by selecting only the field IDs from the provided list. -The list in the user message consists of JSON objects that map a human-readable "field" name to its unique "id". +The list in the user message consists of JSON objects that map a human-readable field "name" to its unique "id". You must not output any field names — only the corresponding "id" values. Ensure that your output follows the exact JSON format specified.`; export async function getRelevantFieldNames({ @@ -114,10 +114,12 @@ export async function getRelevantFieldNames({ '@timestamp': new Date().toISOString(), message: { role: MessageRole.User, - content: `Below is a list of fields. Each entry is a JSON object that contains a "field" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields: + content: `Below is a list of fields. Each entry is a JSON object that contains a "name" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields: ${fieldsInChunk - .map((field) => JSON.stringify({ field, id: shortIdTable.take(field) })) + .map((fieldName) => + JSON.stringify({ name: fieldName, id: shortIdTable.take(fieldName) }) + ) .join('\n')}`, }, }, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/index.ts index 77ba9afb18260..56a12bfaf1fcf 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/index.ts @@ -5,8 +5,11 @@ * 2.0. */ +import { IScopedClusterClient, Logger } from '@kbn/core/server'; +import { Message } from '../../../common'; import { FunctionRegistrationParameters } from '..'; import { FunctionVisibility } from '../../../common/functions/types'; +import { FunctionCallChatFunction, RespondFunctionResources } from '../../service/types'; import { getRelevantFieldNames } from './get_relevant_field_names'; export const GET_DATASET_INFO_FUNCTION_NAME = 'get_dataset_info'; @@ -32,68 +35,102 @@ export function registerGetDatasetInfoFunction({ index: { type: 'string', description: - 'index pattern the user is interested in or empty string to get information about all available indices', + 'Index pattern the user is interested in. You are allowed to specify multiple, comma-separated patterns like "index1,index2". If you provide an empty string, all indices will be returned. By default matching indicies in local and remote indices are searched. If you want to limit the search to a specific cluster you can prefix the index pattern with the cluster name. For example, "cluster1:my-index".', }, }, required: ['index'], } as const, }, - async ({ arguments: { index }, messages, chat }, signal) => { - const coreContext = await resources.context.core; + async ({ arguments: { index: indexPattern }, messages, chat }, signal) => { + const content = await getDatasetInfo({ resources, indexPattern, signal, messages, chat }); + return { content }; + } + ); +} - const esClient = coreContext.elasticsearch.client; - const savedObjectsClient = coreContext.savedObjects.client; +export async function getDatasetInfo({ + resources, + indexPattern, + signal, + messages, + chat, +}: { + resources: RespondFunctionResources; + indexPattern: string; + signal: AbortSignal; + messages: Message[]; + chat: FunctionCallChatFunction; +}) { + const coreContext = await resources.context.core; + const esClient = coreContext.elasticsearch.client; + const savedObjectsClient = coreContext.savedObjects.client; - let indices: string[] = []; + const indices = await getIndicesFromIndexPattern(indexPattern, esClient, resources.logger); + if (indices.length === 0 || indexPattern === '') { + return { indices, fields: [] }; + } - try { - const body = await esClient.asCurrentUser.indices.resolveIndex({ - name: index === '' ? ['*', '*:*'] : index.split(','), - expand_wildcards: 'open', - }); - indices = [ - ...body.indices.map((i) => i.name), - ...body.data_streams.map((d) => d.name), - ...body.aliases.map((d) => d.name), - ]; - } catch (e) { - indices = []; - } + try { + const { fields, stats } = await getRelevantFieldNames({ + index: indices, + messages, + esClient: esClient.asCurrentUser, + dataViews: await resources.plugins.dataViews.start(), + savedObjectsClient, + signal, + chat, + }); + return { indices, fields, stats }; + } catch (e) { + resources.logger.error(`Error getting relevant field names: ${e.message}`); + return { indices, fields: [] }; + } +} - if (index === '') { - return { - content: { - indices, - fields: [], - }, - }; +async function getIndicesFromIndexPattern( + indexPattern: string, + esClient: IScopedClusterClient, + logger: Logger +) { + let name: string[] = []; + if (indexPattern === '') { + name = ['*', '*:*']; + } else { + name = indexPattern.split(',').flatMap((pattern) => { + // search specific cluster + if (pattern.includes(':')) { + const [cluster, p] = pattern.split(':'); + return `${cluster}:*${p}*`; } - if (indices.length === 0) { - return { - content: { - indices, - fields: [], - }, - }; - } + // search across local and remote clusters + return [`*${pattern}*`, `*:*${pattern}*`]; + }); + } - const relevantFieldNames = await getRelevantFieldNames({ - index, - messages, - esClient: esClient.asCurrentUser, - dataViews: await resources.plugins.dataViews.start(), - savedObjectsClient, - signal, - chat, - }); - return { - content: { - indices: [index], - fields: relevantFieldNames.fields, - stats: relevantFieldNames.stats, - }, - }; + try { + const body = await esClient.asCurrentUser.indices.resolveIndex({ + name, + expand_wildcards: 'open', // exclude hidden and closed indices + }); + + // if there is an exact match, only return that + const hasExactMatch = + body.indices.some((i) => i.name === indexPattern) || + body.aliases.some((i) => i.name === indexPattern); + + if (hasExactMatch) { + return [indexPattern]; } - ); + + // otherwise return all matching indices, data streams, and aliases + return [ + ...body.indices.map((i) => i.name), + ...body.data_streams.map((d) => d.name), + ...body.aliases.map((d) => d.name), + ]; + } catch (e) { + logger.error(`Error resolving index pattern: ${e.message}`); + return []; + } } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/functions/route.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/functions/route.ts index c5f571769dfb6..12e71bbabfda1 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/functions/route.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/functions/route.ts @@ -6,6 +6,7 @@ */ import { notImplemented } from '@hapi/boom'; import { nonEmptyStringRt, toBooleanRt } from '@kbn/io-ts-utils'; +import { context as otelContext } from '@opentelemetry/api'; import * as t from 'io-ts'; import { v4 } from 'uuid'; import { FunctionDefinition } from '../../../common/functions/types'; @@ -14,6 +15,8 @@ import type { RecalledEntry } from '../../service/knowledge_base_service'; import { getSystemMessageFromInstructions } from '../../service/util/get_system_message_from_instructions'; import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; import { assistantScopeType } from '../runtime_types'; +import { getDatasetInfo } from '../../functions/get_dataset_info'; +import { LangTracer } from '../../service/client/instrumentation/lang_tracer'; const getFunctionsRoute = createObservabilityAIAssistantServerRoute({ endpoint: 'GET /internal/observability_ai_assistant/functions', @@ -78,6 +81,47 @@ const getFunctionsRoute = createObservabilityAIAssistantServerRoute({ }, }); +const functionDatasetInfoRoute = createObservabilityAIAssistantServerRoute({ + endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info', + params: t.type({ + query: t.type({ index: t.string, connectorId: t.string }), + }), + security: { + authz: { + requiredPrivileges: ['ai_assistant'], + }, + }, + handler: async (resources) => { + const client = await resources.service.getClient({ request: resources.request }); + + const { + query: { index, connectorId }, + } = resources.params; + + const controller = new AbortController(); + resources.request.events.aborted$.subscribe(() => { + controller.abort(); + }); + + const resp = await getDatasetInfo({ + resources, + indexPattern: index, + signal: controller.signal, + messages: [], + chat: (operationName, params) => { + return client.chat(operationName, { + ...params, + stream: true, + tracer: new LangTracer(otelContext.active()), + connectorId, + }); + }, + }); + + return resp; + }, +}); + const functionRecallRoute = createObservabilityAIAssistantServerRoute({ endpoint: 'POST /internal/observability_ai_assistant/functions/recall', params: t.type({ @@ -176,4 +220,5 @@ export const functionRoutes = { ...getFunctionsRoute, ...functionRecallRoute, ...functionSummariseRoute, + ...functionDatasetInfoRoute, }; diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts index 8fbb2b1be4ef3..c0618ce4c048c 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts @@ -17,9 +17,10 @@ import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; import { ApmAlertFields } from '../../../../../../../apm_api_integration/tests/alerts/helpers/alerting_api_helper'; import { LlmProxy, + RelevantField, createLlmProxy, } from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; -import { getMessageAddedEvents } from './helpers'; +import { getMessageAddedEvents, getSystemMessage, systemMessageSorted } from './helpers'; import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; import { APM_ALERTS_INDEX } from '../../../apm/alerts/helpers/alerting_helper'; @@ -32,7 +33,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon const samlAuth = getService('samlAuth'); describe('function: get_alerts_dataset_info', function () { - // Fails on MKI: https://github.com/elastic/kibana/issues/205581 this.tags(['failsOnMKI']); let llmProxy: LlmProxy; let connectorId: string; @@ -40,8 +40,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon let apmSynthtraceEsClient: ApmSynthtraceEsClient; let roleAuthc: RoleCredentials; let createdRuleId: string; - let expectedRelevantFieldNames: string[]; - let primarySystemMessage: string; + let getRelevantFields: () => Promise; before(async () => { ({ apmSynthtraceEsClient } = await createSyntheticApmData(getService)); @@ -58,26 +57,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon when: () => true, }); - void llmProxy.interceptWithFunctionRequest({ - name: 'select_relevant_fields', - // @ts-expect-error - when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields', - arguments: (requestBody) => { - const userMessage = last(requestBody.messages); - const topFields = (userMessage?.content as string) - .slice(204) // remove the prefix message and only get the JSON - .trim() - .split('\n') - .map((line) => JSON.parse(line)) - .slice(0, 5); - - expectedRelevantFieldNames = topFields.map(({ field }) => field); - - const fieldIds = topFields.map(({ id }) => id); - - return JSON.stringify({ fieldIds }); - }, - }); + ({ getRelevantFields } = llmProxy.interceptSelectRelevantFieldsToolChoice()); void llmProxy.interceptWithFunctionRequest({ name: 'alerts', @@ -114,19 +94,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon await llmProxy.waitForAllInterceptorsToHaveBeenCalled(); messageAddedEvents = getMessageAddedEvents(body); - - const { - body: { systemMessage }, - } = await observabilityAIAssistantAPIClient.editor({ - endpoint: 'GET /internal/observability_ai_assistant/functions', - params: { - query: { - scopes: ['observability'], - }, - }, - }); - - primarySystemMessage = systemMessage; }); after(async () => { @@ -228,9 +195,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); describe('The system message', () => { - it('has the primary system message', () => { - expect(sortSystemMessage(firstRequestBody.messages[0].content as string)).to.eql( - sortSystemMessage(primarySystemMessage) + it('has the primary system message', async () => { + const primarySystemMessage = await getSystemMessage(getService); + expect(systemMessageSorted(firstRequestBody.messages[0].content as string)).to.eql( + systemMessageSorted(primarySystemMessage) ); }); @@ -254,14 +222,11 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); it('contains a system generated user message with a list of field candidates', () => { - const hasList = secondRequestBody.messages.some( - (message) => - message.role === 'user' && - (message.content as string).includes('Below is a list of fields.') && - (message.content as string).includes('@timestamp') - ); + const lastMessage = last(secondRequestBody.messages); - expect(hasList).to.be(true); + expect(lastMessage?.role).to.be('user'); + expect(lastMessage?.content).to.contain('Below is a list of fields'); + expect(lastMessage?.content).to.contain('@timestamp'); }); it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => { @@ -294,7 +259,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon expect(hasFunctionRequest).to.be(true); }); - it('contains the `get_alerts_dataset_info` response', () => { + it('contains the `get_alerts_dataset_info` response', async () => { const functionResponse = last(thirdRequestBody.messages); const parsedContent = JSON.parse(functionResponse?.content as string) as { fields: string[]; @@ -303,7 +268,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon const fieldNamesWithType = parsedContent.fields; const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]); - expect(fieldNamesWithoutType).to.eql(expectedRelevantFieldNames); + const relevantFields = await getRelevantFields(); + expect(fieldNamesWithoutType).to.eql(relevantFields.map(({ name }) => name)); expect(fieldNamesWithType).to.eql([ '@timestamp:date', '_id:_id', @@ -314,13 +280,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); it('emits a messageAdded event with the `get_alerts_dataset_info` function response', async () => { - const messageWithDatasetInfo = messageAddedEvents.find( + const eventWithDatasetInfo = messageAddedEvents.find( ({ message }) => message.message.role === MessageRole.User && message.message.name === 'get_alerts_dataset_info' ); - const parsedContent = JSON.parse(messageWithDatasetInfo?.message.message.content!) as { + const parsedContent = JSON.parse(eventWithDatasetInfo?.message.message.content!) as { fields: string[]; }; @@ -361,12 +327,12 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); it('emits a messageAdded event with the `alert` function response', async () => { - const messageWithAlerts = messageAddedEvents.find( + const event = messageAddedEvents.find( ({ message }) => message.message.role === MessageRole.User && message.message.name === 'alerts' ); - const parsedContent = JSON.parse(messageWithAlerts?.message.message.content!) as { + const parsedContent = JSON.parse(event?.message.message.content!) as { total: number; alerts: any[]; }; @@ -490,11 +456,3 @@ async function createSyntheticApmData( return { apmSynthtraceEsClient }; } - -// order of instructions can vary, so we sort to compare them -function sortSystemMessage(message: string) { - return message - .split('\n\n') - .map((line) => line.trim()) - .sort(); -} diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_dataset_info.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_dataset_info.spec.ts new file mode 100644 index 0000000000000..7f810bada6692 --- /dev/null +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_dataset_info.spec.ts @@ -0,0 +1,377 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common'; +import expect from '@kbn/expect'; +import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace'; +import { last } from 'lodash'; +import { GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names'; +import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; +import { + LlmProxy, + RelevantField, + createLlmProxy, +} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; +import { getMessageAddedEvents, getSystemMessage, systemMessageSorted } from './helpers'; +import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; +import { createSimpleSyntheticLogs } from '../../synthtrace_scenarios/simple_logs'; + +export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) { + const log = getService('log'); + const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi'); + const synthtrace = getService('synthtrace'); + + describe('get_dataset_info', function () { + this.tags(['failsOnMKI']); + let llmProxy: LlmProxy; + let connectorId: string; + + before(async () => { + llmProxy = await createLlmProxy(log); + connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ + port: llmProxy.getPort(), + }); + }); + + after(async () => { + llmProxy.close(); + await observabilityAIAssistantAPIClient.deleteActionConnector({ + actionId: connectorId, + }); + }); + + // Calling `get_dataset_info` via the chat/complete endpoint + describe('POST /internal/observability_ai_assistant/chat/complete', function () { + let messageAddedEvents: MessageAddEvent[]; + let logsSynthtraceEsClient: LogsSynthtraceEsClient; + let getRelevantFields: () => Promise; + + const USER_MESSAGE = 'Do I have any Apache logs?'; + + before(async () => { + logsSynthtraceEsClient = synthtrace.createLogsSynthtraceEsClient(); + await createSimpleSyntheticLogs({ logsSynthtraceEsClient }); + + void llmProxy.interceptWithFunctionRequest({ + name: 'get_dataset_info', + arguments: () => JSON.stringify({ index: 'logs*' }), + when: () => true, + }); + + ({ getRelevantFields } = llmProxy.interceptSelectRelevantFieldsToolChoice()); + + void llmProxy.interceptConversation(`Yes, you do have logs. Congratulations! 🎈️🎈️🎈️`); + + const { status, body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'POST /internal/observability_ai_assistant/chat/complete', + params: { + body: { + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: USER_MESSAGE, + }, + }, + ], + connectorId, + persist: false, + screenContexts: [], + scopes: ['observability' as const], + }, + }, + }); + + expect(status).to.be(200); + + await llmProxy.waitForAllInterceptorsToHaveBeenCalled(); + messageAddedEvents = getMessageAddedEvents(body); + }); + + after(async () => { + await logsSynthtraceEsClient.clean(); + }); + + describe('LLM requests', () => { + let firstRequestBody: ChatCompletionStreamParams; + let secondRequestBody: ChatCompletionStreamParams; + let thirdRequestBody: ChatCompletionStreamParams; + + before(async () => { + firstRequestBody = llmProxy.interceptedRequests[0].requestBody; + secondRequestBody = llmProxy.interceptedRequests[1].requestBody; + thirdRequestBody = llmProxy.interceptedRequests[2].requestBody; + }); + + it('makes 3 requests to the LLM', () => { + expect(llmProxy.interceptedRequests.length).to.be(3); + }); + + describe('every request to the LLM', () => { + it('contains a system message', () => { + const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every( + ({ requestBody }) => { + const firstMessage = requestBody.messages[0]; + return ( + firstMessage.role === 'system' && + (firstMessage.content as string).includes('You are a helpful assistant') + ); + } + ); + expect(everyRequestHasSystemMessage).to.be(true); + }); + + it('contains the original user message', () => { + const everyRequestHasUserMessage = llmProxy.interceptedRequests.every( + ({ requestBody }) => + requestBody.messages.some( + (message) => + message.role === 'user' && (message.content as string) === USER_MESSAGE + ) + ); + expect(everyRequestHasUserMessage).to.be(true); + }); + + it('contains the context function request and context function response', () => { + const everyRequestHasContextFunction = llmProxy.interceptedRequests.every( + ({ requestBody }) => { + const hasContextFunctionRequest = requestBody.messages.some( + (message) => + message.role === 'assistant' && + message.tool_calls?.[0]?.function?.name === 'context' + ); + + const hasContextFunctionResponse = requestBody.messages.some( + (message) => + message.role === 'tool' && + (message.content as string).includes('screen_description') && + (message.content as string).includes('learnings') + ); + + return hasContextFunctionRequest && hasContextFunctionResponse; + } + ); + + expect(everyRequestHasContextFunction).to.be(true); + }); + }); + + describe('The first request', () => { + it('contains the correct number of messages', () => { + expect(firstRequestBody.messages.length).to.be(4); + }); + + it('contains the `get_dataset_info` tool', () => { + const hasTool = firstRequestBody.tools?.some( + (tool) => tool.function.name === 'get_dataset_info' + ); + + expect(hasTool).to.be(true); + }); + + it('leaves the function calling decision to the LLM via tool_choice=auto', () => { + expect(firstRequestBody.tool_choice).to.be('auto'); + }); + + describe('The system message', () => { + it('has the primary system message', async () => { + const primarySystemMessage = await getSystemMessage(getService); + expect(systemMessageSorted(firstRequestBody.messages[0].content as string)).to.eql( + systemMessageSorted(primarySystemMessage) + ); + }); + + it('has a different system message from request 2', () => { + expect(firstRequestBody.messages[0]).not.to.eql(secondRequestBody.messages[0]); + }); + + it('has the same system message as request 3', () => { + expect(firstRequestBody.messages[0]).to.eql(thirdRequestBody.messages[0]); + }); + }); + }); + + describe('The second request', () => { + it('contains the correct number of messages', () => { + expect(secondRequestBody.messages.length).to.be(5); + }); + + it('contains a system generated user message with a list of field candidates', () => { + const lastMessage = last(secondRequestBody.messages); + + expect(lastMessage?.role).to.be('user'); + expect(lastMessage?.content).to.contain('Below is a list of fields'); + expect(lastMessage?.content).to.contain('@timestamp'); + }); + + it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => { + const hasToolChoice = + // @ts-expect-error + secondRequestBody.tool_choice?.function?.name === 'select_relevant_fields'; + + expect(hasToolChoice).to.be(true); + }); + + it('has a custom, function-specific system message', () => { + expect(secondRequestBody.messages[0].content).to.be( + GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE + ); + }); + }); + + describe('The third request', () => { + it('contains the correct number of messages', () => { + expect(thirdRequestBody.messages.length).to.be(6); + }); + + it('contains the `get_dataset_info` request', () => { + const hasFunctionRequest = thirdRequestBody.messages.some( + (message) => + message.role === 'assistant' && + message.tool_calls?.[0]?.function?.name === 'get_dataset_info' + ); + + expect(hasFunctionRequest).to.be(true); + }); + + it('contains the `get_dataset_info` response', () => { + const functionResponseMessage = last(thirdRequestBody.messages); + const parsedContent = JSON.parse(functionResponseMessage?.content as string); + expect(Object.keys(parsedContent)).to.eql(['indices', 'fields', 'stats']); + expect(parsedContent.indices).to.eql([ + 'logs-web.access-default', + '.alerts-observability.logs.alerts-default', + ]); + }); + + it('emits a messageAdded event with the `get_dataset_info` function response', async () => { + const event = messageAddedEvents.find( + ({ message }) => + message.message.role === MessageRole.User && + message.message.name === 'get_dataset_info' + ); + + const parsedContent = JSON.parse(event?.message.message.content!) as { + indices: string[]; + fields: string[]; + }; + + const fieldNamesWithType = parsedContent.fields; + const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]); + + const relevantFields = await getRelevantFields(); + expect(fieldNamesWithoutType).to.eql(relevantFields.map(({ name }) => name)); + expect(parsedContent.indices).to.eql([ + 'logs-web.access-default', + '.alerts-observability.logs.alerts-default', + ]); + }); + }); + }); + + describe('messageAdded events', () => { + it('emits 5 messageAdded events', () => { + expect(messageAddedEvents.length).to.be(5); + }); + }); + }); + + // Calling `get_dataset_info` directly + describe('GET /internal/observability_ai_assistant/functions/get_dataset_info', () => { + let logsSynthtraceEsClient: LogsSynthtraceEsClient; + + before(async () => { + logsSynthtraceEsClient = synthtrace.createLogsSynthtraceEsClient(); + await Promise.all([ + createSimpleSyntheticLogs({ logsSynthtraceEsClient, dataset: 'zookeeper.access' }), + createSimpleSyntheticLogs({ logsSynthtraceEsClient, dataset: 'apache.access' }), + ]); + }); + + after(async () => { + await logsSynthtraceEsClient.clean(); + }); + + it('returns Zookeeper logs but not the Apache logs', async () => { + llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 }); + + const { body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info', + params: { + query: { + index: 'zookeeper', + connectorId, + }, + }, + }); + + expect(body.indices).to.eql(['logs-zookeeper.access-default']); + expect(body.fields.length).to.be.greaterThan(0); + }); + + it('returns both Zookeeper and Apache logs', async () => { + llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 }); + + const { body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info', + params: { + query: { + index: 'logs', + connectorId, + }, + }, + }); + + await llmProxy.waitForAllInterceptorsToHaveBeenCalled(); + + expect(body.indices).to.eql([ + 'logs-apache.access-default', + 'logs-zookeeper.access-default', + '.alerts-observability.logs.alerts-default', + ]); + expect(body.fields.length).to.be.greaterThan(0); + }); + + it('accepts a comma-separated of patterns', async () => { + llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 }); + + const { body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info', + params: { + query: { + index: 'zookeeper,apache', + connectorId, + }, + }, + }); + + await llmProxy.waitForAllInterceptorsToHaveBeenCalled(); + + expect(body.indices).to.eql([ + 'logs-apache.access-default', + 'logs-zookeeper.access-default', + ]); + }); + + it('handles no matching indices gracefully', async () => { + const { body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info', + params: { + query: { + index: 'foobarbaz', + connectorId, + }, + }, + }); + + expect(body.indices).to.eql([]); + expect(body.fields).to.eql([]); + }); + }); + }); +} diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts index f36b9b9eb6037..cc0b9a8aaf8c5 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts @@ -13,6 +13,7 @@ import { } from '@kbn/observability-ai-assistant-plugin/common'; import { Readable } from 'stream'; import type { AssistantScope } from '@kbn/ai-assistant-common'; +import { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; import type { ObservabilityAIAssistantApiClient } from '../../../../../services/observability_ai_assistant_api'; function decodeEvents(body: Readable | string) { @@ -73,3 +74,28 @@ export async function invokeChatCompleteWithFunctionRequest({ return body; } + +// order of instructions can vary, so we sort to compare them +export function systemMessageSorted(message: string) { + return message + .split('\n\n') + .map((line) => line.trim()) + .sort(); +} + +export async function getSystemMessage( + getService: DeploymentAgnosticFtrProviderContext['getService'] +) { + const apiClient = getService('observabilityAIAssistantApi'); + + const { body } = await apiClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions', + params: { + query: { + scopes: ['observability'], + }, + }, + }); + + return body.systemMessage; +} diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts index 31d0b5f5c836c..c7db9d728337c 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts @@ -17,6 +17,7 @@ export default function aiAssistantApiIntegrationTests({ loadTestFile(require.resolve('./complete/complete.spec.ts')); loadTestFile(require.resolve('./complete/functions/alerts.spec.ts')); loadTestFile(require.resolve('./complete/functions/get_alerts_dataset_info.spec.ts')); + loadTestFile(require.resolve('./complete/functions/get_dataset_info.spec.ts')); loadTestFile(require.resolve('./complete/functions/elasticsearch.spec.ts')); loadTestFile(require.resolve('./complete/functions/summarize.spec.ts')); loadTestFile(require.resolve('./public_complete/public_complete.spec.ts')); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/synthtrace_scenarios/simple_logs.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/synthtrace_scenarios/simple_logs.ts new file mode 100644 index 0000000000000..554bdaa7012c0 --- /dev/null +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/synthtrace_scenarios/simple_logs.ts @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { timerange, log } from '@kbn/apm-synthtrace-client'; +import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace'; + +export async function createSimpleSyntheticLogs({ + logsSynthtraceEsClient, + message, + dataset, +}: { + logsSynthtraceEsClient: LogsSynthtraceEsClient; + message?: string; + dataset?: string; +}) { + const range = timerange('now-15m', 'now'); + + const simpleLogs = range + .interval('1m') + .rate(1) + .generator((timestamp) => + log + .create() + .message(message ?? 'simple log message') + .dataset(dataset ?? 'web.access') + .timestamp(timestamp) + ); + + await logsSynthtraceEsClient.index([simpleLogs]); +} diff --git a/x-pack/test/common/utils/synthtrace/apm_es_client.ts b/x-pack/test/common/utils/synthtrace/apm_es_client.ts index 9bdc258c1e1be..932b55cafd056 100644 --- a/x-pack/test/common/utils/synthtrace/apm_es_client.ts +++ b/x-pack/test/common/utils/synthtrace/apm_es_client.ts @@ -13,7 +13,7 @@ interface GetApmSynthtraceEsClientParams { packageVersion: string; } -export async function getApmSynthtraceEsClient({ +export function getApmSynthtraceEsClient({ client, packageVersion, }: GetApmSynthtraceEsClientParams) { diff --git a/x-pack/test/common/utils/synthtrace/infra_es_client.ts b/x-pack/test/common/utils/synthtrace/infra_es_client.ts index 7e39942a9a46c..1eccad6a2ab3c 100644 --- a/x-pack/test/common/utils/synthtrace/infra_es_client.ts +++ b/x-pack/test/common/utils/synthtrace/infra_es_client.ts @@ -8,7 +8,7 @@ import { Client } from '@elastic/elasticsearch'; import { InfraSynthtraceEsClient, createLogger, LogLevel } from '@kbn/apm-synthtrace'; -export async function getInfraSynthtraceEsClient(client: Client) { +export function getInfraSynthtraceEsClient(client: Client) { return new InfraSynthtraceEsClient({ client, logger: createLogger(LogLevel.info), diff --git a/x-pack/test/common/utils/synthtrace/logs_es_client.ts b/x-pack/test/common/utils/synthtrace/logs_es_client.ts index 4d7222818bb9c..a6d049d8d3b6b 100644 --- a/x-pack/test/common/utils/synthtrace/logs_es_client.ts +++ b/x-pack/test/common/utils/synthtrace/logs_es_client.ts @@ -8,7 +8,7 @@ import { Client } from '@elastic/elasticsearch'; import { LogsSynthtraceEsClient, createLogger, LogLevel } from '@kbn/apm-synthtrace'; -export async function getLogsSynthtraceEsClient(client: Client) { +export function getLogsSynthtraceEsClient(client: Client) { return new LogsSynthtraceEsClient({ client, logger: createLogger(LogLevel.info), diff --git a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts index 030550f7a0c67..2cf3548a5f6e5 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts @@ -9,7 +9,7 @@ import { ToolingLog } from '@kbn/tooling-log'; import getPort from 'get-port'; import { v4 as uuidv4 } from 'uuid'; import http, { type Server } from 'http'; -import { isString, once, pull, isFunction } from 'lodash'; +import { isString, once, pull, isFunction, last } from 'lodash'; import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title'; import pRetry from 'p-retry'; import type { ChatCompletionChunkToolCall } from '@kbn/inference-common'; @@ -36,6 +36,12 @@ export interface ToolMessage { content?: string; tool_calls?: ChatCompletionChunkToolCall[]; } + +export interface RelevantField { + id: string; + name: string; +} + export interface LlmResponseSimulator { requestBody: ChatCompletionStreamParams; status: (code: number) => void; @@ -180,6 +186,39 @@ export class LlmProxy { }).completeAfterIntercept(); } + interceptSelectRelevantFieldsToolChoice({ + from = 0, + to = 5, + }: { from?: number; to?: number } = {}) { + let relevantFields: RelevantField[] = []; + const simulator = this.interceptWithFunctionRequest({ + name: 'select_relevant_fields', + // @ts-expect-error + when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields', + arguments: (requestBody) => { + const messageWithFieldIds = last(requestBody.messages); + relevantFields = (messageWithFieldIds?.content as string) + .split('\n\n') + .slice(1) + .join('') + .trim() + .split('\n') + .slice(from, to) + .map((line) => JSON.parse(line) as RelevantField); + + return JSON.stringify({ fieldIds: relevantFields.map(({ id }) => id) }); + }, + }); + + return { + simulator, + getRelevantFields: async () => { + await simulator; + return relevantFields; + }, + }; + } + interceptTitle(title: string) { return this.interceptWithFunctionRequest({ name: TITLE_CONVERSATION_FUNCTION_NAME,