From b633bcd6b6d4cb8b437bb341d2db549a0ff8938d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Wed, 5 Mar 2025 09:09:22 +0100 Subject: [PATCH] [Obs AI Assistant] Add API test for `get_alerts_dataset_info` tool (#212858) Follow-up to: https://github.com/elastic/kibana/pull/212077 This PR includes an API test that covers `get_alerts_dataset_info` and would have caught the bug fixed in https://github.com/elastic/kibana/pull/212077. It also contains the following bug fixes: - Fix system message in `select_relevant_fields` - Change prompt in `select_relevant_fields` so that the LLM consistently uses the right format when responding. (cherry picked from commit 0fb83efd82ae3ebd8a9fe27813e436b80cd240d3) --- .../get_relevant_field_names.ts | 24 +- .../server/functions/index.ts | 84 ++- .../server/service/client/index.ts | 2 +- .../server/functions/alerts.ts | 3 +- .../server/functions/query/index.ts | 15 +- .../ai_assistant/chat/chat.spec.ts | 4 +- .../ai_assistant/complete/complete.spec.ts | 10 +- .../complete/functions/alerts.spec.ts | 2 +- .../complete/functions/elasticsearch.spec.ts | 2 +- .../functions/get_alerts_dataset_info.spec.ts | 500 ++++++++++++++++++ .../complete/functions/helpers.ts | 7 + .../complete/functions/summarize.spec.ts | 2 +- .../apis/observability/ai_assistant/index.ts | 1 + .../knowledge_base_user_instructions.spec.ts | 6 +- .../public_complete/public_complete.spec.ts | 6 +- .../common/create_llm_proxy.ts | 187 ++++--- .../tests/contextual_insights/index.spec.ts | 2 +- .../tests/conversations/index.spec.ts | 34 +- 18 files changed, 730 insertions(+), 161 deletions(-) create mode 100644 x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts index 847d9f2980053..ffa83dbe92d77 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts @@ -13,6 +13,12 @@ import { MessageRole, ShortIdTable, type Message } from '../../../common'; import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks'; import { FunctionCallChatFunction } from '../../service/types'; +const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields'; +export const GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE = `You are a helpful assistant for Elastic Observability. +Your task is to determine which fields are relevant to the conversation by selecting only the field IDs from the provided list. +The list in the user message consists of JSON objects that map a human-readable "field" name to its unique "id". +You must not output any field names — only the corresponding "id" values. Ensure that your output follows the exact JSON format specified.`; + export async function getRelevantFieldNames({ index, start, @@ -100,11 +106,7 @@ export async function getRelevantFieldNames({ await chat('get_relevant_dataset_names', { signal, stream: true, - systemMessage: `You are a helpful assistant for Elastic Observability. - Your task is to create a list of field names that are relevant - to the conversation, using ONLY the list of fields and - types provided in the last user message. DO NOT UNDER ANY - CIRCUMSTANCES include fields not mentioned in this list.`, + systemMessage: GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE, messages: [ // remove the last function request ...messages.slice(0, -1), @@ -112,7 +114,7 @@ export async function getRelevantFieldNames({ '@timestamp': new Date().toISOString(), message: { role: MessageRole.User, - content: `This is the list: + content: `Below is a list of fields. Each entry is a JSON object that contains a "field" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields: ${fieldsInChunk .map((field) => JSON.stringify({ field, id: shortIdTable.take(field) })) @@ -122,8 +124,12 @@ export async function getRelevantFieldNames({ ], functions: [ { - name: 'select_relevant_fields', - description: 'The IDs of the fields you consider relevant to the conversation', + name: SELECT_RELEVANT_FIELDS_NAME, + description: `Return only the field IDs (from the provided list) that you consider relevant to the conversation. Do not use any of the field names. Your response must be in the exact JSON format: + { + "fieldIds": ["id1", "id2", "id3"] + } + Only include IDs from the list provided in the user message.`, parameters: { type: 'object', properties: { @@ -138,7 +144,7 @@ export async function getRelevantFieldNames({ } as const, }, ], - functionCall: 'select_relevant_fields', + functionCall: SELECT_RELEVANT_FIELDS_NAME, }) ).pipe(concatenateChatCompletionChunks()); diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/index.ts index 36a0b66dc60b7..3599f3ca11f77 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/index.ts @@ -39,52 +39,46 @@ export const registerFunctions: RegistrationCallback = async ({ }; const isServerless = !!resources.plugins.serverless; - if (scopes.includes('observability')) { - functions.registerInstruction(`You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities. - It's very important to not assume what the user is meaning. Ask them for clarification if needed. - - If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation. - - In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\ - /\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important! - - You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response. - - Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language. - - If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results - returned to you, before executing the same tool or another tool again if needed. - - DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (\`service.name == "foo"\`) with "kqlFilter" (\`service.name:"foo"\`). - - The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the ${ - isServerless ? `Project settings.` : `Stack Management app under the option AI Assistants` - }. - If the user asks how to change the language, reply in the same language the user asked in.`); - } - - if (scopes.length === 0 || (scopes.length === 1 && scopes[0] === 'all')) { - functions.registerInstruction( - `You are a helpful assistant for Elasticsearch. Your goal is to help Elasticsearch users accomplish tasks using Kibana and Elasticsearch. You can help them construct queries, index data, search data, use Elasticsearch APIs, generate sample data, visualise and analyze data. - - It's very important to not assume what the user means. Ask them for clarification if needed. - - If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation. - - In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\ - /\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important! - - You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response. - - If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results - returned to you, before executing the same tool or another tool again if needed. - - The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the ${ - isServerless ? `Project settings.` : `Stack Management app under the option AI Assistants` - }. - If the user asks how to change the language, reply in the same language the user asked in.` - ); + const isObservabilityDeployment = scopes.includes('observability'); + const isGenericDeployment = scopes.length === 0 || (scopes.length === 1 && scopes[0] === 'all'); + + if (isObservabilityDeployment || isGenericDeployment) { + functions.registerInstruction(` +${ + isObservabilityDeployment + ? `You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities.` + : `You are a helpful assistant for Elasticsearch. Your goal is to help Elasticsearch users accomplish tasks using Kibana and Elasticsearch. You can help them construct queries, index data, search data, use Elasticsearch APIs, generate sample data, visualise and analyze data.` +} + It's very important to not assume what the user means. Ask them for clarification if needed. + + If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation. + + In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\ + /\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important! + + You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response. + + ${ + isObservabilityDeployment + ? 'Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language.' + : '' + } + + If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results + returned to you, before executing the same tool or another tool again if needed. + + + ${ + isObservabilityDeployment + ? 'DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (`service.name == "foo"`) with "kqlFilter" (`service.name:"foo"`).' + : '' + } + + The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the ${ + isServerless ? `Project settings.` : `Stack Management app under the option AI Assistants` + }. + If the user asks how to change the language, reply in the same language the user asked in.`); } const { ready: isKnowledgeBaseReady } = await client.getKnowledgeBaseStatus(); diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts index c59cce4099e78..2895ff596254a 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts @@ -274,8 +274,8 @@ export class ObservabilityAIAssistantClient { chat: (name, chatParams) => { // inject a chat function with predefined parameters return this.chat(name, { - ...chatParams, systemMessage, + ...chatParams, signal, simulateFunctionCalling, connectorId, diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/alerts.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/alerts.ts index 868082ad69b89..e1074eaa9e616 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/alerts.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/alerts.ts @@ -112,9 +112,10 @@ export function registerAlertsFunction({ signal, chat: ( operationName, - { messages: nextMessages, functionCall, functions: nextFunctions } + { messages: nextMessages, functionCall, functions: nextFunctions, systemMessage } ) => { return chat(operationName, { + systemMessage, messages: nextMessages, functionCall, functions: nextFunctions, diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts index 0316e97deeade..3952bacfb3c24 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts @@ -18,7 +18,6 @@ import { createFunctionResponseMessage } from '@kbn/observability-ai-assistant-p import { convertMessagesForInference } from '@kbn/observability-ai-assistant-plugin/common/convert_messages_for_inference'; import { map } from 'rxjs'; import { v4 } from 'uuid'; -import { RegisterInstructionCallback } from '@kbn/observability-ai-assistant-plugin/server/service/types'; import type { FunctionRegistrationParameters } from '..'; import { runAndValidateEsqlQuery } from './validate_esql_query'; @@ -30,9 +29,12 @@ export function registerQueryFunction({ resources, pluginsStart, }: FunctionRegistrationParameters) { - const instruction: RegisterInstructionCallback = ({ availableFunctionNames }) => - availableFunctionNames.includes(QUERY_FUNCTION_NAME) - ? `You MUST use the "${QUERY_FUNCTION_NAME}" function when the user wants to: + functions.registerInstruction(({ availableFunctionNames }) => { + if (!availableFunctionNames.includes(QUERY_FUNCTION_NAME)) { + return; + } + + return `You MUST use the "${QUERY_FUNCTION_NAME}" function when the user wants to: - visualize data - run any arbitrary query - breakdown or filter ES|QL queries that are displayed on the current page @@ -48,9 +50,8 @@ export function registerQueryFunction({ even if it has been called before. When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt. - If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.` - : undefined; - functions.registerInstruction(instruction); + If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.`; + }); functions.registerFunction( { diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/chat/chat.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/chat/chat.spec.ts index 5b56118cbb92c..4a7772c904df7 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/chat/chat.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/chat/chat.spec.ts @@ -85,7 +85,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); expect(status).to.be(200); }); @@ -104,7 +104,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); const simulator = await simulatorPromise; const requestData = simulator.requestBody; // This is the request sent to the LLM expect(requestData.messages[0].content).to.eql(SYSTEM_MESSAGE); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/complete.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/complete.spec.ts index 65aa6068c91b6..821a775abbb6a 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/complete.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/complete.spec.ts @@ -76,7 +76,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon scopes: ['all'], }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); return String(response.body) .split('\n') @@ -133,7 +133,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon await new Promise((resolve) => passThrough.on('end', () => resolve())); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); parsedEvents = decodeEvents(receivedChunks.join('')); }); @@ -243,7 +243,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); const simulator = await simulatorPromise; const requestData = simulator.requestBody; expect(requestData.messages[0].role).to.eql('system'); @@ -420,7 +420,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon expect(createResponse.status).to.be(200); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); conversationCreatedEvent = getConversationCreatedEvent(createResponse.body); @@ -463,7 +463,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon expect(updatedResponse.status).to.be(200); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); }); after(async () => { diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/alerts.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/alerts.spec.ts index c5629f6844c60..5dc629eb5937e 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/alerts.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/alerts.spec.ts @@ -46,7 +46,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); alertsEvents = getMessageAddedEvents(alertsResponseBody); }); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/elasticsearch.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/elasticsearch.spec.ts index 0a25f8d6111c2..c7ab9650be6f6 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/elasticsearch.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/elasticsearch.spec.ts @@ -65,7 +65,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); events = getMessageAddedEvents(responseBody); }); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts new file mode 100644 index 0000000000000..8fbb2b1be4ef3 --- /dev/null +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/get_alerts_dataset_info.spec.ts @@ -0,0 +1,500 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common'; +import expect from '@kbn/expect'; +import { ApmRuleType } from '@kbn/rule-data-utils'; +import { apm, timerange } from '@kbn/apm-synthtrace-client'; +import { ApmSynthtraceEsClient } from '@kbn/apm-synthtrace'; +import { RoleCredentials } from '@kbn/ftr-common-functional-services'; +import { last } from 'lodash'; +import { GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names'; +import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; +import { ApmAlertFields } from '../../../../../../../apm_api_integration/tests/alerts/helpers/alerting_api_helper'; +import { + LlmProxy, + createLlmProxy, +} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy'; +import { getMessageAddedEvents } from './helpers'; +import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; +import { APM_ALERTS_INDEX } from '../../../apm/alerts/helpers/alerting_helper'; + +const USER_MESSAGE = 'How many alerts do I have for the past 10 days?'; + +export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) { + const log = getService('log'); + const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi'); + const alertingApi = getService('alertingApi'); + const samlAuth = getService('samlAuth'); + + describe('function: get_alerts_dataset_info', function () { + // Fails on MKI: https://github.com/elastic/kibana/issues/205581 + this.tags(['failsOnMKI']); + let llmProxy: LlmProxy; + let connectorId: string; + let messageAddedEvents: MessageAddEvent[]; + let apmSynthtraceEsClient: ApmSynthtraceEsClient; + let roleAuthc: RoleCredentials; + let createdRuleId: string; + let expectedRelevantFieldNames: string[]; + let primarySystemMessage: string; + + before(async () => { + ({ apmSynthtraceEsClient } = await createSyntheticApmData(getService)); + ({ roleAuthc, createdRuleId } = await createApmErrorCountRule(getService)); + + llmProxy = await createLlmProxy(log); + connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ + port: llmProxy.getPort(), + }); + + void llmProxy.interceptWithFunctionRequest({ + name: 'get_alerts_dataset_info', + arguments: () => JSON.stringify({ start: 'now-10d', end: 'now' }), + when: () => true, + }); + + void llmProxy.interceptWithFunctionRequest({ + name: 'select_relevant_fields', + // @ts-expect-error + when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields', + arguments: (requestBody) => { + const userMessage = last(requestBody.messages); + const topFields = (userMessage?.content as string) + .slice(204) // remove the prefix message and only get the JSON + .trim() + .split('\n') + .map((line) => JSON.parse(line)) + .slice(0, 5); + + expectedRelevantFieldNames = topFields.map(({ field }) => field); + + const fieldIds = topFields.map(({ id }) => id); + + return JSON.stringify({ fieldIds }); + }, + }); + + void llmProxy.interceptWithFunctionRequest({ + name: 'alerts', + arguments: () => JSON.stringify({ start: 'now-10d', end: 'now' }), + when: () => true, + }); + + void llmProxy.interceptConversation( + `You have active alerts for the past 10 days. Back to work!` + ); + + const { status, body } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'POST /internal/observability_ai_assistant/chat/complete', + params: { + body: { + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: USER_MESSAGE, + }, + }, + ], + connectorId, + persist: false, + screenContexts: [], + scopes: ['observability' as const], + }, + }, + }); + + expect(status).to.be(200); + + await llmProxy.waitForAllInterceptorsToHaveBeenCalled(); + messageAddedEvents = getMessageAddedEvents(body); + + const { + body: { systemMessage }, + } = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions', + params: { + query: { + scopes: ['observability'], + }, + }, + }); + + primarySystemMessage = systemMessage; + }); + + after(async () => { + llmProxy.close(); + await observabilityAIAssistantAPIClient.deleteActionConnector({ + actionId: connectorId, + }); + + await apmSynthtraceEsClient.clean(); + await alertingApi.cleanUpAlerts({ + roleAuthc, + ruleId: createdRuleId, + alertIndexName: APM_ALERTS_INDEX, + consumer: 'apm', + }); + + await samlAuth.invalidateM2mApiKeyWithRoleScope(roleAuthc); + }); + + describe('LLM requests', () => { + let firstRequestBody: ChatCompletionStreamParams; + let secondRequestBody: ChatCompletionStreamParams; + let thirdRequestBody: ChatCompletionStreamParams; + let fourthRequestBody: ChatCompletionStreamParams; + + before(async () => { + firstRequestBody = llmProxy.interceptedRequests[0].requestBody; + secondRequestBody = llmProxy.interceptedRequests[1].requestBody; + thirdRequestBody = llmProxy.interceptedRequests[2].requestBody; + fourthRequestBody = llmProxy.interceptedRequests[3].requestBody; + }); + + it('makes 4 requests to the LLM', () => { + expect(llmProxy.interceptedRequests.length).to.be(4); + }); + + describe('every request to the LLM', () => { + it('contains a system message', () => { + const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every( + ({ requestBody }) => { + const firstMessage = requestBody.messages[0]; + return ( + firstMessage.role === 'system' && + (firstMessage.content as string).includes('You are a helpful assistant') + ); + } + ); + expect(everyRequestHasSystemMessage).to.be(true); + }); + + it('contains the original user message', () => { + const everyRequestHasUserMessage = llmProxy.interceptedRequests.every(({ requestBody }) => + requestBody.messages.some( + (message) => message.role === 'user' && (message.content as string) === USER_MESSAGE + ) + ); + expect(everyRequestHasUserMessage).to.be(true); + }); + + it('contains the context function request and context function response', () => { + const everyRequestHasContextFunction = llmProxy.interceptedRequests.every( + ({ requestBody }) => { + const hasContextFunctionRequest = requestBody.messages.some( + (message) => + message.role === 'assistant' && + message.tool_calls?.[0]?.function?.name === 'context' + ); + + const hasContextFunctionResponse = requestBody.messages.some( + (message) => + message.role === 'tool' && + (message.content as string).includes('screen_description') && + (message.content as string).includes('learnings') + ); + + return hasContextFunctionRequest && hasContextFunctionResponse; + } + ); + + expect(everyRequestHasContextFunction).to.be(true); + }); + }); + + describe('The first request', () => { + it('contains the correct number of messages', () => { + expect(firstRequestBody.messages.length).to.be(4); + }); + + it('contains the `get_alerts_dataset_info` tool', () => { + const hasTool = firstRequestBody.tools?.some( + (tool) => tool.function.name === 'get_alerts_dataset_info' + ); + + expect(hasTool).to.be(true); + }); + + it('leaves the function calling decision to the LLM via tool_choice=auto', () => { + expect(firstRequestBody.tool_choice).to.be('auto'); + }); + + describe('The system message', () => { + it('has the primary system message', () => { + expect(sortSystemMessage(firstRequestBody.messages[0].content as string)).to.eql( + sortSystemMessage(primarySystemMessage) + ); + }); + + it('has a different system message from request 2', () => { + expect(firstRequestBody.messages[0]).not.to.eql(secondRequestBody.messages[0]); + }); + + it('has the same system message as request 3', () => { + expect(firstRequestBody.messages[0]).to.eql(thirdRequestBody.messages[0]); + }); + + it('has the same system message as request 4', () => { + expect(firstRequestBody.messages[0]).to.eql(fourthRequestBody.messages[0]); + }); + }); + }); + + describe('The second request', () => { + it('contains the correct number of messages', () => { + expect(secondRequestBody.messages.length).to.be(5); + }); + + it('contains a system generated user message with a list of field candidates', () => { + const hasList = secondRequestBody.messages.some( + (message) => + message.role === 'user' && + (message.content as string).includes('Below is a list of fields.') && + (message.content as string).includes('@timestamp') + ); + + expect(hasList).to.be(true); + }); + + it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => { + const hasToolChoice = + // @ts-expect-error + secondRequestBody.tool_choice?.function?.name === 'select_relevant_fields'; + + expect(hasToolChoice).to.be(true); + }); + + it('has a custom, function-specific system message', () => { + expect(secondRequestBody.messages[0].content).to.be( + GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE + ); + }); + }); + + describe('The third request', () => { + it('contains the correct number of messages', () => { + expect(thirdRequestBody.messages.length).to.be(6); + }); + + it('contains the `get_alerts_dataset_info` request', () => { + const hasFunctionRequest = thirdRequestBody.messages.some( + (message) => + message.role === 'assistant' && + message.tool_calls?.[0]?.function?.name === 'get_alerts_dataset_info' + ); + + expect(hasFunctionRequest).to.be(true); + }); + + it('contains the `get_alerts_dataset_info` response', () => { + const functionResponse = last(thirdRequestBody.messages); + const parsedContent = JSON.parse(functionResponse?.content as string) as { + fields: string[]; + }; + + const fieldNamesWithType = parsedContent.fields; + const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]); + + expect(fieldNamesWithoutType).to.eql(expectedRelevantFieldNames); + expect(fieldNamesWithType).to.eql([ + '@timestamp:date', + '_id:_id', + '_ignored:string', + '_index:_index', + '_score:number', + ]); + }); + + it('emits a messageAdded event with the `get_alerts_dataset_info` function response', async () => { + const messageWithDatasetInfo = messageAddedEvents.find( + ({ message }) => + message.message.role === MessageRole.User && + message.message.name === 'get_alerts_dataset_info' + ); + + const parsedContent = JSON.parse(messageWithDatasetInfo?.message.message.content!) as { + fields: string[]; + }; + + expect(parsedContent.fields).to.eql([ + '@timestamp:date', + '_id:_id', + '_ignored:string', + '_index:_index', + '_score:number', + ]); + }); + + it('contains the `alerts` tool', () => { + const hasTool = thirdRequestBody.tools?.some((tool) => tool.function.name === 'alerts'); + + expect(hasTool).to.be(true); + }); + }); + + describe('The fourth request', () => { + it('contains the correct number of messages', () => { + expect(fourthRequestBody.messages.length).to.be(8); + }); + + it('contains the `alerts` request', () => { + const hasFunctionRequest = fourthRequestBody.messages.some( + (message) => + message.role === 'assistant' && message.tool_calls?.[0]?.function?.name === 'alerts' + ); + + expect(hasFunctionRequest).to.be(true); + }); + + it('contains the `alerts` response', () => { + const functionResponseMessage = last(fourthRequestBody.messages); + const parsedContent = JSON.parse(functionResponseMessage?.content as string); + expect(Object.keys(parsedContent)).to.eql(['total', 'alerts']); + }); + + it('emits a messageAdded event with the `alert` function response', async () => { + const messageWithAlerts = messageAddedEvents.find( + ({ message }) => + message.message.role === MessageRole.User && message.message.name === 'alerts' + ); + + const parsedContent = JSON.parse(messageWithAlerts?.message.message.content!) as { + total: number; + alerts: any[]; + }; + expect(parsedContent.total).to.be(1); + expect(parsedContent.alerts.length).to.be(1); + }); + }); + }); + + describe('messageAdded events', () => { + it('emits 7 messageAdded events', () => { + expect(messageAddedEvents.length).to.be(7); + }); + + it('emits messageAdded events in the correct order', async () => { + const formattedMessageAddedEvents = messageAddedEvents.map(({ message }) => { + const { role, name, function_call: functionCall } = message.message; + if (functionCall) { + return { function_call: functionCall, role }; + } + + return { name, role }; + }); + + expect(formattedMessageAddedEvents).to.eql([ + { + role: 'assistant', + function_call: { name: 'context', trigger: 'assistant' }, + }, + { name: 'context', role: 'user' }, + { + role: 'assistant', + function_call: { + name: 'get_alerts_dataset_info', + arguments: '{"start":"now-10d","end":"now"}', + trigger: 'assistant', + }, + }, + { name: 'get_alerts_dataset_info', role: 'user' }, + { + role: 'assistant', + function_call: { + name: 'alerts', + arguments: '{"start":"now-10d","end":"now"}', + trigger: 'assistant', + }, + }, + { name: 'alerts', role: 'user' }, + { + role: 'assistant', + function_call: { name: '', arguments: '', trigger: 'assistant' }, + }, + ]); + }); + }); + }); +} + +async function createApmErrorCountRule( + getService: DeploymentAgnosticFtrProviderContext['getService'] +) { + const alertingApi = getService('alertingApi'); + const samlAuth = getService('samlAuth'); + + const roleAuthc = await samlAuth.createM2mApiKeyWithRoleScope('editor'); + const createdRule = await alertingApi.createRule({ + ruleTypeId: ApmRuleType.ErrorCount, + name: 'APM error threshold', + consumer: 'apm', + schedule: { interval: '1m' }, + tags: ['apm'], + params: { + environment: 'production', + threshold: 1, + windowSize: 1, + windowUnit: 'h', + }, + roleAuthc, + }); + + const createdRuleId = createdRule.id as string; + const esResponse = await alertingApi.waitForDocumentInIndex({ + indexName: APM_ALERTS_INDEX, + ruleId: createdRuleId, + docCountTarget: 1, + }); + + return { + roleAuthc, + createdRuleId, + alerts: esResponse.hits.hits.map((hit) => hit._source!), + }; +} + +async function createSyntheticApmData( + getService: DeploymentAgnosticFtrProviderContext['getService'] +) { + const synthtrace = getService('synthtrace'); + const apmSynthtraceEsClient = await synthtrace.createApmSynthtraceEsClient(); + + const opbeansNode = apm + .service({ name: 'opbeans-node', environment: 'production', agentName: 'node' }) + .instance('instance'); + + const events = timerange('now-15m', 'now') + .ratePerMinute(1) + .generator((timestamp) => { + return [ + opbeansNode + .transaction({ transactionName: 'DELETE /user/:id' }) + .timestamp(timestamp) + .duration(100) + .failure() + .errors( + opbeansNode.error({ message: 'Unable to delete user' }).timestamp(timestamp + 50) + ), + ]; + }); + + await apmSynthtraceEsClient.index(events); + + return { apmSynthtraceEsClient }; +} + +// order of instructions can vary, so we sort to compare them +function sortSystemMessage(message: string) { + return message + .split('\n\n') + .map((line) => line.trim()) + .sort(); +} diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts index b64295d3a255b..f36b9b9eb6037 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/helpers.ts @@ -45,6 +45,13 @@ export async function invokeChatCompleteWithFunctionRequest({ params: { body: { messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello from user', + }, + }, { '@timestamp': new Date().toISOString(), message: { diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/summarize.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/summarize.spec.ts index eb2dc6aca3d31..999a94be56000 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/summarize.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/summarize.spec.ts @@ -64,7 +64,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); }); after(async () => { diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts index 1d3d41ddb4400..31d0b5f5c836c 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/index.ts @@ -16,6 +16,7 @@ export default function aiAssistantApiIntegrationTests({ loadTestFile(require.resolve('./chat/chat.spec.ts')); loadTestFile(require.resolve('./complete/complete.spec.ts')); loadTestFile(require.resolve('./complete/functions/alerts.spec.ts')); + loadTestFile(require.resolve('./complete/functions/get_alerts_dataset_info.spec.ts')); loadTestFile(require.resolve('./complete/functions/elasticsearch.spec.ts')); loadTestFile(require.resolve('./complete/functions/summarize.spec.ts')); loadTestFile(require.resolve('./public_complete/public_complete.spec.ts')); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/knowledge_base/knowledge_base_user_instructions.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/knowledge_base/knowledge_base_user_instructions.spec.ts index ff519f91900b9..fa946fa29b8dc 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/knowledge_base/knowledge_base_user_instructions.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/knowledge_base/knowledge_base_user_instructions.spec.ts @@ -308,7 +308,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); expect(createResponse.status).to.be(200); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); const conversationCreatedEvent = getConversationCreatedEvent(createResponse.body); const conversationId = conversationCreatedEvent.conversation.id; @@ -321,7 +321,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); const conversation = res.body; return conversation; @@ -470,7 +470,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); const simulator = await simulatorPromise; const requestData = simulator.requestBody; expect(requestData.messages[0].content).to.contain(userInstructionText); diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/public_complete/public_complete.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/public_complete/public_complete.spec.ts index c80e2b4b2d591..c3de648abaab7 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/public_complete/public_complete.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/public_complete/public_complete.spec.ts @@ -14,9 +14,9 @@ import { MessageAddEvent, type StreamingChatResponseEvent, } from '@kbn/observability-ai-assistant-plugin/common/conversation_complete'; -import type OpenAI from 'openai'; import { type AdHocInstruction } from '@kbn/observability-ai-assistant-plugin/common/types'; import type { ChatCompletionChunkToolCall } from '@kbn/inference-common'; +import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; import { createLlmProxy, LlmProxy, @@ -72,7 +72,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }, }); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); const titleSimulator = await titleSimulatorPromise; const conversationSimulator = await conversationSimulatorPromise; @@ -156,7 +156,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); describe('after adding an instruction', () => { - let body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming; + let body: ChatCompletionStreamParams; before(async () => { const { conversationSimulator } = await addInterceptorsAndCallComplete({ diff --git a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts index 5fcf41a33ebdb..030550f7a0c67 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts @@ -7,26 +7,29 @@ import { ToolingLog } from '@kbn/tooling-log'; import getPort from 'get-port'; +import { v4 as uuidv4 } from 'uuid'; import http, { type Server } from 'http'; -import { isString, once, pull } from 'lodash'; -import OpenAI from 'openai'; +import { isString, once, pull, isFunction } from 'lodash'; import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title'; import pRetry from 'p-retry'; import type { ChatCompletionChunkToolCall } from '@kbn/inference-common'; +import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; import { createOpenAiChunk } from './create_openai_chunk'; type Request = http.IncomingMessage; type Response = http.ServerResponse & { req: http.IncomingMessage }; +type LLMMessage = string[] | ToolMessage | string | undefined; + type RequestHandler = ( request: Request, response: Response, - requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming + requestBody: ChatCompletionStreamParams ) => void; interface RequestInterceptor { name: string; - when: (body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) => boolean; + when: (body: ChatCompletionStreamParams) => boolean; } export interface ToolMessage { @@ -34,8 +37,8 @@ export interface ToolMessage { tool_calls?: ChatCompletionChunkToolCall[]; } export interface LlmResponseSimulator { - requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming; - status: (code: number) => Promise; + requestBody: ChatCompletionStreamParams; + status: (code: number) => void; next: (msg: string | ToolMessage) => Promise; error: (error: any) => Promise; complete: () => Promise; @@ -46,35 +49,47 @@ export interface LlmResponseSimulator { export class LlmProxy { server: Server; interval: NodeJS.Timeout; - interceptors: Array = []; + interceptedRequests: Array<{ + requestBody: ChatCompletionStreamParams; + matchingInterceptorName: string | undefined; + }> = []; constructor(private readonly port: number, private readonly log: ToolingLog) { - this.interval = setInterval(() => this.log.debug(`LLM proxy listening on port ${port}`), 1000); + this.interval = setInterval(() => this.log.debug(`LLM proxy listening on port ${port}`), 5000); this.server = http .createServer() .on('request', async (request, response) => { - this.log.info(`LLM request received`); - - const interceptors = this.interceptors.concat(); const requestBody = await getRequestBody(request); - while (interceptors.length) { - const interceptor = interceptors.shift()!; + const matchingInterceptor = this.interceptors.find(({ when }) => when(requestBody)); + this.interceptedRequests.push({ + requestBody, + matchingInterceptorName: matchingInterceptor?.name, + }); + if (matchingInterceptor) { + this.log.info(`Handling interceptor "${matchingInterceptor.name}"`); + matchingInterceptor.handle(request, response, requestBody); - if (interceptor.when(requestBody)) { - pull(this.interceptors, interceptor); - interceptor.handle(request, response, requestBody); - return; - } + this.log.debug(`Removing interceptor "${matchingInterceptor.name}"`); + pull(this.interceptors, matchingInterceptor); + return; } const errorMessage = `No interceptors found to handle request: ${request.method} ${request.url}`; + const availableInterceptorNames = this.interceptors.map(({ name }) => name); + this.log.error( + `Available interceptors: ${JSON.stringify(availableInterceptorNames, null, 2)}` + ); + this.log.error( `${errorMessage}. Messages: ${JSON.stringify(requestBody.messages, null, 2)}` ); - response.writeHead(500, { errorMessage, messages: JSON.stringify(requestBody.messages) }); + response.writeHead(500, { + 'Elastic-Interceptor': 'Interceptor not found', + }); + response.write(sseEvent({ errorMessage, availableInterceptorNames })); response.end(); }) .on('error', (error) => { @@ -88,7 +103,8 @@ export class LlmProxy { } clear() { - this.interceptors.length = 0; + this.interceptors = []; + this.interceptedRequests = []; } close() { @@ -97,16 +113,18 @@ export class LlmProxy { this.server.close(); } - waitForAllInterceptorsSettled() { + waitForAllInterceptorsToHaveBeenCalled() { return pRetry( async () => { if (this.interceptors.length === 0) { return; } - const unsettledInterceptors = this.interceptors.map((i) => i.name).join(', '); + const unsettledInterceptors = this.interceptors.map((i) => i.name); this.log.debug( - `Waiting for the following interceptors to be called: ${unsettledInterceptors}` + `Waiting for the following interceptors to be called: ${JSON.stringify( + unsettledInterceptors + )}` ); if (this.interceptors.length > 0) { throw new Error(`Interceptors were not called: ${unsettledInterceptors}`); @@ -120,61 +138,71 @@ export class LlmProxy { } interceptConversation( - msg: Array | ToolMessage | string | undefined, + msg: LLMMessage, { - name = 'default_interceptor_conversation_name', + name, }: { name?: string; } = {} ) { return this.intercept( - name, - (body) => !isFunctionTitleRequest(body), + `Conversation interceptor: "${name ?? 'Unnamed'}"`, + // @ts-expect-error + (body) => body.tool_choice?.function?.name === undefined, msg ).completeAfterIntercept(); } - interceptTitle(title: string) { - return this.intercept( - `conversation_title_interceptor_${title.split(' ').join('_')}`, - (body) => isFunctionTitleRequest(body), - { + interceptWithFunctionRequest({ + name: name, + arguments: argumentsCallback, + when, + }: { + name: string; + arguments: (body: ChatCompletionStreamParams) => string; + when: RequestInterceptor['when']; + }) { + // @ts-expect-error + return this.intercept(`Function request interceptor: "${name}"`, when, (body) => { + return { content: '', tool_calls: [ { - index: 0, - toolCallId: 'id', function: { - name: TITLE_CONVERSATION_FUNCTION_NAME, - arguments: JSON.stringify({ title }), + name, + arguments: argumentsCallback(body), }, + index: 0, + id: `call_${uuidv4()}`, }, ], - } - ).completeAfterIntercept(); + }; + }).completeAfterIntercept(); } - intercept< - TResponseChunks extends - | Array - | ToolMessage - | string - | undefined = undefined - >( + interceptTitle(title: string) { + return this.interceptWithFunctionRequest({ + name: TITLE_CONVERSATION_FUNCTION_NAME, + arguments: () => JSON.stringify({ title }), + // @ts-expect-error + when: (body) => body.tool_choice?.function?.name === TITLE_CONVERSATION_FUNCTION_NAME, + }); + } + + intercept( name: string, when: RequestInterceptor['when'], - responseChunks?: TResponseChunks - ): TResponseChunks extends undefined - ? { waitForIntercept: () => Promise } - : { completeAfterIntercept: () => Promise } { + responseChunks?: LLMMessage | ((body: ChatCompletionStreamParams) => LLMMessage) + ): { + waitForIntercept: () => Promise; + completeAfterIntercept: () => Promise; + } { const waitForInterceptPromise = Promise.race([ new Promise((outerResolve) => { this.interceptors.push({ name, when, handle: (request, response, requestBody) => { - this.log.info(`LLM request intercepted by "${name}"`); - function write(chunk: string) { return new Promise((resolve) => response.write(chunk, () => resolve())); } @@ -184,24 +212,28 @@ export class LlmProxy { const simulator: LlmResponseSimulator = { requestBody, - status: once(async (status: number) => { + status: once((status: number) => { response.writeHead(status, { + 'Elastic-Interceptor': name, 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }); }), next: (msg) => { + simulator.status(200); const chunk = createOpenAiChunk(msg); - return write(`data: ${JSON.stringify(chunk)}\n\n`); + return write(sseEvent(chunk)); }, rawWrite: (chunk: string) => { + simulator.status(200); return write(chunk); }, rawEnd: async () => { await end(); }, complete: async () => { + this.log.debug(`Completed intercept for "${name}"`); await write('data: [DONE]\n\n'); await end(); }, @@ -216,29 +248,41 @@ export class LlmProxy { }); }), new Promise((_, reject) => { - setTimeout(() => reject(new Error(`Interceptor "${name}" timed out after 20000ms`)), 20000); + setTimeout(() => reject(new Error(`Interceptor "${name}" timed out after 30000ms`)), 30000); }), ]); - if (responseChunks === undefined) { - return { waitForIntercept: () => waitForInterceptPromise } as any; - } - - const parsedChunks = Array.isArray(responseChunks) - ? responseChunks - : isString(responseChunks) - ? responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`)) - : [responseChunks]; - return { + waitForIntercept: () => waitForInterceptPromise, completeAfterIntercept: async () => { const simulator = await waitForInterceptPromise; + + function getParsedChunks(): Array { + const llmMessage = isFunction(responseChunks) + ? responseChunks(simulator.requestBody) + : responseChunks; + + if (!llmMessage) { + return []; + } + + if (Array.isArray(llmMessage)) { + return llmMessage; + } + + if (isString(llmMessage)) { + return llmMessage.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`)); + } + + return [llmMessage]; + } + + const parsedChunks = getParsedChunks(); for (const chunk of parsedChunks) { await simulator.next(chunk); } await simulator.complete(); - return simulator; }, } as any; @@ -251,9 +295,7 @@ export async function createLlmProxy(log: ToolingLog) { return new LlmProxy(port, log); } -async function getRequestBody( - request: http.IncomingMessage -): Promise { +async function getRequestBody(request: http.IncomingMessage): Promise { return new Promise((resolve, reject) => { let data = ''; @@ -271,11 +313,6 @@ async function getRequestBody( }); } -export function isFunctionTitleRequest( - requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming -) { - return ( - requestBody.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !== - undefined - ); +function sseEvent(chunk: unknown) { + return `data: ${JSON.stringify(chunk)}\n\n`; } diff --git a/x-pack/test/observability_ai_assistant_functional/tests/contextual_insights/index.spec.ts b/x-pack/test/observability_ai_assistant_functional/tests/contextual_insights/index.spec.ts index 45382a149a2ed..76249480ffa2d 100644 --- a/x-pack/test/observability_ai_assistant_functional/tests/contextual_insights/index.spec.ts +++ b/x-pack/test/observability_ai_assistant_functional/tests/contextual_insights/index.spec.ts @@ -126,7 +126,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte await openContextualInsights(); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); await retry.tryForTime(5 * 1000, async () => { const llmResponse = await testSubjects.getVisibleText(ui.pages.contextualInsights.text); diff --git a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts index e000c562267f4..49eb42d06d8ad 100644 --- a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts +++ b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts @@ -247,7 +247,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello'); await testSubjects.pressEnter(ui.pages.conversations.chatInput); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); await header.waitUntilLoadingHasFinished(); }); @@ -256,6 +256,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte endpoint: 'POST /internal/observability_ai_assistant/conversations', }); + const functionResponse = await observabilityAIAssistantAPIClient.editor({ + endpoint: 'GET /internal/observability_ai_assistant/functions', + params: { + query: { + scopes: ['observability'], + }, + }, + }); + + const primarySystemMessage = functionResponse.body.systemMessage; + expect(response.body.conversations.length).to.eql(2); expect(response.body.conversations[0].conversation.title).to.be(expectedTitle); @@ -267,10 +278,13 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte const [firstUserMessage, contextRequest, contextResponse, assistantResponse] = messages.map((msg) => msg.message); - const systemMessageContent = - 'You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities.\n\n It\'s very important to not assume what the user is meaning. Ask them for clarification if needed.\n\n If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation.\n\n In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: \':()\\ /". Always put a field value in double quotes. Best: service.name:"opbeans-go". Wrong: service.name:opbeans-go. This is very important!\n\n You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response.\n\n Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language.\n\n If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results\n returned to you, before executing the same tool or another tool again if needed.\n\n DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (`service.name == "foo"`) with "kqlFilter" (`service.name:"foo"`).\n\n The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the Stack Management app under the option AI Assistants.\n If the user asks how to change the language, reply in the same language the user asked in.\n\nYou MUST use the "query" function when the user wants to:\n - visualize data\n - run any arbitrary query\n - breakdown or filter ES|QL queries that are displayed on the current page\n - convert queries from another language to ES|QL\n - asks general questions about ES|QL\n\n DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries or explain anything about the ES|QL query language yourself.\n DO NOT UNDER ANY CIRCUMSTANCES try to correct an ES|QL query yourself - always use the "query" function for this.\n\n If the user asks for a query, and one of the dataset info functions was called and returned no results, you should still call the query function to generate an example query.\n\n Even if the "query" function was used before that, follow it up with the "query" function. If a query fails, do not attempt to correct it yourself. Again you should call the "query" function,\n even if it has been called before.\n\n When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt.\n If the "execute_query" function has been called, summarize these results for the user. The user does not see a visualization in this case.\n\nYou MUST use the "get_dataset_info" function before calling the "query" or the "changes" functions.\n\nIf a function requires an index, you MUST use the results from the dataset info functions.\n\nYou do not have a working memory. If the user expects you to remember the previous conversations, tell them they can set up the knowledge base.\n\nWhen asked questions about the Elastic stack or products, You should use the retrieve_elastic_doc function before answering,\n to retrieve documentation related to the question. Consider that the documentation returned by the function\n is always more up to date and accurate than any own internal knowledge you might have.'; + expect(systemMessage).to.contain( + 'You are a helpful assistant for Elastic Observability. Your goal is ' + ); - expect(systemMessage).to.eql(systemMessageContent); + expect(sortSystemMessage(systemMessage!)).to.eql( + sortSystemMessage(primarySystemMessage) + ); expect(firstUserMessage.content).to.eql('hello'); @@ -305,7 +319,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello'); await testSubjects.pressEnter(ui.pages.conversations.chatInput); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); await header.waitUntilLoadingHasFinished(); }); @@ -396,7 +410,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte await testSubjects.pressEnter(ui.pages.conversations.chatInput); log.info('SQREN: Waiting for the message to be displayed'); - await proxy.waitForAllInterceptorsSettled(); + await proxy.waitForAllInterceptorsToHaveBeenCalled(); await header.waitUntilLoadingHasFinished(); }); @@ -451,3 +465,11 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte }); }); } + +// order of instructions can vary, so we sort to compare them +function sortSystemMessage(message: string) { + return message + .split('\n\n') + .map((line) => line.trim()) + .sort(); +}