Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DERCBOT-919] Add question condensing LLM and prompt template #1827

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion bot/admin/server/src/main/kotlin/BotAdminService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,9 @@ object BotAdminService {
// delete the RAG configuration
ragConfigurationDAO.findByNamespaceAndBotId(app.namespace, app.name)?.let { config ->
ragConfigurationDAO.delete(config._id)
config.llmSetting.apiKey?.let { SecurityUtils.deleteSecret(it) }
config.questionCondensingLlmSetting?.apiKey?.let { SecurityUtils.deleteSecret(it) }
config.questionAnsweringLlmSetting?.apiKey?.let { SecurityUtils.deleteSecret(it) }
config.llmSetting?.apiKey?.let { SecurityUtils.deleteSecret(it) }
config.emSetting.apiKey?.let { SecurityUtils.deleteSecret(it) }
}

Expand Down
34 changes: 30 additions & 4 deletions bot/admin/server/src/main/kotlin/model/BotRAGConfigurationDTO.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ai.tock.bot.admin.model

import ai.tock.bot.admin.bot.rag.BotRAGConfiguration
import ai.tock.bot.admin.service.VectorStoreService
import ai.tock.genai.orchestratorclient.requests.PromptTemplate
import ai.tock.genai.orchestratorcore.mappers.EMSettingMapper
import ai.tock.genai.orchestratorcore.mappers.LLMSettingMapper
import ai.tock.genai.orchestratorcore.models.Constants
Expand All @@ -34,26 +35,39 @@ data class BotRAGConfigurationDTO(
val namespace: String,
val botId: String,
val enabled: Boolean = false,
val llmSetting: LLMSettingDTO,
val questionCondensingLlmSetting: LLMSettingDTO? = null,
val questionCondensingPrompt: PromptTemplate? = null,
val questionAnsweringLlmSetting: LLMSettingDTO,
val questionAnsweringPrompt: PromptTemplate,
val emSetting: EMSettingDTO,
val indexSessionId: String? = null,
val indexName: String? = null,
val noAnswerSentence: String,
val noAnswerStoryId: String? = null,
val documentsRequired: Boolean = true,
val debugEnabled: Boolean,
val maxDocumentsRetrieved: Int,
val maxMessagesFromHistory: Int,
) {
constructor(configuration: BotRAGConfiguration) : this(
id = configuration._id.toString(),
namespace = configuration.namespace,
botId = configuration.botId,
enabled = configuration.enabled,
llmSetting = configuration.llmSetting.toDTO(),
questionCondensingLlmSetting = configuration.questionCondensingLlmSetting?.toDTO(),
questionCondensingPrompt = configuration.questionCondensingPrompt,
questionAnsweringLlmSetting = configuration.getQuestionAnsweringLLMSetting().toDTO(),
questionAnsweringPrompt = configuration.questionAnsweringPrompt
?: configuration.initQuestionAnsweringPrompt(),
emSetting = configuration.emSetting.toDTO(),
indexSessionId = configuration.indexSessionId,
indexName = configuration.generateIndexName(),
noAnswerSentence = configuration.noAnswerSentence,
noAnswerStoryId = configuration.noAnswerStoryId,
documentsRequired = configuration.documentsRequired,
debugEnabled = configuration.debugEnabled,
maxDocumentsRetrieved = configuration.maxDocumentsRetrieved,
maxMessagesFromHistory = configuration.maxMessagesFromHistory,
)

fun toBotRAGConfiguration(): BotRAGConfiguration =
Expand All @@ -62,12 +76,20 @@ data class BotRAGConfigurationDTO(
namespace = namespace,
botId = botId,
enabled = enabled,
llmSetting = LLMSettingMapper.toEntity(
questionCondensingLlmSetting = LLMSettingMapper.toEntity(
namespace = namespace,
botId = botId,
feature = Constants.GEN_AI_RAG_QUESTION_CONDENSING,
dto = questionCondensingLlmSetting!!
),
questionCondensingPrompt = questionCondensingPrompt,
questionAnsweringLlmSetting = LLMSettingMapper.toEntity(
namespace = namespace,
botId = botId,
feature = Constants.GEN_AI_RAG_QUESTION_ANSWERING,
dto = llmSetting
dto = questionAnsweringLlmSetting
),
questionAnsweringPrompt = questionAnsweringPrompt,
emSetting = EMSettingMapper.toEntity(
namespace = namespace,
botId = botId,
Expand All @@ -78,6 +100,9 @@ data class BotRAGConfigurationDTO(
noAnswerSentence = noAnswerSentence,
noAnswerStoryId = noAnswerStoryId,
documentsRequired = documentsRequired,
debugEnabled = debugEnabled,
maxDocumentsRetrieved = maxDocumentsRetrieved,
maxMessagesFromHistory = maxMessagesFromHistory,
)
}

Expand All @@ -87,6 +112,7 @@ private fun BotRAGConfiguration.generateIndexName(): String? {
namespace,
botId,
it,
maxDocumentsRetrieved,
VectorStoreService.getVectorStoreConfiguration(namespace, botId, enabled = true)
?.setting
).second
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ai.tock.bot.admin.model

import ai.tock.bot.admin.bot.sentencegeneration.BotSentenceGenerationConfiguration
import ai.tock.genai.orchestratorclient.requests.PromptTemplate
import ai.tock.genai.orchestratorcore.mappers.LLMSettingMapper
import ai.tock.genai.orchestratorcore.models.Constants
import ai.tock.genai.orchestratorcore.models.llm.LLMSettingDTO
Expand All @@ -32,6 +33,7 @@ data class BotSentenceGenerationConfigurationDTO(
val enabled: Boolean = false,
val nbSentences: Int,
val llmSetting: LLMSettingDTO,
val prompt: PromptTemplate,
) {
constructor(configuration: BotSentenceGenerationConfiguration) : this(
id = configuration._id.toString(),
Expand All @@ -40,6 +42,7 @@ data class BotSentenceGenerationConfigurationDTO(
enabled = configuration.enabled,
nbSentences = configuration.nbSentences,
llmSetting = configuration.llmSetting.toDTO(),
prompt = configuration.prompt ?: configuration.initPrompt()
)

fun toSentenceGenerationConfiguration(): BotSentenceGenerationConfiguration =
Expand All @@ -54,7 +57,8 @@ data class BotSentenceGenerationConfigurationDTO(
botId = botId,
feature = Constants.GEN_AI_COMPLETION_SENTENCE_GENERATION,
dto = llmSetting
)
),
prompt = prompt
)
}

Expand Down
12 changes: 4 additions & 8 deletions bot/admin/server/src/main/kotlin/service/CompletionService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ object CompletionService {
// Get LLM Setting and override the temperature
val llmSetting = sentenceGenerationConfig.llmSetting.copyWithTemperature(request.llmTemperature)

// Get prompt
val prompt = sentenceGenerationConfig.prompt ?: sentenceGenerationConfig.initPrompt()

// Create the inputs map
val inputs = mapOf(
"locale" to request.locale,
Expand All @@ -75,18 +78,11 @@ object CompletionService {
)
)

// Create a Jinja2 prompt template
val prompt = PromptTemplate(
formatter = Formatter.JINJA2.id,
template = llmSetting.prompt,
inputs = inputs
)

// call the completion service to generate sentences
return completionService
.generateSentences(
SentenceGenerationQuery(
llmSetting, prompt,
llmSetting, prompt.copy(inputs = inputs),
ObservabilityService.getObservabilityConfiguration(namespace, botId, enabled = true)?.setting
)
)
Expand Down
7 changes: 5 additions & 2 deletions bot/admin/server/src/main/kotlin/service/RAGService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ object RAGService {
logger.info { "Deleting the RAG Configuration [namespace: $namespace, botId: $botId]" }
ragConfigurationDAO.delete(ragConfig._id)

logger.info { "Deleting the LLM secret ..." }
ragConfig.llmSetting.apiKey?.let { SecurityUtils.deleteSecret(it) }
logger.info { "Deleting the question condensing LLM secret ..." }
ragConfig.questionCondensingLlmSetting?.apiKey?.let { SecurityUtils.deleteSecret(it) }
logger.info { "Deleting the question answering LLM secret ..." }
ragConfig.questionAnsweringLlmSetting?.apiKey?.let { SecurityUtils.deleteSecret(it) }
ragConfig.llmSetting?.apiKey?.let { SecurityUtils.deleteSecret(it) }
logger.info { "Deleting the Embedding secret ..." }
ragConfig.emSetting.apiKey?.let { SecurityUtils.deleteSecret(it) }
}
Expand Down
29 changes: 21 additions & 8 deletions bot/admin/server/src/main/kotlin/service/RAGValidationService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,24 @@ object RAGValidationService {
private val vectorStoreProviderService: VectorStoreProviderService get() = injector.provide()

fun validate(ragConfig: BotRAGConfiguration): Set<ErrorMessage> {
val observabilitySetting = ObservabilityService.getObservabilityConfiguration(
ragConfig.namespace, ragConfig.botId, enabled = true
)?.setting

return mutableSetOf<ErrorMessage>().apply {
val llmErrors = llmProviderService.checkSetting(
val questionCondensingLlmErrors = llmProviderService.checkSetting(
LLMProviderSettingStatusQuery(
ragConfig.questionCondensingLlmSetting!!,
observabilitySetting
)
).getErrors("LLM setting check failed (for question condensing)")

val questionAnsweringLlmErrors = llmProviderService.checkSetting(
LLMProviderSettingStatusQuery(
ragConfig.llmSetting,
ObservabilityService.getObservabilityConfiguration(
ragConfig.namespace, ragConfig.botId, enabled = true
)?.setting
ragConfig.questionAnsweringLlmSetting!!,
observabilitySetting
)
).getErrors("LLM setting check failed")
).getErrors("LLM setting check failed (for question answering)")

val embeddingErrors = emProviderService.checkSetting(
EMProviderSettingStatusQuery(ragConfig.emSetting)
Expand All @@ -59,7 +68,11 @@ object RAGValidationService {
)?.setting

val (_, indexName) = VectorStoreUtils.getVectorStoreElements(
ragConfig.namespace, ragConfig.botId, ragConfig.indexSessionId!!, vectorStoreSetting
ragConfig.namespace,
ragConfig.botId,
ragConfig.indexSessionId!!,
ragConfig.maxDocumentsRetrieved,
vectorStoreSetting
)

vectorStoreProviderService.checkSetting(
Expand All @@ -71,7 +84,7 @@ object RAGValidationService {
).getErrors("Vector store setting check failed")
} ?: emptySet()

addAll(llmErrors + embeddingErrors + indexSessionIdErrors + vectorStoreErrors)
addAll(questionCondensingLlmErrors + questionAnsweringLlmErrors + embeddingErrors + indexSessionIdErrors + vectorStoreErrors)
}
}

Expand Down
26 changes: 19 additions & 7 deletions bot/admin/server/src/test/kotlin/service/RAGServiceTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import ai.tock.bot.test.TFunction
import ai.tock.bot.test.TRunnable
import ai.tock.bot.test.TSupplier
import ai.tock.bot.test.TestCase
import ai.tock.genai.orchestratorclient.requests.PromptTemplate
import ai.tock.genai.orchestratorclient.responses.ProviderSettingStatusResponse
import ai.tock.genai.orchestratorclient.services.EMProviderService
import ai.tock.genai.orchestratorclient.services.LLMProviderService
Expand Down Expand Up @@ -65,25 +66,36 @@ class RAGServiceTest : AbstractTest() {
const val INDEX_SESSION_ID = "1010101"

private val DEFAULT_RAG_CONFIG = BotRAGConfigurationDTO(
id = "ragId",
id = "ragId",
namespace = NAMESPACE,
botId = BOT_ID,
enabled = false,
llmSetting = OpenAILLMSettingDTO(
questionCondensingLlmSetting = OpenAILLMSettingDTO(
apiKey = "apikey",
model = MODEL,
prompt = PROMPT,
temperature = TEMPERATURE,
baseUrl = "https://api.openai.com/v1"
),
questionCondensingPrompt = PromptTemplate(template = PROMPT),
questionAnsweringLlmSetting = OpenAILLMSettingDTO(
apiKey = "apikey",
model = MODEL,
temperature = TEMPERATURE,
baseUrl = "https://api.openai.com/v1"
),
questionAnsweringPrompt = PromptTemplate(template = PROMPT),
emSetting = AzureOpenAIEMSettingDTO(
apiKey = "apiKey",
apiVersion = "apiVersion",
deploymentName = "deployment",
model = "model",
apiBase = "url"
),
noAnswerSentence = "No answer sentence"
noAnswerSentence = "No answer sentence",
documentsRequired = true,
debugEnabled = false,
maxDocumentsRetrieved = 2,
maxMessagesFromHistory = 2,
)

private val DEFAULT_BOT_CONFIG = aApplication.copy(namespace = NAMESPACE, botId = BOT_ID)
Expand Down Expand Up @@ -186,9 +198,9 @@ class RAGServiceTest : AbstractTest() {
Assertions.assertEquals(BOT_ID, captured.botId)
Assertions.assertEquals(true, captured.enabled)
Assertions.assertEquals(NAMESPACE, captured.namespace)
Assertions.assertEquals(PROVIDER, captured.llmSetting.provider.name)
Assertions.assertEquals(TEMPERATURE, captured.llmSetting.temperature)
Assertions.assertEquals(PROMPT, captured.llmSetting.prompt)
Assertions.assertEquals(PROVIDER, captured.questionAnsweringLlmSetting!!.provider.name)
Assertions.assertEquals(TEMPERATURE, captured.questionAnsweringLlmSetting!!.temperature)
Assertions.assertEquals(PROMPT, captured.questionAnsweringPrompt!!.template)
Assertions.assertEquals(null, captured.noAnswerStoryId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ai.tock.bot.admin.service
import ai.tock.bot.admin.bot.observability.BotObservabilityConfigurationDAO
import ai.tock.bot.admin.bot.vectorstore.BotVectorStoreConfigurationDAO
import ai.tock.bot.admin.model.BotRAGConfigurationDTO
import ai.tock.genai.orchestratorclient.requests.PromptTemplate
import ai.tock.genai.orchestratorclient.responses.ErrorInfo
import ai.tock.genai.orchestratorclient.responses.ErrorResponse
import ai.tock.genai.orchestratorclient.responses.ProviderSettingStatusResponse
Expand Down Expand Up @@ -63,7 +64,7 @@ class RAGValidationServiceTest {
}

private val openAILLMSetting = OpenAILLMSetting(
apiKey = "123-abc", model = "unavailable-model", temperature = "0.4", prompt = "How to bike in the rain",
apiKey = "123-abc", model = "unavailable-model", temperature = "0.4",
baseUrl = "https://api.openai.com/v1",
)

Expand All @@ -78,9 +79,16 @@ class RAGValidationServiceTest {
private val ragConfiguration = BotRAGConfigurationDTO(
namespace = "namespace",
botId = "botId",
llmSetting = openAILLMSetting,
questionCondensingLlmSetting = openAILLMSetting,
questionCondensingPrompt = PromptTemplate(template = "test"),
questionAnsweringLlmSetting = openAILLMSetting,
questionAnsweringPrompt = PromptTemplate(template = "How to bike in the rain"),
emSetting = azureOpenAIEMSetting,
noAnswerSentence = " No answer sentence",
documentsRequired = true,
debugEnabled = false,
maxDocumentsRetrieved = 2,
maxMessagesFromHistory = 2,
)

@Test
Expand Down Expand Up @@ -163,7 +171,7 @@ class RAGValidationServiceTest {
fun `validation of the RAG configuration when the Orchestrator returns 2 errors for LLM and 1 for Embedding model, the RAG function has not been activated`() {

// GIVEN
// - 3 errors returned by Generative AI Orchestrator for LLM (2) and EM (1)
// - 3 errors returned by Generative AI Orchestrator for LLM (4 = 2 for condensing + 2 for answering) and EM (1)
// - RAG is not enabled
every {
llmProviderService.checkSetting(any())
Expand All @@ -187,11 +195,13 @@ class RAGValidationServiceTest {
)

// THEN :
// Check that 3 errors have been found
assertEquals(2, errors.size)
// Check that 3 groups of errors have been found
assertEquals(3, errors.size)
assertEquals("10", (((errors.elementAt(0).params) as List<*>)[0] as ErrorResponse).code)
assertEquals("20", (((errors.elementAt(0).params) as List<*>)[1] as ErrorResponse).code)
assertEquals("30", (((errors.elementAt(1).params) as List<*>)[0] as ErrorResponse).code)
assertEquals("10", (((errors.elementAt(1).params) as List<*>)[0] as ErrorResponse).code)
assertEquals("20", (((errors.elementAt(1).params) as List<*>)[1] as ErrorResponse).code)
assertEquals("30", (((errors.elementAt(2).params) as List<*>)[0] as ErrorResponse).code)
}

private fun createFakeErrorResponse(code: String) = ErrorResponse(
Expand Down
Loading