From 15fe2e522466ff7fa6d30be1a9459cbacfe76c05 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Tue, 24 Dec 2024 15:04:58 +0100 Subject: [PATCH] Make chat history available to RetrievalAugmentationAdvisor * Extend Query with conversation history in RetrievalAugmentationAdvisor * Add integration tests for query compression and rewrite Signed-off-by: Thomas Vitale --- .../advisor/RetrievalAugmentationAdvisor.java | 7 +- .../RetrievalAugmentationAdvisorIT.java | 73 +++++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index a9f78f985c9..6f570265c61 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -102,8 +102,11 @@ public static Builder builder() { public AdvisedRequest before(AdvisedRequest request) { Map context = new HashMap<>(request.adviseContext()); - // 0. Create a query from the user text and parameters. - Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render()); + // 0. Create a query from the user text, parameters, and conversation history. + Query originalQuery = Query.builder() + .text(new PromptTemplate(request.userText(), request.userParams()).render()) + .history(request.messages()) + .build(); // 1. Transform original user query based on a chain of query transformers. Query transformedQuery = originalQuery; diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java index 757d5ff758d..cf0d6ff5374 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java @@ -24,7 +24,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentReader; @@ -34,6 +37,8 @@ import org.springframework.ai.integration.tests.TestApplication; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander; +import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer; import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; import org.springframework.ai.reader.markdown.MarkdownDocumentReader; @@ -103,6 +108,74 @@ void ragBasic() { evaluateRelevancy(question, chatResponse); } + @Test + void ragWithCompression() { + MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build(); + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .queryTransformers(CompressionQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(this.openAiChatModel)) + .build()) + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build()) + .build(); + + ChatClient chatClient = ChatClient.builder(this.openAiChatModel) + .defaultAdvisors(memoryAdvisor, ragAdvisor) + .build(); + + String conversationId = "007"; + + ChatResponse chatResponse1 = chatClient.prompt() + .user("Where does the adventure of Anacletus and Birba take place?") + .advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, + conversationId)) + .call() + .chatResponse(); + + assertThat(chatResponse1).isNotNull(); + String response1 = chatResponse1.getResult().getOutput().getText(); + System.out.println(response1); + + ChatResponse chatResponse2 = chatClient.prompt() + .user("Did they meet any cow?") + .advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, + conversationId)) + .call() + .chatResponse(); + + assertThat(chatResponse2).isNotNull(); + String response2 = chatResponse2.getResult().getOutput().getText(); + System.out.println(response2); + assertThat(response2.toLowerCase()).containsIgnoringCase("Fergus"); + } + + @Test + void ragWithRewrite() { + String question = "Where are the main characters going?"; + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .queryTransformers(RewriteQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(this.openAiChatModel)) + .targetSearchSystem("vector store") + .build()) + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build()) + .build(); + + ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel).build().prompt() + .user(question) + .advisors(ragAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + + String response = chatResponse.getResult().getOutput().getText(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Loch of the Stars"); + + evaluateRelevancy(question, chatResponse); + } + @Test void ragWithTranslation() { String question = "Hvor finder Anacletus og Birbas eventyr sted?";