Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Obs AI Assistant] Add API test for get_alerts_dataset_info tool #212858

Merged
merged 15 commits into from
Mar 5, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import { MessageRole, ShortIdTable, type Message } from '../../../common';
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
import { FunctionCallChatFunction } from '../../service/types';

const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields';

export async function getRelevantFieldNames({
index,
start,
Expand Down Expand Up @@ -100,19 +102,18 @@ export async function getRelevantFieldNames({
await chat('get_relevant_dataset_names', {
signal,
stream: true,
systemMessage: `You are a helpful assistant for Elastic Observability.
Your task is to create a list of field names that are relevant
to the conversation, using ONLY the list of fields and
types provided in the last user message. DO NOT UNDER ANY
CIRCUMSTANCES include fields not mentioned in this list.`,
systemMessage: `You are a helpful assistant for Elastic Observability.
Your task is to determine which fields are relevant to the conversation by selecting only the field IDs from the provided list.
The list in the user message consists of JSON objects that map a human-readable "field" name to its unique "id".
You must not output any field names — only the corresponding "id" values. Ensure that your output follows the exact JSON format specified.`,
messages: [
// remove the last function request
...messages.slice(0, -1),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: `This is the list:
content: `Below is a list of fields. Each entry is a JSON object that contains a "field" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields:

${fieldsInChunk
.map((field) => JSON.stringify({ field, id: shortIdTable.take(field) }))
Expand All @@ -122,8 +123,12 @@ export async function getRelevantFieldNames({
],
functions: [
{
name: 'select_relevant_fields',
description: 'The IDs of the fields you consider relevant to the conversation',
name: SELECT_RELEVANT_FIELDS_NAME,
description: `Return only the field IDs (from the provided list) that you consider relevant to the conversation. Do not use any of the field names. Your response must be in the exact JSON format:
{
"fieldIds": ["id1", "id2", "id3"]
}
Only include IDs from the list provided in the user message.`,
parameters: {
type: 'object',
properties: {
Expand All @@ -138,7 +143,7 @@ export async function getRelevantFieldNames({
} as const,
},
],
functionCall: 'select_relevant_fields',
functionCall: SELECT_RELEVANT_FIELDS_NAME,
})
).pipe(concatenateChatCompletionChunks());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ export function registerAlertsFunction({
signal,
chat: (
operationName,
{ messages: nextMessages, functionCall, functions: nextFunctions }
{ messages: nextMessages, functionCall, functions: nextFunctions, systemMessage }
) => {
return chat(operationName, {
systemMessage,
messages: nextMessages,
functionCall,
functions: nextFunctions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
expect(status).to.be(200);
});

Expand All @@ -104,7 +104,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const simulator = await simulatorPromise;
const requestData = simulator.requestBody; // This is the request sent to the LLM
expect(requestData.messages[0].content).to.eql(SYSTEM_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
scopes: ['all'],
});

await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();

return String(response.body)
.split('\n')
Expand Down Expand Up @@ -133,7 +133,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon

await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));

await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();

parsedEvents = decodeEvents(receivedChunks.join(''));
});
Expand Down Expand Up @@ -243,7 +243,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const simulator = await simulatorPromise;
const requestData = simulator.requestBody;
expect(requestData.messages[0].role).to.eql('system');
Expand Down Expand Up @@ -420,7 +420,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon

expect(createResponse.status).to.be(200);

await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();

conversationCreatedEvent = getConversationCreatedEvent(createResponse.body);

Expand Down Expand Up @@ -463,7 +463,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon

expect(updatedResponse.status).to.be(200);

await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
});

after(async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});

await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();

alertsEvents = getMessageAddedEvents(alertsResponseBody);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});

await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();

events = getMessageAddedEvents(responseBody);
});
Expand Down
Loading