Skip to content

Commit

Permalink
Make chat history available to RetrievalAugmentationAdvisor
Browse files Browse the repository at this point in the history
* Extend Query with conversation history in RetrievalAugmentationAdvisor
* Add integration tests for query compression and rewrite

Signed-off-by: Thomas Vitale <[email protected]>
  • Loading branch information
ThomasVitale committed Dec 24, 2024
1 parent d7fe07b commit 15fe2e5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ public static Builder builder() {
public AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());

// 0. Create a query from the user text and parameters.
Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
// 0. Create a query from the user text, parameters, and conversation history.
Query originalQuery = Query.builder()
.text(new PromptTemplate(request.userText(), request.userParams()).render())
.history(request.messages())
.build();

// 1. Transform original user query based on a chain of query transformers.
Query transformedQuery = originalQuery;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentReader;
Expand All @@ -34,6 +37,8 @@
import org.springframework.ai.integration.tests.TestApplication;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
Expand Down Expand Up @@ -103,6 +108,74 @@ void ragBasic() {
evaluateRelevancy(question, chatResponse);
}

@Test
void ragWithCompression() {
MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();

RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
.queryTransformers(CompressionQueryTransformer.builder()
.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
.build())
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
.build();

ChatClient chatClient = ChatClient.builder(this.openAiChatModel)
.defaultAdvisors(memoryAdvisor, ragAdvisor)
.build();

String conversationId = "007";

ChatResponse chatResponse1 = chatClient.prompt()
.user("Where does the adventure of Anacletus and Birba take place?")
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY,
conversationId))
.call()
.chatResponse();

assertThat(chatResponse1).isNotNull();
String response1 = chatResponse1.getResult().getOutput().getText();
System.out.println(response1);

ChatResponse chatResponse2 = chatClient.prompt()
.user("Did they meet any cow?")
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY,
conversationId))
.call()
.chatResponse();

assertThat(chatResponse2).isNotNull();
String response2 = chatResponse2.getResult().getOutput().getText();
System.out.println(response2);
assertThat(response2.toLowerCase()).containsIgnoringCase("Fergus");
}

@Test
void ragWithRewrite() {
String question = "Where are the main characters going?";

RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
.queryTransformers(RewriteQueryTransformer.builder()
.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
.targetSearchSystem("vector store")
.build())
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
.build();

ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel).build().prompt()
.user(question)
.advisors(ragAdvisor)
.call()
.chatResponse();

assertThat(chatResponse).isNotNull();

String response = chatResponse.getResult().getOutput().getText();
System.out.println(response);
assertThat(response).containsIgnoringCase("Loch of the Stars");

evaluateRelevancy(question, chatResponse);
}

@Test
void ragWithTranslation() {
String question = "Hvor finder Anacletus og Birbas eventyr sted?";
Expand Down

0 comments on commit 15fe2e5

Please sign in to comment.