Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vsct-jburet committed Jul 11, 2024
1 parent 788a33a commit 15cf039
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
52 changes: 38 additions & 14 deletions bot/admin/server/src/test/kotlin/service/RAGServiceTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ai.tock.bot.admin.service
import ai.tock.bot.admin.AbstractTest
import ai.tock.bot.admin.BotAdminService
import ai.tock.bot.admin.answer.AnswerConfigurationType
import ai.tock.bot.admin.bot.observability.BotObservabilityConfigurationDAO
import ai.tock.bot.admin.bot.rag.BotRAGConfiguration
import ai.tock.bot.admin.bot.rag.BotRAGConfigurationDAO
import ai.tock.bot.admin.model.BotRAGConfigurationDTO
Expand All @@ -37,9 +38,11 @@ import ai.tock.genai.orchestratorcore.models.llm.OpenAILLMSettingDTO
import ai.tock.nlp.core.Intent
import ai.tock.shared.tockInternalInjector
import ai.tock.shared.withoutNamespace
import ai.tock.translator.I18nDAO
import com.github.salomonbrys.kodein.Kodein
import com.github.salomonbrys.kodein.KodeinInjector
import com.github.salomonbrys.kodein.bind
import com.github.salomonbrys.kodein.provider
import com.github.salomonbrys.kodein.singleton
import io.mockk.*
import org.junit.jupiter.api.AfterEach
Expand All @@ -66,13 +69,13 @@ class RAGServiceTest : AbstractTest() {
namespace = NAMESPACE,
botId = BOT_ID,
enabled = false,
llmSetting = OpenAILLMSettingDTO (
llmSetting = OpenAILLMSettingDTO(
apiKey = "apikey",
model = MODEL,
prompt = PROMPT,
temperature = TEMPERATURE
),
emSetting = AzureOpenAIEMSettingDTO (
emSetting = AzureOpenAIEMSettingDTO(
apiKey = "apiKey",
apiVersion = "apiVersion",
deploymentName = "deployment",
Expand All @@ -83,7 +86,8 @@ class RAGServiceTest : AbstractTest() {

private val DEFAULT_BOT_CONFIG = aApplication.copy(namespace = NAMESPACE, botId = BOT_ID)

private fun getRAGConfigurationDTO(enabled: Boolean, indexSessionId: String? = null) = DEFAULT_RAG_CONFIG.copy(enabled = enabled, indexSessionId = indexSessionId)
private fun getRAGConfigurationDTO(enabled: Boolean, indexSessionId: String? = null) =
DEFAULT_RAG_CONFIG.copy(enabled = enabled, indexSessionId = indexSessionId)

init {
tockInternalInjector = KodeinInjector()
Expand All @@ -92,6 +96,8 @@ class RAGServiceTest : AbstractTest() {
bind<StoryDefinitionConfigurationDAO>() with singleton { storyDao }
bind<LLMProviderService>() with singleton { llmProviderService }
bind<EMProviderService>() with singleton { emProviderService }
bind<I18nDAO>() with singleton { i18nDAO }
bind<BotObservabilityConfigurationDAO>() with provider { botObservabilityConfigurationDAO }

}.also {
tockInternalInjector.inject(Kodein {
Expand All @@ -107,6 +113,9 @@ class RAGServiceTest : AbstractTest() {
private val llmProviderService: LLMProviderService = mockk(relaxed = false)
private val emProviderService: EMProviderService = mockk(relaxed = false)

private val i18nDAO: I18nDAO = mockk(relaxed = true)
private val botObservabilityConfigurationDAO: BotObservabilityConfigurationDAO = mockk(relaxed = true)

private val slot = slot<BotRAGConfiguration>()
private val storySlot = slot<StoryDefinitionConfiguration>()
}
Expand Down Expand Up @@ -137,7 +146,10 @@ class RAGServiceTest : AbstractTest() {

val captureRagAndStoryToSave: TRunnable = {
every { storyDao.save(capture(storySlot)) } returns Unit
every { ragDao.save(capture(slot)) } returns getRAGConfigurationDTO(true, INDEX_SESSION_ID).toBotRAGConfiguration()
every { ragDao.save(capture(slot)) } returns getRAGConfigurationDTO(
true,
INDEX_SESSION_ID
).toBotRAGConfiguration()
}

val callServiceSave: TFunction<SaveFnEntry?, Unit> = {
Expand All @@ -147,7 +159,16 @@ class RAGServiceTest : AbstractTest() {

val daoSaveByFnIsCalledOnce: TRunnable = {
verify(exactly = 1) { storyDao.save(any()) }
verify(exactly = 1) { ragDao.save(eq(getRAGConfigurationDTO(true, INDEX_SESSION_ID).toBotRAGConfiguration())) }
verify(exactly = 1) {
ragDao.save(
eq(
getRAGConfigurationDTO(
true,
INDEX_SESSION_ID
).toBotRAGConfiguration()
)
)
}
}

val findCurrentUnknownFnNotCalled: TRunnable = {
Expand All @@ -171,11 +192,11 @@ class RAGServiceTest : AbstractTest() {


TestCase<SaveFnEntry, Unit>("Save valid RAG Configuration that does not exist yet").given(
"An application name and a valid request",
entry
).and(
"Rag Config not exist with request name or label and the given application name", ragNotYetExists
).and("The rag config in database is captured", captureRagAndStoryToSave)
"An application name and a valid request",
entry
).and(
"Rag Config not exist with request name or label and the given application name", ragNotYetExists
).and("The rag config in database is captured", captureRagAndStoryToSave)
.and("The LLM and EM setting are valid", checkLlmAndEmSetting)
.`when`("RagService's save method is called", callServiceSave)
.then("The dao's saveEnableRagRequest must be called exactly once", daoSaveByFnIsCalledOnce)
Expand Down Expand Up @@ -255,7 +276,8 @@ class RAGServiceTest : AbstractTest() {
"""
- no story is saved
- rag configuration is saved
""".trimIndent(), checks)
""".trimIndent(), checks
)
.run()

}
Expand Down Expand Up @@ -340,7 +362,8 @@ class RAGServiceTest : AbstractTest() {
- unknown story is saved
- unknown story is disabled
- rag configuration is saved
""".trimIndent(), checks)
""".trimIndent(), checks
)
.run()
}

Expand Down Expand Up @@ -423,11 +446,12 @@ class RAGServiceTest : AbstractTest() {
- unknown story is saved
- unknown story is enabled
- rag configuration is saved
""".trimIndent(), checks)
""".trimIndent(), checks
)
.run()
}

}


typealias SaveFnEntry = BotRAGConfigurationDTO
typealias SaveFnEntry = BotRAGConfigurationDTO
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package ai.tock.bot.admin.service

import ai.tock.bot.admin.bot.observability.BotObservabilityConfigurationDAO
import ai.tock.bot.admin.model.BotRAGConfigurationDTO
import ai.tock.genai.orchestratorclient.responses.ErrorInfo
import ai.tock.genai.orchestratorclient.responses.ErrorResponse
Expand Down Expand Up @@ -43,6 +44,7 @@ class RAGValidationServiceTest {
Kodein.Module {
bind<LLMProviderService>() with singleton { llmProviderService }
bind<EMProviderService>() with singleton { emProviderService }
bind<BotObservabilityConfigurationDAO>() with singleton { botObservabilityConfigurationDAO }
}.also {
tockInternalInjector.inject(Kodein {
import(it)
Expand All @@ -52,6 +54,7 @@ class RAGValidationServiceTest {

private val llmProviderService: LLMProviderService = mockk(relaxed = false)
private val emProviderService: EMProviderService = mockk(relaxed = false)
private val botObservabilityConfigurationDAO: BotObservabilityConfigurationDAO = mockk(relaxed = true)
}

private val openAILLMSetting = OpenAILLMSetting(
Expand Down Expand Up @@ -190,4 +193,4 @@ class RAGValidationServiceTest {
detail = "detail",
info = ErrorInfo(provider = "provider", error = "error", cause = "cause", request = "request")
)
}
}
3 changes: 3 additions & 0 deletions bot/engine/src/test/kotlin/engine/BotEngineTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ai.tock.bot.engine

import ai.tock.bot.admin.bot.BotApplicationConfiguration
import ai.tock.bot.admin.bot.BotApplicationConfigurationDAO
import ai.tock.bot.admin.bot.observability.BotObservabilityConfigurationDAO
import ai.tock.bot.admin.bot.rag.BotRAGConfigurationDAO
import ai.tock.bot.admin.story.StoryDefinitionConfigurationDAO
import ai.tock.bot.connector.Connector
Expand Down Expand Up @@ -83,6 +84,7 @@ abstract class BotEngineTest {
val translator: TranslatorEngine = mockk(relaxed = true)
val storyDefinitionConfigurationDAO: StoryDefinitionConfigurationDAO = mockk(relaxed = true)
val featureDAO: FeatureDAO = mockk(relaxed = true)
val botObservabilityConfigurationDAO : BotObservabilityConfigurationDAO = mockk(relaxed = true)

val entityA = Entity(EntityType("a"), "a")
val entityAValue = NlpEntityValue(0, 1, entityA, null, false)
Expand Down Expand Up @@ -131,6 +133,7 @@ abstract class BotEngineTest {
bind<StoryDefinitionConfigurationDAO>() with provider { storyDefinitionConfigurationDAO }
bind<FeatureDAO>() with provider { featureDAO }
bind<BotRAGConfigurationDAO>() with provider { botRAGConfigurationDAO }
bind<BotObservabilityConfigurationDAO>() with provider { botObservabilityConfigurationDAO }
}
}

Expand Down

0 comments on commit 15cf039

Please sign in to comment.