diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 127390df7da..173d7938c75 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -126,7 +126,7 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu } } - return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions, + return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.advisorText, inputRequest.chatOptions, media, inputRequest.functionNames, inputRequest.functionCallbacks, messages, inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, inputRequest.toolContext); @@ -138,7 +138,7 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), - advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), + advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), advisedRequest.advisorText(), advisedRequest.advisorParams(), observationRegistry, customObservationConvention, advisedRequest.toolContext()); } @@ -605,20 +605,23 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe @Nullable private String systemText; + @Nullable + private String advisorText; + @Nullable private ChatOptions chatOptions; /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, - ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, + ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorText, ccr.advisorParams, ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, List functionCallbacks, List messages, List functionNames, - List media, @Nullable ChatOptions chatOptions, List advisors, + List media, @Nullable ChatOptions chatOptions, List advisors, @Nullable String advisorText, Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention customObservationConvention, Map toolContext) { @@ -649,6 +652,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.messages.addAll(messages); this.media.addAll(media); this.advisors.addAll(advisors); + this.advisorText = advisorText; this.advisorParams.putAll(advisorParams); this.observationRegistry = observationRegistry; this.customObservationConvention = customObservationConvention != null ? customObservationConvention @@ -778,6 +782,10 @@ public Builder mutate() { builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams)); } + if (StringUtils.hasText(this.advisorText)) { + builder.defaultSystem(s -> s.text(this.advisorText).params(this.advisorParams)); + } + if (this.chatOptions != null) { builder.defaultOptions(this.chatOptions); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 7f4fbdbff17..50a04f6b1da 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -63,7 +63,7 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(), - List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, + List.of(), List.of(), List.of(), null, List.of(), null, Map.of(), observationRegistry, customObservationConvention, Map.of()); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index cd1c53cb301..a070df1ea17 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -94,7 +94,8 @@ private AdvisedRequest before(AdvisedRequest request) { AdvisedRequest advisedRequest = AdvisedRequest.from(request).messages(advisedMessages).build(); // 4. Add the new user input to the conversation memory. - UserMessage userMessage = new UserMessage(request.userText(), request.media()); + String renderedUserText = request.renderUserText(); + UserMessage userMessage = new UserMessage(renderedUserText, request.media()); this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index ebcb5f33d8b..b72b732aa40 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -127,7 +127,8 @@ private AdvisedRequest before(AdvisedRequest request) { .build(); // 4. Add the new user input to the conversation memory. - UserMessage userMessage = new UserMessage(request.userText(), request.media()); + String renderedUserText = request.renderUserText(); + UserMessage userMessage = new UserMessage(renderedUserText, request.media()); this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 57d1406d2f9..06285f1d745 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -216,7 +216,7 @@ private AdvisedRequest before(AdvisedRequest request) { String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise; // 2. Search for similar documents in the vector store. - String query = new PromptTemplate(request.userText(), request.userParams()).render(); + String query = request.renderUserText(); var searchRequestToUse = SearchRequest.from(this.searchRequest) .query(query) .filterExpression(doGetFilterExpression(context)) @@ -236,7 +236,7 @@ private AdvisedRequest before(AdvisedRequest request) { advisedUserParams.put("question_answer_context", documentContext); AdvisedRequest advisedRequest = AdvisedRequest.from(request) - .userText(advisedUserText) + .advisorText(advisedUserText) .userParams(advisedUserParams) .adviseContext(context) .build(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index a9f78f985c9..78642e40f84 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -103,7 +103,7 @@ public AdvisedRequest before(AdvisedRequest request) { Map context = new HashMap<>(request.adviseContext()); // 0. Create a query from the user text and parameters. - Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render()); + Query originalQuery = new Query(request.renderUserText()); // 1. Transform original user query based on a chain of query transformers. Query transformedQuery = originalQuery; @@ -129,10 +129,10 @@ public AdvisedRequest before(AdvisedRequest request) { context.put(DOCUMENT_CONTEXT, documents); // 5. Augment user query with the document contextual data. - Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents); + Query augmentedQuery = this.queryAugmenter.augment(new Query(request.userText()), documents); // 6. Update advised request with augmented prompt. - return AdvisedRequest.from(request).userText(augmentedQuery.text()).adviseContext(context).build(); + return AdvisedRequest.from(request).advisorText(augmentedQuery.text()).adviseContext(context).build(); } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java index 6eb56a1c85b..2022c22a597 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -74,9 +74,10 @@ public String getName() { @Override public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + String renderedUserText = advisedRequest.renderUserText(); if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + && this.sensitiveWords.stream().anyMatch(renderedUserText::contains)) { return createFailureResponse(advisedRequest); } @@ -86,9 +87,10 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis @Override public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + String renderedUserText = advisedRequest.renderUserText(); if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + && this.sensitiveWords.stream().anyMatch(renderedUserText::contains)) { return Flux.just(createFailureResponse(advisedRequest)); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index f25279c96c7..32e004ce536 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -137,9 +137,10 @@ private AdvisedRequest before(AdvisedRequest request) { else { advisedSystemText = this.systemTextAdvise; } + String renderedUserText = request.renderUserText(); var searchRequest = SearchRequest.builder() - .query(request.userText()) + .query(renderedUserText) .topK(this.doGetChatMemoryRetrieveSize(request.adviseContext())) .filterExpression( DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(request.adviseContext()) + "'") @@ -159,7 +160,7 @@ private AdvisedRequest before(AdvisedRequest request) { .systemParams(advisedSystemParams) .build(); - UserMessage userMessage = new UserMessage(request.userText(), request.media()); + UserMessage userMessage = new UserMessage(renderedUserText, request.media()); this.getChatMemoryStore() .write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext()))); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index daef3980fb2..9544613ce88 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -68,6 +68,7 @@ public record AdvisedRequest( String userText, @Nullable String systemText, + String advisorText, @Nullable ChatOptions chatOptions, List media, @@ -82,8 +83,10 @@ public record AdvisedRequest( Map toolContext // @formatter:on ) { - - public AdvisedRequest { + public AdvisedRequest(ChatModel chatModel, String userText, @Nullable + String systemText, @Nullable + String advisorText, @Nullable + ChatOptions chatOptions, List media, List functionNames, List functionCallbacks, List messages, Map userParams, Map systemParams, List advisors, Map advisorParams, Map adviseContext, Map toolContext) { Assert.notNull(chatModel, "chatModel cannot be null"); Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages), "userText cannot be null or empty unless messages are provided and contain Tool Response message."); @@ -112,6 +115,28 @@ public record AdvisedRequest( Assert.notNull(toolContext, "toolContext cannot be null"); Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements"); + + if (!CollectionUtils.isEmpty(userParams)) { + this.userText = userText; + this.userParams = userParams; + } else { + this.userParams = Map.of("userText", userText); + this.userText = "{userText}"; + } + + this.chatModel = chatModel; + this.systemText = systemText; + this.advisorText = (advisorText != null) ? advisorText : ""; + this.chatOptions = chatOptions; + this.media = media; + this.functionNames = functionNames; + this.functionCallbacks = functionCallbacks; + this.messages = messages; + this.systemParams = systemParams; + this.advisors = advisors; + this.advisorParams = advisorParams; + this.adviseContext = adviseContext; + this.toolContext = toolContext; } public static Builder builder() { @@ -125,6 +150,7 @@ public static Builder from(AdvisedRequest from) { builder.chatModel = from.chatModel; builder.userText = from.userText; builder.systemText = from.systemText; + builder.advisorText = from.advisorText; builder.chatOptions = from.chatOptions; builder.media = from.media; builder.functionNames = from.functionNames; @@ -146,6 +172,11 @@ public AdvisedRequest updateContext(Function, Map(this.messages()); @@ -157,22 +188,28 @@ public Prompt toPrompt() { messages.add(new SystemMessage(processedSystemText)); } - String formatParam = (String) this.adviseContext().get("formatParam"); - - var processedUserText = StringUtils.hasText(formatParam) - ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); - - if (StringUtils.hasText(processedUserText)) { - Map userParams = new HashMap<>(this.userParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); + String processedAdvisorText = this.advisorText(); + if (StringUtils.hasText(this.advisorText())) { + Map advisorParams = new HashMap<>(this.advisorParams()); + if (!CollectionUtils.isEmpty(this.userParams())) { + advisorParams.putAll(this.userParams()); } - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = new PromptTemplate(processedUserText, userParams).render(); + + if (!CollectionUtils.isEmpty(advisorParams)) { + processedAdvisorText = new PromptTemplate(processedAdvisorText, advisorParams).render(); } - messages.add(new UserMessage(processedUserText, this.media())); + } else { + processedAdvisorText = renderUserText(); + } + + String formatParam = (String) this.adviseContext().get("formatParam"); + + if (StringUtils.hasText(formatParam)) { + processedAdvisorText += System.lineSeparator() + formatParam; } + messages.add(new UserMessage(processedAdvisorText, this.media())); + if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { if (!this.functionNames().isEmpty()) { functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); @@ -199,6 +236,8 @@ public static final class Builder { private String systemText; + private String advisorText; + private ChatOptions chatOptions; private List media = List.of(); @@ -254,6 +293,16 @@ public Builder systemText(String systemText) { return this; } + /** + * Set the advisor text. + * @param advisorText the advisor text + * @return this {@link Builder} instance + */ + public Builder advisorText(String advisorText) { + this.advisorText = advisorText; + return this; + } + /** * Set the chat options. * @param chatOptions the chat options @@ -495,7 +544,7 @@ public Builder withToolContext(Map toolContext) { * @return a new {@link AdvisedRequest} instance */ public AdvisedRequest build() { - return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, + return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.advisorText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, this.advisors, this.advisorParams, this.adviseContext, this.toolContext); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 5f7951ca5cd..f0705fcebe3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -1198,14 +1198,14 @@ void buildChatClientRequestSpec() { ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), - Map.of(), ObservationRegistry.NOOP, null, Map.of()); + null, Map.of(), ObservationRegistry.NOOP, null, Map.of()); assertThat(spec).isNotNull(); } @Test void whenChatModelIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), null, - Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), null, Map.of(), ObservationRegistry.NOOP, null, Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); @@ -1214,7 +1214,7 @@ void whenChatModelIsNullThenThrow() { @Test void whenObservationRegistryIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, - Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), null, + Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), null, Map.of(), null, null, Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("observationRegistry cannot be null"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java index 8511409a967..270b60a4fd1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java @@ -233,4 +233,31 @@ public void qaAdvisorTakesUserParameterizedUserMessagesIntoAccountForSimilarityS assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery); } + @Test + public void qaAdvisorTakesUserTextWithBracesIntoAccountForSimilaritySearch() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))), + ChatResponseMetadata.builder().build())); + + given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture())) + .willReturn(List.of(new Document("doc1"), new Document("doc2"))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.builder().build()); + + String userText = "{ \"name\" : \"Chuck\" }"; + + // @formatter:off + chatClient.prompt() + .user(userText) + .advisors(qaAdvisor) + .call() + .chatResponse(); + //formatter:on + + var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getText(); + assertThat(userPrompt).contains(userText); + assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(userText); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java index e697b576a2e..8002be46454 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java @@ -123,4 +123,66 @@ void theOneWithTheDocumentRetriever() { """); } + @Test + void whenUserTextWithBracesThenDoesNotThrow() { + // Chat Model + var chatModel = mock(ChatModel.class); + var promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())).willReturn(ChatResponse.builder() + .withGenerations(List.of(new Generation(new AssistantMessage("Felix Felicis")))) + .build()); + + // Document Retriever + var documentContext = List.of(Document.builder().id("1").text("doc1").build(), + Document.builder().id("2").text("doc2").build()); + var documentRetriever = mock(DocumentRetriever.class); + var queryCaptor = ArgumentCaptor.forClass(Query.class); + given(documentRetriever.retrieve(queryCaptor.capture())).willReturn(documentContext); + + // Advisor + var advisor = RetrievalAugmentationAdvisor.builder().documentRetriever(documentRetriever).build(); + + // Chat Client + var chatClient = ChatClient.builder(chatModel) + .defaultAdvisors(advisor) + .defaultSystem("You are a wizard!") + .build(); + + // Call + String userText = "{ \"name\" : \"Chuck\" }"; + var chatResponse = chatClient.prompt() + .user(userText) + .call() + .chatResponse(); + + // Verify + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("Felix Felicis"); + assertThat(chatResponse.getMetadata().>get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT)) + .containsAll(documentContext); + + var query = queryCaptor.getValue(); + assertThat(query.text()) + .isEqualTo(userText); + + var prompt = promptCaptor.getValue(); + assertThat(prompt.getContents()).contains(""" + Context information is below. + + --------------------- + """); + assertThat(prompt.getContents()).contains(""" + --------------------- + + Given the context information and no prior knowledge, answer the query. + + Follow these rules: + + 1. If the answer is not in the context, just say that you don't know. + 2. Avoid statements like "Based on the context..." or "The provided information...". + + Query: { \"name\" : \"Chuck\" } + + Answer: + """); + } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java index 3a62b7a1fdd..83ee5c31baf 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java @@ -103,7 +103,7 @@ private void validate(String content, CapturedOutput output) { UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("Please answer my question XYZ"); - assertThat(output.getOut()).contains("request: AdvisedRequest", "userText=Please answer my question XYZ"); + assertThat(output.getOut()).contains("request: AdvisedRequest", "userText={userText}", "userParams={userText=Please answer my question XYZ}"); assertThat(output.getOut()).contains("response:", "finishReason"); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java index 4f2d4415ec2..20ce6b7b925 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java @@ -36,14 +36,14 @@ class AdvisedRequestTests { @Test void buildAdvisedRequest() { - AdvisedRequest request = new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + AdvisedRequest request = new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()); assertThat(request).isNotNull(); } @Test void whenChatModelIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(null, "user", null, null, List.of(), List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(null, "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); @@ -51,7 +51,7 @@ void whenChatModelIsNullThenThrows() { @Test void whenUserTextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage( @@ -60,7 +60,7 @@ void whenUserTextIsNullThenThrows() { @Test void whenUserTextIsEmptyThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage( @@ -69,7 +69,7 @@ void whenUserTextIsEmptyThenThrows() { @Test void whenMediaIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, null, List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("media cannot be null"); @@ -77,7 +77,7 @@ void whenMediaIsNullThenThrows() { @Test void whenFunctionNamesIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), null, + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), null, List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("functionNames cannot be null"); @@ -85,7 +85,7 @@ void whenFunctionNamesIsNullThenThrows() { @Test void whenFunctionCallbacksIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), null, List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("functionCallbacks cannot be null"); @@ -93,7 +93,7 @@ void whenFunctionCallbacksIsNullThenThrows() { @Test void whenMessagesIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), null, Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("messages cannot be null"); @@ -101,7 +101,7 @@ void whenMessagesIsNullThenThrows() { @Test void whenUserParamsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), null, Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("userParams cannot be null"); @@ -109,7 +109,7 @@ void whenUserParamsIsNullThenThrows() { @Test void whenSystemParamsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), null, List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("systemParams cannot be null"); @@ -117,7 +117,7 @@ void whenSystemParamsIsNullThenThrows() { @Test void whenAdvisorsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), null, Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("advisors cannot be null"); @@ -125,7 +125,7 @@ void whenAdvisorsIsNullThenThrows() { @Test void whenAdvisorParamsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), null, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("advisorParams cannot be null"); @@ -133,7 +133,7 @@ void whenAdvisorParamsIsNullThenThrows() { @Test void whenAdviseContextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), null, Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("adviseContext cannot be null"); @@ -141,7 +141,7 @@ void whenAdviseContextIsNullThenThrows() { @Test void whenToolContextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("toolContext cannot be null"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java index 31d017d749e..6ca75f72e8e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java @@ -61,7 +61,7 @@ void whenEmptyInputContentThenReturnOriginalContext() { ChatClientObservationConvention customObservationConvention = null; var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, + List.of(), List.of(), null, List.of(), null, Map.of(), observationRegistry, customObservationConvention, Map.of()); var expectedContext = ChatClientObservationContext.builder().withRequest(request).build(); @@ -78,7 +78,7 @@ void whenWithTextThenAugmentContext() { var request = new DefaultChatClientRequestSpec(this.chatModel, "sample user text", Map.of("up1", "upv1"), "sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null, - List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); + List.of(), null, Map.of(), observationRegistry, customObservationConvention, Map.of()); var originalContext = ChatClientObservationContext.builder().withRequest(request).build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java index cf8f644248a..0855c8717c4 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/ChatClientObservationContextTests.java @@ -46,7 +46,7 @@ class ChatClientObservationContextTests { void whenMandatoryRequestOptionsThenReturn() { var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); + List.of(), List.of(), null, List.of(), null, Map.of(), ObservationRegistry.NOOP, null, Map.of()); var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 0f1e4814277..15b4cd65a65 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -114,7 +114,7 @@ public String call(String functionInput) { @BeforeEach public void beforeEach() { this.request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), - List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); + List.of(), List.of(), List.of(), null, List.of(), null, Map.of(), ObservationRegistry.NOOP, null, Map.of()); } @Test @@ -161,7 +161,7 @@ void shouldHaveOptionalKeyValues() { var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(), List.of("function1", "function2"), List.of(), null, - List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"), + List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), "", Map.of("advParam1", "advisorParam1Value"), ObservationRegistry.NOOP, null, Map.of()); ChatClientObservationContext observationContext = ChatClientObservationContext.builder()