diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/functions/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/functions/types.ts index ce07d3de03308..dcfd39651bf0a 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/functions/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/functions/types.ts @@ -25,14 +25,12 @@ export type FunctionResponse = } | Observable; -export interface FunctionDefinition< - TParameters extends CompatibleJSONSchema = CompatibleJSONSchema -> { +export interface FunctionDefinition { name: string; description: string; visibility?: FunctionVisibility; descriptionForUser?: string; - parameters: TParameters; + parameters?: TParameters; contexts: string[]; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts index b32161ca0195e..e5494fed64d10 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts @@ -4,6 +4,8 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ +import type { ObservabilityAIAssistantChatService } from '../public'; +import type { CompatibleJSONSchema, FunctionResponse } from './functions/types'; export enum MessageRole { System = 'system', @@ -77,6 +79,31 @@ export interface KnowledgeBaseEntry { role: KnowledgeBaseEntryRole; } +export interface ObservabilityAIAssistantScreenContextRequest { + screenDescription?: string; + data?: Array<{ + name: string; + description: string; + value: any; + }>; + actions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>; +} + +export type ScreenContextActionRespondFunction = ({}: { + args: TArguments; + signal: AbortSignal; + connectorId: string; + client: Pick; + messages: Message[]; +}) => Promise; + +export interface ScreenContextActionDefinition { + name: string; + description: string; + parameters?: CompatibleJSONSchema; + respond: ScreenContextActionRespondFunction; +} + export interface ObservabilityAIAssistantScreenContext { screenDescription?: string; data?: Array<{ @@ -84,4 +111,5 @@ export interface ObservabilityAIAssistantScreenContext { description: string; value: any; }>; + actions?: ScreenContextActionDefinition[]; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/create_function_request_message.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts similarity index 83% rename from x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/create_function_request_message.ts rename to x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts index 8c38b03040794..45399ea651bb3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/create_function_request_message.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_request_message.ts @@ -6,11 +6,8 @@ */ import { v4 } from 'uuid'; -import { MessageRole } from '../../../common'; -import { - MessageAddEvent, - StreamingChatResponseEventType, -} from '../../../common/conversation_complete'; +import { MessageRole } from '..'; +import { MessageAddEvent, StreamingChatResponseEventType } from '../conversation_complete'; export function createFunctionRequestMessage({ name, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts new file mode 100644 index 0000000000000..79f6e5d4ff6df --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_error.ts @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { createFunctionResponseMessage } from './create_function_response_message'; + +export function createFunctionResponseError({ + name, + error, + message, +}: { + name: string; + error: Error; + message?: string; +}) { + return createFunctionResponseMessage({ + name, + content: { + error: { + ...error, + name: error.name, + message: error.message, + cause: error.cause, + stack: error.stack, + }, + message: message || error.message, + }, + }); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/create_function_response_message.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_message.ts similarity index 82% rename from x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/create_function_response_message.ts rename to x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_message.ts index 186ff117734c3..b382e09c19e37 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/create_function_response_message.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/create_function_response_message.ts @@ -6,11 +6,8 @@ */ import { v4 } from 'uuid'; -import { MessageRole } from '../../../common'; -import { - type MessageAddEvent, - StreamingChatResponseEventType, -} from '../../../common/conversation_complete'; +import { MessageRole } from '..'; +import { type MessageAddEvent, StreamingChatResponseEventType } from '../conversation_complete'; export function createFunctionResponseMessage({ name, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_chat.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_chat.ts index 1ccc48ddb3934..2ab4fd294dffa 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_chat.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_chat.ts @@ -161,7 +161,7 @@ function useChatWithoutContext({ setChatState(ChatState.Loading); const next$ = chatService.complete({ - screenContexts: service.getScreenContexts(), + getScreenContexts: () => service.getScreenContexts(), connectorId, messages: getWithSystemMessage(nextMessages, systemMessage), persist, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts index a3bb72746a86b..80d01706700b8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts @@ -33,6 +33,8 @@ export { useAbortableAsync, type AbortableAsyncState } from './hooks/use_abortab export { createStorybookChatService, createStorybookService } from './storybook_mock'; +export { createScreenContextAction } from './utils/create_screen_context_action'; + export { ChatState } from './hooks/use_chat'; export { FeedbackButtons, type Feedback } from './components/buttons/feedback_buttons'; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx index fb6438dbe8580..c38307e920641 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx @@ -7,8 +7,9 @@ import { i18n } from '@kbn/i18n'; import { noop } from 'lodash'; import React from 'react'; -import { Observable } from 'rxjs'; +import { Observable, of } from 'rxjs'; import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete'; +import { ScreenContextActionDefinition } from '../common/types'; import type { ObservabilityAIAssistantAPIClient } from './api'; import type { ObservabilityAIAssistantChatService, @@ -49,6 +50,7 @@ export const mockService: ObservabilityAIAssistantService = { openNewConversation: noop, predefinedConversation$: new Observable(), }, + navigate: async () => of(), }; function createSetupContract(): ObservabilityAIAssistantPublicSetup { @@ -75,6 +77,7 @@ function createStartContract(): ObservabilityAIAssistantPublicStart { getPreferredLanguage: () => 'English', }), getContextualInsightMessages: () => [], + createScreenContextAction: () => ({} as ScreenContextActionDefinition), }; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/plugin.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/plugin.tsx index 1e82b135a837a..c003856b03699 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/plugin.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/plugin.tsx @@ -28,6 +28,7 @@ import type { } from './types'; import { useUserPreferredLanguage } from './hooks/use_user_preferred_language'; import { getContextualInsightMessages } from './utils/get_contextual_insight_messages'; +import { createScreenContextAction } from './utils/create_screen_context_action'; export class ObservabilityAIAssistantPlugin implements @@ -107,6 +108,7 @@ export class ObservabilityAIAssistantPlugin ) : null, getContextualInsightMessages, + createScreenContextAction, }; } } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts new file mode 100644 index 0000000000000..eca7b6977a7c3 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts @@ -0,0 +1,404 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { filter, last, lastValueFrom, map, of, throwError, toArray } from 'rxjs'; +import { v4 } from 'uuid'; +import { + type Message, + MessageRole, + StreamingChatResponseEventType, + type StreamingChatResponseEvent, + ChatCompletionErrorCode, + ChatCompletionError, + MessageAddEvent, + createInternalServerError, +} from '../../common'; +import type { ObservabilityAIAssistantChatService } from '../types'; +import { complete } from './complete'; + +const client = { + chat: jest.fn(), + complete: jest.fn(), +} as unknown as ObservabilityAIAssistantChatService; + +const connectorId = 'foo'; + +const messages: Message[] = [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System message', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'User message', + }, + }, +]; + +const createLlmResponse = ( + chunks: Array<{ content: string; function_call?: { name: string; arguments: string } }> +): StreamingChatResponseEvent[] => { + const id = v4(); + const message = chunks.reduce( + (prev, current) => { + prev.content += current.content ?? ''; + prev.function_call!.name += current.function_call?.name ?? ''; + prev.function_call!.arguments! += current.function_call?.arguments ?? ''; + return prev; + }, + { + content: '', + role: MessageRole.Assistant, + function_call: { name: '', arguments: '', trigger: MessageRole.Assistant }, + } + ); + + const events: StreamingChatResponseEvent[] = [ + ...chunks.map((msg) => ({ + id, + message: msg, + type: StreamingChatResponseEventType.ChatCompletionChunk as const, + })), + { + id, + message: { + '@timestamp': new Date().toString(), + message, + }, + type: StreamingChatResponseEventType.MessageAdd as const, + }, + ]; + + return events; +}; + +type CompleteParameters = Parameters[0]; + +describe('complete', () => { + const requestCallback: jest.MockedFunction[1]> = jest.fn(); + + beforeEach(() => { + requestCallback.mockReset(); + }); + + function callComplete(params?: Partial) { + return complete( + { + client, + connectorId, + getScreenContexts: () => [], + messages, + persist: false, + signal: new AbortController().signal, + ...params, + }, + requestCallback + ); + } + + describe('when an error is emitted', () => { + beforeEach(() => { + requestCallback.mockImplementation(() => + of({ + type: StreamingChatResponseEventType.ChatCompletionError, + error: { + message: 'Not found', + code: ChatCompletionErrorCode.NotFoundError, + }, + }) + ); + }); + + it('the observable errors out', async () => { + await expect(async () => await lastValueFrom(callComplete())).rejects.toThrowError( + 'Not found' + ); + + await expect(async () => await lastValueFrom(callComplete())).rejects.toBeInstanceOf( + ChatCompletionError + ); + + await expect(async () => await lastValueFrom(callComplete())).rejects.toHaveProperty( + 'code', + ChatCompletionErrorCode.NotFoundError + ); + }); + }); + + describe('with screen context and an action is called', () => { + const respondFn: jest.MockedFn = jest.fn(); + + const getScreenContexts: CompleteParameters['getScreenContexts'] = jest.fn().mockReturnValue([ + { + actions: [ + { + name: 'my_action', + description: 'My action', + parameters: { + type: 'object', + properties: { + foo: { + type: 'string', + }, + }, + }, + respond: respondFn, + }, + ], + }, + ]); + + beforeEach(() => { + requestCallback.mockImplementationOnce(() => + of( + ...createLlmResponse([ + { + content: '', + function_call: { name: 'my_action', arguments: JSON.stringify({ foo: 'bar' }) }, + }, + ]) + ) + ); + }); + describe('and it succeeds', () => { + let allMessages: Message[] = []; + beforeEach(async () => { + respondFn.mockResolvedValueOnce({ content: { bar: 'foo' } }); + + requestCallback.mockImplementationOnce(() => + of(...createLlmResponse([{ content: 'Great action call' }])) + ); + + allMessages = await lastValueFrom( + callComplete({ + getScreenContexts, + }).pipe( + filter( + (event): event is MessageAddEvent => + event.type === StreamingChatResponseEventType.MessageAdd + ), + map((event) => event.message), + toArray(), + last() + ) + ); + }); + + it('calls the request callback again with the executed message', () => { + expect(requestCallback).toHaveBeenCalledTimes(2); + + const nextMessages = requestCallback.mock.lastCall![0].params.body.messages; + + const expectedMessages = [ + { + '@timestamp': expect.any(String), + message: { + content: '', + function_call: { + arguments: JSON.stringify({ foo: 'bar' }), + name: 'my_action', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: JSON.stringify({ bar: 'foo' }), + name: 'my_action', + role: MessageRole.User, + }, + }, + ]; + + expect(nextMessages).toEqual([...messages, ...expectedMessages]); + }); + + it('calls the action handler with the arguments from the LLM', () => { + expect(respondFn).toHaveBeenCalledWith( + expect.objectContaining({ + args: { + foo: 'bar', + }, + }) + ); + }); + + it('returns all the messages in the created observable', () => { + expect(allMessages[allMessages.length - 1]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: 'Great action call', + function_call: { + arguments: '', + name: '', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }); + }); + }); + + describe('and it fails', () => { + beforeEach(async () => { + respondFn.mockRejectedValueOnce(new Error('foo')); + + requestCallback.mockImplementationOnce(() => + of(...createLlmResponse([{ content: 'Action call failed' }])) + ); + + await lastValueFrom( + callComplete({ + getScreenContexts, + }).pipe( + filter( + (event): event is MessageAddEvent => + event.type === StreamingChatResponseEventType.MessageAdd + ), + map((event) => event.message), + toArray(), + last() + ) + ); + }); + + it('calls the request callback again with the error', () => { + expect(requestCallback).toHaveBeenCalledTimes(2); + + const nextMessages = requestCallback.mock.lastCall![0].params.body.messages; + + const errorMessage = nextMessages[nextMessages.length - 1]; + + expect(errorMessage).toEqual({ + '@timestamp': expect.any(String), + message: { + content: expect.any(String), + name: 'my_action', + role: MessageRole.User, + }, + }); + + expect(JSON.parse(errorMessage.message.content ?? '{}')).toEqual({ + error: expect.objectContaining({ + message: 'foo', + }), + message: 'foo', + }); + }); + }); + + describe('and it returns an observable that completes', () => { + let allMessages: Message[] = []; + let allEvents: StreamingChatResponseEvent[] = []; + beforeEach(async () => { + respondFn.mockResolvedValueOnce( + of(...createLlmResponse([{ content: 'My function response' }])) + ); + + allEvents = await lastValueFrom( + callComplete({ + getScreenContexts, + }).pipe(toArray(), last()) + ); + + allMessages = allEvents + .filter( + (event): event is MessageAddEvent => + event.type === StreamingChatResponseEventType.MessageAdd + ) + .map((event) => event.message); + }); + + it('propagates all the events from the responded observable', () => { + expect(allEvents.length).toEqual(5); + expect( + allEvents.filter( + (event) => event.type === StreamingChatResponseEventType.ChatCompletionChunk + ).length + ).toEqual(2); + }); + + it('automatically adds a function response message', () => { + expect(allMessages[allMessages.length - 2]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: JSON.stringify({ executed: true }), + name: 'my_action', + role: MessageRole.User, + }, + }); + }); + + it('adds the messages from the observable', () => { + expect(allMessages[allMessages.length - 1]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: 'My function response', + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }); + }); + }); + + describe('and it returns an observable that errors out', () => { + let allMessages: Message[] = []; + let allEvents: StreamingChatResponseEvent[] = []; + beforeEach(async () => { + respondFn.mockResolvedValueOnce(throwError(() => createInternalServerError('Foo'))); + + requestCallback.mockImplementationOnce(() => + of( + ...createLlmResponse([ + { + content: 'Looks like your action failed', + }, + ]) + ) + ); + + allEvents = await lastValueFrom( + callComplete({ + getScreenContexts, + }).pipe(toArray(), last()) + ); + + allMessages = allEvents + .filter( + (event): event is MessageAddEvent => + event.type === StreamingChatResponseEventType.MessageAdd + ) + .map((event) => event.message); + }); + + it('appends the error message', () => { + expect(allMessages[allMessages.length - 1]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: 'Looks like your action failed', + function_call: { + arguments: '', + name: '', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }); + }); + }); + }); +}); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts new file mode 100644 index 0000000000000..812d486317b57 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + catchError, + combineLatest, + filter, + isObservable, + last, + map, + Observable, + of, + shareReplay, + toArray, +} from 'rxjs'; +import { + MessageRole, + StreamingChatResponseEventType, + type BufferFlushEvent, + type ConversationCreateEvent, + type ConversationUpdateEvent, + type Message, + type MessageAddEvent, + type StreamingChatResponseEvent, + type StreamingChatResponseEventWithoutError, +} from '../../common'; +import { ObservabilityAIAssistantScreenContext } from '../../common/types'; +import { createFunctionResponseError } from '../../common/utils/create_function_response_error'; +import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message'; +import { throwSerializedChatCompletionErrors } from '../../common/utils/throw_serialized_chat_completion_errors'; +import type { ObservabilityAIAssistantAPIClientRequestParamsOf } from '../api'; +import { ObservabilityAIAssistantChatService } from '../types'; + +export function complete( + { + client, + getScreenContexts, + connectorId, + conversationId, + messages: initialMessages, + persist, + signal, + }: { + client: Pick; + getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; + connectorId: string; + conversationId?: string; + messages: Message[]; + persist: boolean; + signal: AbortSignal; + }, + requestCallback: ( + params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat/complete'> + ) => Observable +): Observable { + return new Observable((subscriber) => { + const screenContexts = getScreenContexts(); + const allActions = screenContexts.flatMap((context) => context.actions ?? []); + + const response$ = requestCallback({ + params: { + body: { connectorId, messages: initialMessages, persist, screenContexts, conversationId }, + }, + }).pipe( + filter( + (event): event is StreamingChatResponseEvent => + event.type !== StreamingChatResponseEventType.BufferFlush + ), + throwSerializedChatCompletionErrors(), + shareReplay() + ); + + const messages$ = response$.pipe( + filter( + (event): event is MessageAddEvent => + event.type === StreamingChatResponseEventType.MessageAdd + ), + map((event) => event.message), + toArray(), + last() + ); + + const conversationId$ = response$.pipe( + last( + (event): event is ConversationCreateEvent | ConversationUpdateEvent => + event.type === StreamingChatResponseEventType.ConversationCreate || + event.type === StreamingChatResponseEventType.ConversationUpdate + ), + map((event) => event.conversation.id), + catchError(() => { + return of(conversationId); + }) + ); + + response$.subscribe({ + next: (val) => { + subscriber.next(val); + }, + error: (error) => { + subscriber.error(error); + }, + }); + + combineLatest([conversationId$, messages$, response$.pipe(last())]).subscribe({ + next: ([nextConversationId, allMessages]) => { + const functionCall = allMessages[allMessages.length - 1]?.message.function_call; + + if (!functionCall?.name) { + subscriber.complete(); + return; + } + + const requestedAction = allActions.find((action) => action.name === functionCall.name); + + function next(nextMessages: Message[]) { + if ( + nextMessages[nextMessages.length - 1].message.role === MessageRole.Assistant && + !persist + ) { + subscriber.complete(); + return; + } + + complete( + { + client, + getScreenContexts, + connectorId, + conversationId: nextConversationId || conversationId, + messages: initialMessages.concat(nextMessages), + signal, + persist, + }, + requestCallback + ).subscribe(subscriber); + } + + if (!requestedAction) { + const errorMessage = createFunctionResponseError({ + name: functionCall.name, + error: new Error(`Requested action ${functionCall.name} was not found`), + }); + + subscriber.next(errorMessage); + + next([...allMessages, errorMessage.message]); + return; + } + + requestedAction + .respond({ + signal, + client, + args: JSON.parse(functionCall.arguments || '{}'), + connectorId, + messages: allMessages, + }) + .then(async (functionResponse) => { + if (isObservable(functionResponse)) { + const executedMessage = createFunctionResponseMessage({ + name: functionCall.name, + content: { + executed: true, + }, + }); + + allMessages.push(executedMessage.message); + + subscriber.next(executedMessage); + + return await new Promise((resolve, reject) => { + functionResponse.subscribe({ + next: (val) => { + if (val.type === StreamingChatResponseEventType.MessageAdd) { + allMessages.push(val.message); + } + subscriber.next(val); + }, + error: (error) => { + reject(error); + }, + complete: () => { + resolve(); + }, + }); + }); + } + + return createFunctionResponseMessage({ + name: functionCall.name, + content: functionResponse.content, + data: functionResponse.data, + }); + }) + .catch((error) => { + return createFunctionResponseError({ + name: functionCall.name, + error, + }); + }) + .then((event) => { + if (event) { + allMessages.push(event.message); + + subscriber.next(event); + } + next(allMessages); + }); + }, + error: (error) => { + subscriber.error(error); + }, + }); + }); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts index 683792d5cf708..2a333d742c5ec 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.test.ts @@ -73,7 +73,7 @@ describe('createChatService', () => { reportEvent: () => {}, telemetryCounter$: new Observable(), }, - client: clientSpy, + apiClient: clientSpy, registrations: [], signal: new AbortController().signal, }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts index 257a657ce6487..c92d2f7b3daf9 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts @@ -5,15 +5,27 @@ * 2.0. */ -import { AnalyticsServiceStart, HttpResponse } from '@kbn/core/public'; +import type { AnalyticsServiceStart, HttpResponse } from '@kbn/core/public'; import { AbortError } from '@kbn/kibana-utils-plugin/common'; -import { IncomingMessage } from 'http'; +import type { IncomingMessage } from 'http'; import { pick } from 'lodash'; -import { concatMap, delay, filter, map, Observable, of, scan, shareReplay, timestamp } from 'rxjs'; import { - BufferFlushEvent, + concatMap, + delay, + filter, + from, + map, + Observable, + of, + scan, + shareReplay, + switchMap, + timestamp, +} from 'rxjs'; +import { + type BufferFlushEvent, StreamingChatResponseEventType, - StreamingChatResponseEventWithoutError, + type StreamingChatResponseEventWithoutError, type StreamingChatResponseEvent, } from '../../common/conversation_complete'; import { @@ -21,7 +33,6 @@ import { FunctionResponse, FunctionVisibility, } from '../../common/functions/types'; -import { type Message } from '../../common/types'; import { filterFunctionDefinitions } from '../../common/utils/filter_function_definitions'; import { throwSerializedChatCompletionErrors } from '../../common/utils/throw_serialized_chat_completion_errors'; import { sendEvent } from '../analytics'; @@ -32,6 +43,7 @@ import type { RenderFunction, } from '../types'; import { readableStreamReaderIntoObservable } from '../utils/readable_stream_reader_into_observable'; +import { complete } from './complete'; const MIN_DELAY = 35; @@ -82,19 +94,19 @@ export async function createChatService({ analytics, signal: setupAbortSignal, registrations, - client, + apiClient, }: { analytics: AnalyticsServiceStart; signal: AbortSignal; registrations: ChatRegistrationRenderFunction[]; - client: ObservabilityAIAssistantAPIClient; + apiClient: ObservabilityAIAssistantAPIClient; }): Promise { const functionRegistry: FunctionRegistry = new Map(); const renderFunctionRegistry: Map> = new Map(); const [{ functionDefinitions, contextDefinitions }] = await Promise.all([ - client('GET /internal/observability_ai_assistant/functions', { + apiClient('GET /internal/observability_ai_assistant/functions', { signal: setupAbortSignal, }), ...registrations.map((registration) => { @@ -117,100 +129,8 @@ export async function createChatService({ }); }; - return { - sendAnalyticsEvent: (event) => { - sendEvent(analytics, event); - }, - renderFunction: (name, args, response, onActionClick) => { - const fn = renderFunctionRegistry.get(name); - - if (!fn) { - throw new Error(`Function ${name} not found`); - } - - const parsedArguments = args ? JSON.parse(args) : {}; - - const parsedResponse = { - content: JSON.parse(response.content ?? '{}'), - data: JSON.parse(response.data ?? '{}'), - }; - - return fn?.({ - response: parsedResponse, - arguments: parsedArguments, - onActionClick, - }); - }, - getContexts: () => contextDefinitions, - getFunctions, - hasFunction: (name: string) => { - return functionRegistry.has(name); - }, - hasRenderFunction: (name: string) => { - return renderFunctionRegistry.has(name); - }, - complete({ - screenContexts, - connectorId, - conversationId, - messages, - persist, - signal, - responseLanguage, - }) { - return new Observable((subscriber) => { - client('POST /internal/observability_ai_assistant/chat/complete', { - params: { - body: { - connectorId, - conversationId, - screenContexts, - messages, - persist, - responseLanguage, - }, - }, - signal, - asResponse: true, - rawResponse: true, - }) - .then((_response) => { - const response = _response as unknown as HttpResponse; - const response$ = toObservable(response) - .pipe( - map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent), - filter( - (line): line is StreamingChatResponseEvent => - line.type !== StreamingChatResponseEventType.BufferFlush - ), - throwSerializedChatCompletionErrors() - ) - .subscribe(subscriber); - - signal.addEventListener('abort', () => { - response$.unsubscribe(); - }); - }) - .catch((err) => { - subscriber.error(err); - subscriber.complete(); - }); - }); - }, - chat( - name: string, - { - connectorId, - messages, - function: callFunctions = 'auto', - signal, - }: { - connectorId: string; - messages: Message[]; - function?: 'none' | 'auto'; - signal: AbortSignal; - } - ) { + const client: Pick = { + chat(name: string, { connectorId, messages, function: callFunctions = 'auto', signal }) { return new Observable((subscriber) => { const contexts = ['core', 'apm']; @@ -222,7 +142,7 @@ export async function createChatService({ ); }); - client('POST /internal/observability_ai_assistant/chat', { + apiClient('POST /internal/observability_ai_assistant/chat', { params: { body: { name, @@ -279,5 +199,68 @@ export async function createChatService({ shareReplay() ); }, + complete({ getScreenContexts, connectorId, conversationId, messages, persist, signal }) { + return complete( + { + getScreenContexts, + connectorId, + conversationId, + messages, + persist, + signal, + client, + }, + ({ params }) => { + return from( + apiClient('POST /internal/observability_ai_assistant/chat/complete', { + params, + signal, + asResponse: true, + rawResponse: true, + }) + ).pipe( + map((_response) => toObservable(_response as unknown as HttpResponse)), + switchMap((response$) => response$), + map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent), + shareReplay() + ); + } + ); + }, + }; + + return { + sendAnalyticsEvent: (event) => { + sendEvent(analytics, event); + }, + renderFunction: (name, args, response, onActionClick) => { + const fn = renderFunctionRegistry.get(name); + + if (!fn) { + throw new Error(`Function ${name} not found`); + } + + const parsedArguments = args ? JSON.parse(args) : {}; + + const parsedResponse = { + content: JSON.parse(response.content ?? '{}'), + data: JSON.parse(response.data ?? '{}'), + }; + + return fn?.({ + response: parsedResponse, + arguments: parsedArguments, + onActionClick, + }); + }, + getContexts: () => contextDefinitions, + getFunctions, + hasFunction: (name: string) => { + return functionRegistry.has(name); + }, + hasRenderFunction: (name: string) => { + return renderFunctionRegistry.has(name); + }, + ...client, }; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts index 0db5b9cf8b5ba..9ccfca66e53b3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts @@ -6,9 +6,11 @@ */ import type { AnalyticsServiceStart, CoreStart } from '@kbn/core/public'; -import { without } from 'lodash'; -import { BehaviorSubject, Subject } from 'rxjs'; +import { compact, without } from 'lodash'; +import { BehaviorSubject, debounceTime, filter, lastValueFrom, of, Subject, take } from 'rxjs'; import type { Message, ObservabilityAIAssistantScreenContext } from '../../common/types'; +import { createFunctionRequestMessage } from '../../common/utils/create_function_request_message'; +import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message'; import { createCallObservabilityAIAssistantAPI } from '../api'; import type { ChatRegistrationRenderFunction, ObservabilityAIAssistantService } from '../types'; @@ -21,7 +23,7 @@ export function createService({ coreStart: CoreStart; enabled: boolean; }): ObservabilityAIAssistantService { - const client = createCallObservabilityAIAssistantAPI(coreStart); + const apiClient = createCallObservabilityAIAssistantAPI(coreStart); const registrations: ChatRegistrationRenderFunction[] = []; @@ -37,17 +39,50 @@ export function createService({ }, start: async ({ signal }) => { const mod = await import('./create_chat_service'); - return await mod.createChatService({ analytics, client, signal, registrations }); + return await mod.createChatService({ analytics, apiClient, signal, registrations }); }, - callApi: client, + callApi: apiClient, getScreenContexts() { return screenContexts$.value; }, setScreenContext: (context: ObservabilityAIAssistantScreenContext) => { screenContexts$.next(screenContexts$.value.concat(context)); - return () => { + + function unsubscribe() { screenContexts$.next(without(screenContexts$.value, context)); - }; + } + + return unsubscribe; + }, + navigate: async (cb) => { + cb(); + + // wait for at least 1s of no network activity + await lastValueFrom( + coreStart.http.getLoadingCount$().pipe( + filter((count) => count === 0), + debounceTime(1000), + take(1) + ) + ); + + return of( + createFunctionRequestMessage({ + name: 'context', + args: { + queries: [], + categories: [], + }, + }), + createFunctionResponseMessage({ + name: 'context', + content: { + screenDescription: compact( + screenContexts$.value.map((context) => context.screenDescription) + ).join('\n\n'), + }, + }) + ); }, conversations: { openNewConversation: ({ messages, title }: { messages: Message[]; title?: string }) => { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx index 1b2d71a71b345..01c2f658e360b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx @@ -7,7 +7,7 @@ import { i18n } from '@kbn/i18n'; import { noop } from 'lodash'; import React from 'react'; -import { Observable } from 'rxjs'; +import { Observable, of } from 'rxjs'; import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete'; import type { ObservabilityAIAssistantAPIClient } from './api'; import type { ObservabilityAIAssistantChatService, ObservabilityAIAssistantService } from './types'; @@ -44,4 +44,5 @@ export const createStorybookService = (): ObservabilityAIAssistantService => ({ openNewConversation: noop, predefinedConversation$: new Observable(), }, + navigate: async () => of(), }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts index d8c719dfa0364..bf26ce44eb81f 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts @@ -9,7 +9,10 @@ import type { LicensingPluginStart } from '@kbn/licensing-plugin/public'; import type { MlPluginSetup, MlPluginStart } from '@kbn/ml-plugin/public'; import type { SecurityPluginSetup, SecurityPluginStart } from '@kbn/security-plugin/public'; import type { Observable } from 'rxjs'; -import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete'; +import type { + MessageAddEvent, + StreamingChatResponseEventWithoutError, +} from '../common/conversation_complete'; import type { ContextDefinition, FunctionDefinition, @@ -30,6 +33,7 @@ import { useChat } from './hooks/use_chat'; import type { UseGenAIConnectorsResult } from './hooks/use_genai_connectors'; import { useObservabilityAIAssistantChatService } from './hooks/use_observability_ai_assistant_chat_service'; import type { UseUserPreferredLanguageResult } from './hooks/use_user_preferred_language'; +import { createScreenContextAction } from './utils/create_screen_context_action'; /* eslint-disable @typescript-eslint/no-empty-interface*/ @@ -47,7 +51,7 @@ export interface ObservabilityAIAssistantChatService { } ) => Observable; complete: (options: { - screenContexts: ObservabilityAIAssistantScreenContext[]; + getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; conversationId?: string; connectorId: string; messages: Message[]; @@ -80,6 +84,7 @@ export interface ObservabilityAIAssistantService { setScreenContext: (screenContext: ObservabilityAIAssistantScreenContext) => () => void; getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; conversations: ObservabilityAIAssistantConversationService; + navigate: (callback: () => void) => Promise>; } export type RenderFunction = (options: { @@ -123,4 +128,5 @@ export interface ObservabilityAIAssistantPublicStart { useChat: typeof useChat; useUserPreferredLanguage: () => UseUserPreferredLanguageResult; getContextualInsightMessages: ({}: { message: string; instructions: string }) => Message[]; + createScreenContextAction: typeof createScreenContextAction; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts new file mode 100644 index 0000000000000..3dbc4dbaf36f0 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { FromSchema } from 'json-schema-to-ts'; +import { CompatibleJSONSchema } from '../../common/functions/types'; +import type { + ScreenContextActionDefinition, + ScreenContextActionRespondFunction, +} from '../../common/types'; + +type ReturnOf> = + TActionDefinition['parameters'] extends CompatibleJSONSchema + ? FromSchema + : undefined; + +export function createScreenContextAction< + TActionDefinition extends Omit, + TResponse = ReturnOf +>( + definition: TActionDefinition, + respond: ScreenContextActionRespondFunction +): ScreenContextActionDefinition { + return { + ...definition, + respond, + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts index 225a248b160ac..61603210a44f2 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts @@ -18,8 +18,8 @@ import { MessageAddEvent } from '../../common/conversation_complete'; import { FunctionVisibility } from '../../common/functions/types'; import { MessageRole, type Message } from '../../common/types'; import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_chat_completion_chunks'; +import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message'; import type { ObservabilityAIAssistantClient } from '../service/client'; -import { createFunctionResponseMessage } from '../service/util/create_function_response_message'; import { parseSuggestionScores } from './parse_suggestion_scores'; const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts index 84057b04e3bd4..7d375a6ae9d2c 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts @@ -25,11 +25,15 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ messages: t.array(messageRt), connectorId: t.string, functions: t.array( - t.type({ - name: t.string, - description: t.string, - parameters: t.any, - }) + t.intersection([ + t.type({ + name: t.string, + description: t.string, + }), + t.partial({ + parameters: t.any, + }), + ]) ), }), t.partial({ diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/runtime_types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/runtime_types.ts index cef56f673e235..1e185018b84c3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/runtime_types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/runtime_types.ts @@ -13,7 +13,7 @@ import { type ConversationUpdateRequest, type Message, MessageRole, - type ObservabilityAIAssistantScreenContext, + type ObservabilityAIAssistantScreenContextRequest, } from '../../common/types'; const serializeableRt = t.any; @@ -94,7 +94,7 @@ export const conversationRt: t.Type = t.intersection([ }), ]); -export const screenContextRt: t.Type = t.partial({ +export const screenContextRt: t.Type = t.partial({ description: t.string, data: t.array( t.type({ @@ -103,4 +103,15 @@ export const screenContextRt: t.Type = t. value: t.any, }) ), + actions: t.array( + t.intersection([ + t.type({ + name: t.string, + description: t.string, + }), + t.partial({ + parameters: t.record(t.string, t.any), + }), + ]) + ), }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts index 618f7eef00276..5f55f11c29764 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts @@ -10,13 +10,13 @@ import Ajv, { type ErrorObject, type ValidateFunction } from 'ajv'; import dedent from 'dedent'; import { compact, keyBy } from 'lodash'; import { - type ContextRegistry, FunctionVisibility, - type RegisterContextDefinition, type ContextDefinition, + type ContextRegistry, type FunctionResponse, + type RegisterContextDefinition, } from '../../../common/functions/types'; -import type { Message, ObservabilityAIAssistantScreenContext } from '../../../common/types'; +import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../../common/types'; import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions'; import type { FunctionHandler, FunctionHandlerRegistry, RegisterFunction } from '../types'; @@ -35,9 +35,13 @@ export class ChatFunctionClient { private readonly functionRegistry: FunctionHandlerRegistry = new Map(); private readonly validators: Map = new Map(); - constructor(private readonly screenContexts: ObservabilityAIAssistantScreenContext[]) { + private readonly actions: Required['actions']; + + constructor(private readonly screenContexts: ObservabilityAIAssistantScreenContextRequest[]) { const allData = compact(screenContexts.flatMap((context) => context.data)); + this.actions = compact(screenContexts.flatMap((context) => context.actions)); + if (allData.length) { this.registerFunction( { @@ -74,10 +78,18 @@ export class ChatFunctionClient { } ); } + + this.actions.forEach((action) => { + if (action.parameters) { + this.validators.set(action.name, ajv.compile(action.parameters)); + } + }); } registerFunction: RegisterFunction = (definition, respond) => { - this.validators.set(definition.name, ajv.compile(definition.parameters)); + if (definition.parameters) { + this.validators.set(definition.name, ajv.compile(definition.parameters)); + } this.functionRegistry.set(definition.name, { definition, respond }); }; @@ -85,8 +97,12 @@ export class ChatFunctionClient { this.contextRegistry.set(context.name, context); }; - private validate(name: string, parameters: unknown) { + validate(name: string, parameters: unknown) { const validator = this.validators.get(name)!; + if (!validator) { + return; + } + const result = validator(parameters); if (!result) { throw new FunctionArgsValidationError(validator.errors!); @@ -97,6 +113,10 @@ export class ChatFunctionClient { return Array.from(this.contextRegistry.values()); } + hasAction(name: string) { + return !!this.actions.find((action) => action.name === name)!; + } + getFunctions({ contexts, filter, @@ -117,6 +137,10 @@ export class ChatFunctionClient { return matchingDefinitions.map((definition) => functionsByName[definition.name]); } + getActions() { + return this.actions; + } + hasFunction(name: string): boolean { return this.functionRegistry.has(name); } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock_claude_adapter.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock_claude_adapter.ts index f864ba02c8ac5..de727f1168aa0 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock_claude_adapter.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/bedrock_claude_adapter.ts @@ -90,7 +90,9 @@ export const createBedrockClaudeAdapter: LlmApiAdapterFactory = ({ (fn) => ` ${fn.name} ${fn.description} - + ${ + fn.parameters + ? ` ${jsonSchemaToFlatParameters(fn.parameters).map((param) => { return ` ${param.name} @@ -107,7 +109,9 @@ export const createBedrockClaudeAdapter: LlmApiAdapterFactory = ({ `; })} - + ` + : '' + } ` ) .join('\n')} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/openai_adapter.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/openai_adapter.ts index 61935d891a1db..5eb6834d44482 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/openai_adapter.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/openai_adapter.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { compact, isEmpty, omit } from 'lodash'; +import { compact, isEmpty, merge, omit } from 'lodash'; import OpenAI from 'openai'; import { MessageRole } from '../../../../common'; import { processOpenAiStream } from '../../../../common/utils/process_openai_stream'; @@ -44,7 +44,16 @@ export const createOpenAiAdapter: LlmApiAdapterFactory = ({ }) ); - const functionsForOpenAI = functions; + const functionsForOpenAI = functions?.map((fn) => ({ + ...fn, + parameters: merge( + { + type: 'object', + properties: {}, + }, + fn.parameters + ), + })); const request: Omit & { model?: string } = { messages: messagesForOpenAI as OpenAI.ChatCompletionCreateParams['messages'], diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_bedrock_stream.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_bedrock_stream.ts index 03fa2fa86461b..bf747051a0347 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_bedrock_stream.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_bedrock_stream.ts @@ -24,7 +24,7 @@ async function parseFunctionCallXml({ functions, }: { xml: string; - functions?: Array<{ name: string; description: string; parameters: JSONSchema }>; + functions?: Array<{ name: string; description: string; parameters?: JSONSchema }>; }) { const parser = new Parser(); @@ -45,7 +45,9 @@ async function parseFunctionCallXml({ ); } - const args = convertDeserializedXmlWithJsonSchema(parameters, functionDef.parameters); + const args = functionDef.parameters + ? convertDeserializedXmlWithJsonSchema(parameters, functionDef.parameters) + : {}; return { name: fnName, @@ -58,7 +60,7 @@ export function processBedrockStream({ functions, }: { logger: Logger; - functions?: Array<{ name: string; description: string; parameters: JSONSchema }>; + functions?: Array<{ name: string; description: string; parameters?: JSONSchema }>; }) { return (source: Observable) => new Observable((subscriber) => { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts index fff3edeccb7db..44c8b5711be32 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/types.ts @@ -15,7 +15,7 @@ import { CompatibleJSONSchema } from '../../../../common/functions/types'; export type LlmApiAdapterFactory = (options: { logger: Logger; messages: Message[]; - functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>; + functions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>; functionCall?: string; }) => LlmApiAdapter; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index 96ef44adfee64..216e3dfce1361 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -23,10 +23,10 @@ import { MessageAddEvent, StreamingChatResponseEventType, } from '../../../common/conversation_complete'; +import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message'; import type { CreateChatCompletionResponseChunk } from '../../../common/utils/process_openai_stream'; -import type { ChatFunctionClient } from '../chat_function_client'; +import { ChatFunctionClient } from '../chat_function_client'; import type { KnowledgeBaseService } from '../knowledge_base_service'; -import { createFunctionResponseMessage } from '../util/create_function_response_message'; import { observableIntoStream } from '../util/observable_into_stream'; type ChunkDelta = CreateChatCompletionResponseChunk['choices'][number]['delta']; @@ -116,6 +116,9 @@ describe('Observability AI Assistant client', () => { executeFunction: jest.fn(), getFunctions: jest.fn(), hasFunction: jest.fn(), + hasAction: jest.fn(), + getActions: jest.fn(), + validate: jest.fn(), } as any; let llmSimulator: LlmSimulator; @@ -128,6 +131,9 @@ describe('Observability AI Assistant client', () => { return name !== 'context'; }); + functionClientMock.hasAction.mockReturnValue(false); + functionClientMock.getActions.mockReturnValue([]); + actionsClientMock.get.mockResolvedValue({ actionTypeId: ObservabilityAIAssistantConnectorType.OpenAI, id: 'foo', @@ -1468,4 +1474,123 @@ describe('Observability AI Assistant client', () => { 'You MUST respond in the users preferred language which is: Orcish. This is a system message' ); }); + + describe('when executing an action', () => { + let completePromise: Promise; + + beforeEach(async () => { + client = createClient(); + + llmSimulator = createLlmSimulator(); + + actionsClientMock.execute.mockImplementation(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + const complete$ = await client.complete({ + connectorId: 'foo', + messages: [ + system('This is a system message'), + user('Can you call the my_action function?'), + ], + functionClient: new ChatFunctionClient([ + { + actions: [ + { + name: 'my_action', + description: 'My action description', + parameters: { + type: 'object', + properties: { + foo: { + type: 'string', + }, + }, + required: ['foo'], + }, + }, + ], + }, + ]), + signal: new AbortController().signal, + title: 'My predefined title', + persist: false, + }); + + const messages: Message[] = []; + + completePromise = new Promise((resolve, reject) => { + complete$.subscribe({ + next: (event) => { + if (event.type === StreamingChatResponseEventType.MessageAdd) { + messages.push(event.message); + } + }, + complete: () => resolve(messages), + }); + }); + }); + + describe('and validation succeeds', () => { + beforeEach(async () => { + await llmSimulator.next({ + function_call: { name: 'my_action', arguments: JSON.stringify({ foo: 'bar' }) }, + }); + await llmSimulator.complete(); + }); + + it('completes the observable function request being the last event', async () => { + const messages = await completePromise; + expect(messages.length).toBe(1); + + expect(messages[0].message.function_call).toEqual({ + name: 'my_action', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }); + }); + }); + + describe('and validation fails', () => { + beforeEach(async () => { + await llmSimulator.next({ + function_call: { name: 'my_action', arguments: JSON.stringify({ bar: 'foo' }) }, + }); + + await llmSimulator.complete(); + + await waitFor(() => + actionsClientMock.execute.mock.calls.length === 2 + ? Promise.resolve() + : Promise.reject(new Error('Waiting until execute is called again')) + ); + + await nextTick(); + + await llmSimulator.next({ + content: 'Looks like the function call failed', + }); + + await llmSimulator.complete(); + }); + + it('appends a function response error and sends it back to the LLM', async () => { + const messages = await completePromise; + expect(messages.length).toBe(3); + + expect(messages[0].message.function_call?.name).toBe('my_action'); + + expect(messages[1].message.name).toBe('my_action'); + + expect(JSON.parse(messages[1].message.content ?? '{}')).toHaveProperty('error'); + + expect(messages[2].message.content).toBe('Looks like the function call failed'); + }); + }); + }); }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index cec07bc949cab..932adbc3a0f5d 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -48,6 +48,7 @@ import { type Message, } from '../../../common/types'; import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks'; +import { createFunctionResponseError } from '../../../common/utils/create_function_response_error'; import { emitWithConcatenatedMessage } from '../../../common/utils/emit_with_concatenated_message'; import type { ChatFunctionClient } from '../chat_function_client'; import { @@ -159,6 +160,19 @@ export class ObservabilityAIAssistantClient { const MAX_FUNCTION_CALLS = 5; const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; + const allFunctions = functionClient + .getFunctions() + .filter((fn) => { + const visibility = fn.definition.visibility ?? FunctionVisibility.All; + return ( + visibility === FunctionVisibility.All || + visibility === FunctionVisibility.AssistantOnly + ); + }) + .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')); + + const allActions = functionClient.getActions(); + const next = async (nextMessages: Message[]): Promise => { const lastMessage = last(nextMessages); @@ -199,18 +213,7 @@ export class ObservabilityAIAssistantClient { return await next(nextMessages.concat(addedMessage)); } else if (isUserMessage) { const functions = - numFunctionsCalled >= MAX_FUNCTION_CALLS - ? [] - : functionClient - .getFunctions() - .filter((fn) => { - const visibility = fn.definition.visibility ?? FunctionVisibility.All; - return ( - visibility === FunctionVisibility.All || - visibility === FunctionVisibility.AssistantOnly - ); - }) - .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')); + numFunctionsCalled >= MAX_FUNCTION_CALLS ? [] : allFunctions.concat(allActions); if (numFunctionsCalled >= MAX_FUNCTION_CALLS) { this.dependencies.logger.debug( @@ -254,9 +257,37 @@ export class ObservabilityAIAssistantClient { } if (isAssistantMessageWithFunctionRequest) { - const span = apm.startSpan( - `execute_function ${lastMessage.message.function_call!.name}` - ); + const functionCallName = lastMessage.message.function_call!.name; + + if (functionClient.hasAction(functionCallName)) { + this.dependencies.logger.debug(`Executing client-side action: ${functionCallName}`); + + // if validation fails, return the error to the LLM. + // otherwise, close the stream. + + try { + functionClient.validate( + functionCallName, + JSON.parse(lastMessage.message.function_call!.arguments || '{}') + ); + } catch (error) { + const functionResponseMessage = createFunctionResponseError({ + name: functionCallName, + error, + }); + nextMessages = nextMessages.concat(functionResponseMessage.message); + + subscriber.next(functionResponseMessage); + + return await next(nextMessages); + } + + subscriber.complete(); + + return; + } + + const span = apm.startSpan(`execute_function ${functionCallName}`); span?.addLabels({ ai_assistant_args: JSON.stringify(lastMessage.message.function_call!.arguments ?? {}), @@ -273,7 +304,7 @@ export class ObservabilityAIAssistantClient { : await functionClient .executeFunction({ connectorId, - name: lastMessage.message.function_call!.name, + name: functionCallName, messages: nextMessages, args: lastMessage.message.function_call!.arguments, signal, @@ -313,6 +344,7 @@ export class ObservabilityAIAssistantClient { numFunctionsCalled++; if (signal.aborted) { + span?.end(); return; } @@ -455,7 +487,7 @@ export class ObservabilityAIAssistantClient { }: { messages: Message[]; connectorId: string; - functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>; + functions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>; functionCall?: string; signal: AbortSignal; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/index.ts index 1df243fc2ba35..e1a19df6b44b0 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/index.ts @@ -13,7 +13,10 @@ import type { SecurityPluginStart } from '@kbn/security-plugin/server'; import { getSpaceIdFromPath } from '@kbn/spaces-plugin/common'; import type { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server'; import { once } from 'lodash'; -import { KnowledgeBaseEntryRole, ObservabilityAIAssistantScreenContext } from '../../common/types'; +import { + KnowledgeBaseEntryRole, + ObservabilityAIAssistantScreenContextRequest, +} from '../../common/types'; import type { ObservabilityAIAssistantPluginStartDependencies } from '../types'; import { ChatFunctionClient } from './chat_function_client'; import { ObservabilityAIAssistantClient } from './client'; @@ -291,7 +294,7 @@ export class ObservabilityAIAssistantService { resources, client, }: { - screenContexts: ObservabilityAIAssistantScreenContext[]; + screenContexts: ObservabilityAIAssistantScreenContextRequest[]; signal: AbortSignal; resources: RespondFunctionResources; client: ObservabilityAIAssistantClient; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts index 1ade4fc0e179a..cc21d373be4eb 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts @@ -11,7 +11,7 @@ import type { FunctionDefinition, FunctionResponse, } from '../../common/functions/types'; -import type { Message, ObservabilityAIAssistantScreenContext } from '../../common/types'; +import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../common/types'; import type { ObservabilityAIAssistantRouteHandlerResources } from '../routes/types'; import { ChatFunctionClient } from './chat_function_client'; import type { ObservabilityAIAssistantClient } from './client'; @@ -26,7 +26,7 @@ type RespondFunction = ( arguments: TArguments; messages: Message[]; connectorId: string; - screenContexts: ObservabilityAIAssistantScreenContext[]; + screenContexts: ObservabilityAIAssistantScreenContextRequest[]; }, signal: AbortSignal ) => Promise; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/buttons/new_chat_button.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/buttons/new_chat_button.tsx index 2d7387d9a1040..6dadba00f2394 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/buttons/new_chat_button.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/buttons/new_chat_button.tsx @@ -11,8 +11,9 @@ import { i18n } from '@kbn/i18n'; export function NewChatButton( props: React.ComponentProps & { collapsed?: boolean } ) { + const { collapsed, ...nextProps } = props; return !props.collapsed ? ( - + {i18n.translate('xpack.observabilityAiAssistant.newChatButton', { defaultMessage: 'New chat', @@ -23,7 +24,7 @@ export function NewChatButton( data-test-subj="observabilityAiAssistantNewChatButtonButton" iconType={EuiIconNewChat} size="xs" - {...props} + {...nextProps} /> ); } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/nav_control/index.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/nav_control/index.tsx index 1366c41464241..2210f50d7a6b7 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/nav_control/index.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/nav_control/index.tsx @@ -15,6 +15,7 @@ import { ChatFlyout } from '../chat/chat_flyout'; import { useKibana } from '../../hooks/use_kibana'; import { useIsNavControlVisible } from '../../hooks/is_nav_control_visible'; import { useTheme } from '../../hooks/use_theme'; +import { useNavControlScreenContext } from '../../hooks/use_nav_control_screen_context'; export function NavControl({}: {}) { const service = useObservabilityAIAssistantAppService(); @@ -31,6 +32,8 @@ export function NavControl({}: {}) { const [hasBeenOpened, setHasBeenOpened] = useState(false); + useNavControlScreenContext(); + const chatService = useAbortableAsync( ({ signal }) => { return hasBeenOpened ? service.start({ signal }) : undefined; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_conversation.test.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_conversation.test.tsx index 2f7d872d191d7..e190cc0c34782 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_conversation.test.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_conversation.test.tsx @@ -13,7 +13,7 @@ import { } from '@testing-library/react-hooks'; import { merge } from 'lodash'; import React from 'react'; -import { Observable, Subject } from 'rxjs'; +import { Observable, of, Subject } from 'rxjs'; import { MessageRole, StreamingChatResponseEventType, @@ -54,6 +54,7 @@ const mockService: MockedService = { openNewConversation: jest.fn(), predefinedConversation$: new Observable(), }, + navigate: jest.fn().mockReturnValue(of()), }; const mockChatService = createMockChatService(); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_nav_control_screen_context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_nav_control_screen_context.ts new file mode 100644 index 0000000000000..7baf9aae82512 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/hooks/use_nav_control_screen_context.ts @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { useEffect, useState } from 'react'; +import datemath from '@elastic/datemath'; +import moment from 'moment'; +import { useKibana } from './use_kibana'; +import { useObservabilityAIAssistantAppService } from './use_observability_ai_assistant_app_service'; + +export function useNavControlScreenContext() { + const service = useObservabilityAIAssistantAppService(); + + const { + services: { + plugins: { + start: { data }, + }, + }, + } = useKibana(); + + const { from, to } = data.query.timefilter.timefilter.getTime(); + + const [href, setHref] = useState(window.location.href); + + useEffect(() => { + const originalPushState = window.history.pushState.bind(window.history); + const originalReplaceState = window.history.replaceState.bind(window.history); + + let unmounted: boolean = false; + + function updateHref() { + if (!unmounted) { + setHref(window.location.href); + } + } + + window.history.pushState = (...args: Parameters) => { + originalPushState(...args); + updateHref(); + }; + + window.history.replaceState = (...args: Parameters) => { + originalReplaceState(...args); + updateHref(); + }; + window.addEventListener('popstate', updateHref); + + window.addEventListener('hashchange', updateHref); + + return () => { + unmounted = true; + window.removeEventListener('popstate', updateHref); + window.removeEventListener('hashchange', updateHref); + }; + }, []); + + useEffect(() => { + const start = datemath.parse(from)?.format() ?? moment().subtract(1, 'day').toISOString(); + const end = datemath.parse(to)?.format() ?? moment().toISOString(); + + return service.setScreenContext({ + screenDescription: `The user is looking at ${href}. The current time range is ${start} - ${end}.`, + }); + }, [service, from, to, href]); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/types.ts index 5bce062b9eecd..2876ddaf3332d 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/types.ts @@ -27,6 +27,7 @@ import type { TriggersAndActionsUIPublicPluginSetup, TriggersAndActionsUIPublicPluginStart, } from '@kbn/triggers-actions-ui-plugin/public'; +import type { DataPublicPluginStart } from '@kbn/data-plugin/public'; // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface ObservabilityAIAssistantAppPublicStart {} @@ -44,6 +45,7 @@ export interface ObservabilityAIAssistantAppPluginStartDependencies { observabilityShared: ObservabilitySharedPluginStart; ml: MlPluginStart; triggersActionsUi: TriggersAndActionsUIPublicPluginStart; + data: DataPublicPluginStart; } export interface ObservabilityAIAssistantAppPluginSetupDependencies { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts index ff36484381b2a..a3f86aaa03de7 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/index.ts @@ -22,7 +22,7 @@ import { } from '@kbn/observability-ai-assistant-plugin/common/utils/concatenate_chat_completion_chunks'; import { ChatCompletionChunkEvent } from '@kbn/observability-ai-assistant-plugin/common/conversation_complete'; import { emitWithConcatenatedMessage } from '@kbn/observability-ai-assistant-plugin/common/utils/emit_with_concatenated_message'; -import { createFunctionResponseMessage } from '@kbn/observability-ai-assistant-plugin/server/service/util/create_function_response_message'; +import { createFunctionResponseMessage } from '@kbn/observability-ai-assistant-plugin/common/utils/create_function_response_message'; import type { FunctionRegistrationParameters } from '..'; import { correctCommonEsqlMistakes } from './correct_common_esql_mistakes'; @@ -352,11 +352,16 @@ export function registerQueryFunction({ ], connectorId, signal, + functions: functions.getActions(), } ); return esqlResponse$.pipe( emitWithConcatenatedMessage((msg) => { + if (msg.message.function_call.name) { + return msg; + } + const esqlQuery = correctCommonEsqlMistakes(msg.message.content, resources.logger).match( /```esql([\s\S]*?)```/ )?.[1]; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/tsconfig.json b/x-pack/plugins/observability_solution/observability_ai_assistant_app/tsconfig.json index ac80f9b74f1cf..65fe909cd7e04 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/tsconfig.json +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/tsconfig.json @@ -48,7 +48,8 @@ "@kbn/ml-plugin", "@kbn/react-kibana-context-theme", "@kbn/shared-ux-link-redirect-app", - "@kbn/shared-ux-utility" + "@kbn/shared-ux-utility", + "@kbn/data-plugin" ], "exclude": ["target/**/*"] } diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts index c4913ce3e41d2..14b6dc11c6437 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts @@ -6,7 +6,7 @@ */ import { Response } from 'supertest'; import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common'; -import { omit } from 'lodash'; +import { omit, pick } from 'lodash'; import { PassThrough } from 'stream'; import expect from '@kbn/expect'; import { @@ -17,7 +17,8 @@ import { StreamingChatResponseEventType, } from '@kbn/observability-ai-assistant-plugin/common/conversation_complete'; import type OpenAI from 'openai'; -import { createLlmProxy, LlmProxy } from '../../common/create_llm_proxy'; +import { ObservabilityAIAssistantScreenContextRequest } from '@kbn/observability-ai-assistant-plugin/common/types'; +import { createLlmProxy, LlmProxy, LlmResponseSimulator } from '../../common/create_llm_proxy'; import { createOpenAiChunk } from '../../common/create_openai_chunk'; import { FtrProviderContext } from '../../common/ftr_provider_context'; @@ -48,6 +49,66 @@ export default function ApiTest({ getService }: FtrProviderContext) { let proxy: LlmProxy; let connectorId: string; + async function getEvents( + params: { screenContexts?: ObservabilityAIAssistantScreenContextRequest[] }, + cb: (conversationSimulator: LlmResponseSimulator) => Promise + ) { + const titleInterceptor = proxy.intercept( + 'title', + (body) => + (JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming).functions?.find( + (fn) => fn.name === 'title_conversation' + ) !== undefined + ); + + const conversationInterceptor = proxy.intercept( + 'conversation', + (body) => + (JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming).functions?.find( + (fn) => fn.name === 'title_conversation' + ) === undefined + ); + + const responsePromise = new Promise((resolve, reject) => { + supertest + .post(COMPLETE_API_URL) + .set('kbn-xsrf', 'foo') + .send({ + messages, + connectorId, + persist: true, + screenContexts: params.screenContexts || [], + }) + .end((err, response) => { + if (err) { + return reject(err); + } + return resolve(response); + }); + }); + + const [conversationSimulator, titleSimulator] = await Promise.all([ + conversationInterceptor.waitForIntercept(), + titleInterceptor.waitForIntercept(), + ]); + + await titleSimulator.status(200); + await titleSimulator.next('My generated title'); + await titleSimulator.complete(); + + await conversationSimulator.status(200); + await cb(conversationSimulator); + + const response = await responsePromise; + + return String(response.body) + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as StreamingChatResponseEvent) + .slice(2); // ignore context request/response, we're testing this elsewhere + } + before(async () => { proxy = await createLlmProxy(); @@ -185,80 +246,30 @@ export default function ApiTest({ getService }: FtrProviderContext) { }); describe('when creating a new conversation', async () => { - let lines: StreamingChatResponseEvent[]; - before(async () => { - const titleInterceptor = proxy.intercept( - 'title', - (body) => - ( - JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming - ).functions?.find((fn) => fn.name === 'title_conversation') !== undefined - ); + let events: StreamingChatResponseEvent[]; - const conversationInterceptor = proxy.intercept( - 'conversation', - (body) => - ( - JSON.parse(body) as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming - ).functions?.find((fn) => fn.name === 'title_conversation') === undefined - ); - - const responsePromise = new Promise((resolve, reject) => { - supertest - .post(COMPLETE_API_URL) - .set('kbn-xsrf', 'foo') - .send({ - messages, - connectorId, - persist: true, - screenContexts: [], - }) - .end((err, response) => { - if (err) { - return reject(err); - } - return resolve(response); - }); + before(async () => { + events = await getEvents({}, async (conversationSimulator) => { + await conversationSimulator.next('Hello'); + await conversationSimulator.next(' again'); + await conversationSimulator.complete(); }); - - const [conversationSimulator, titleSimulator] = await Promise.all([ - conversationInterceptor.waitForIntercept(), - titleInterceptor.waitForIntercept(), - ]); - - await titleSimulator.status(200); - await titleSimulator.next('My generated title'); - await titleSimulator.complete(); - - await conversationSimulator.status(200); - await conversationSimulator.next('Hello'); - await conversationSimulator.next(' again'); - await conversationSimulator.complete(); - - const response = await responsePromise; - - lines = String(response.body) - .split('\n') - .map((line) => line.trim()) - .filter(Boolean) - .map((line) => JSON.parse(line) as StreamingChatResponseEvent) - .slice(2); // ignore context request/response, we're testing this elsewhere }); it('creates a new conversation', async () => { - expect(omit(lines[0], 'id')).to.eql({ + expect(omit(events[0], 'id')).to.eql({ type: StreamingChatResponseEventType.ChatCompletionChunk, message: { content: 'Hello', }, }); - expect(omit(lines[1], 'id')).to.eql({ + expect(omit(events[1], 'id')).to.eql({ type: StreamingChatResponseEventType.ChatCompletionChunk, message: { content: ' again', }, }); - expect(omit(lines[2], 'id', 'message.@timestamp')).to.eql({ + expect(omit(events[2], 'id', 'message.@timestamp')).to.eql({ type: StreamingChatResponseEventType.MessageAdd, message: { message: { @@ -272,7 +283,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { }, }, }); - expect(omit(lines[3], 'conversation.id', 'conversation.last_updated')).to.eql({ + expect(omit(events[3], 'conversation.id', 'conversation.last_updated')).to.eql({ type: StreamingChatResponseEventType.ConversationCreate, conversation: { title: 'My generated title', @@ -281,7 +292,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { }); after(async () => { - const createdConversationId = lines.filter( + const createdConversationId = events.filter( (line): line is ConversationCreateEvent => line.type === StreamingChatResponseEventType.ConversationCreate )[0]?.conversation.id; @@ -299,6 +310,79 @@ export default function ApiTest({ getService }: FtrProviderContext) { }); }); + describe('after executing a screen context action', async () => { + let events: StreamingChatResponseEvent[]; + + before(async () => { + events = await getEvents( + { + screenContexts: [ + { + actions: [ + { + name: 'my_action', + description: 'My action', + parameters: { + type: 'object', + properties: { + foo: { + type: 'string', + }, + }, + }, + }, + ], + }, + ], + }, + async (conversationSimulator) => { + await conversationSimulator.next({ + function_call: { name: 'my_action', arguments: JSON.stringify({ foo: 'bar' }) }, + }); + await conversationSimulator.complete(); + } + ); + }); + + it('closes the stream without persisting the conversation', () => { + expect( + pick( + events[events.length - 1], + 'message.message.content', + 'message.message.function_call', + 'message.message.role' + ) + ).to.eql({ + message: { + message: { + content: '', + function_call: { + name: 'my_action', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + }); + }); + + it('does not store the conversation', async () => { + expect( + events.filter((event) => event.type === StreamingChatResponseEventType.ConversationCreate) + .length + ).to.eql(0); + + const conversations = await observabilityAIAssistantAPIClient + .writeUser({ + endpoint: 'POST /internal/observability_ai_assistant/conversations', + }) + .expect(200); + + expect(conversations.body.conversations.length).to.be(0); + }); + }); + // todo it.skip('updates an existing conversation', async () => {}); diff --git a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts index 5c33cd1891d24..1d4553d1f36cd 100644 --- a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts +++ b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts @@ -212,10 +212,11 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte arguments: JSON.stringify({ queries: [], categories: [] }), }); - expect(pick(contextResponse, 'name', 'content')).to.eql({ - name: 'context', - content: JSON.stringify({ screen_description: '', learnings: [] }), - }); + expect(contextResponse.name).to.eql('context'); + + const parsedContext = JSON.parse(contextResponse.content || ''); + + expect(parsedContext.screen_description).to.contain('The user is looking at'); expect(pick(assistantResponse, 'role', 'content')).to.eql({ role: 'assistant', @@ -275,10 +276,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte arguments: JSON.stringify({ queries: [], categories: [] }), }); - expect(pick(contextResponse, 'name', 'content')).to.eql({ - name: 'context', - content: JSON.stringify({ screen_description: '', learnings: [] }), - }); + expect(contextResponse.name).to.eql('context'); expect(pick(assistantResponse, 'role', 'content')).to.eql({ role: 'assistant',