Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refact AdvisedRequest to avoid userText rendering #2020

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu
}
}

return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions,
return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.advisorText, inputRequest.chatOptions,
media, inputRequest.functionNames, inputRequest.functionCallbacks, messages, inputRequest.userParams,
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext,
inputRequest.toolContext);
Expand All @@ -138,7 +138,7 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise
return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(),
advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(),
advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(),
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(),
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), advisedRequest.advisorText(),
advisedRequest.advisorParams(), observationRegistry, customObservationConvention,
advisedRequest.toolContext());
}
Expand Down Expand Up @@ -605,20 +605,23 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
@Nullable
private String systemText;

@Nullable
private String advisorText;

@Nullable
private ChatOptions chatOptions;

/* copy constructor */
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorText, ccr.advisorParams,
ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
List<FunctionCallback> functionCallbacks, List<Message> messages, List<String> functionNames,
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors, @Nullable String advisorText,
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention,
Map<String, Object> toolContext) {
Expand Down Expand Up @@ -649,6 +652,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
this.messages.addAll(messages);
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorText = advisorText;
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.customObservationConvention = customObservationConvention != null ? customObservationConvention
Expand Down Expand Up @@ -778,6 +782,10 @@ public Builder mutate() {
builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams));
}

if (StringUtils.hasText(this.advisorText)) {
builder.defaultSystem(s -> s.text(this.advisorText).params(this.advisorParams));
}

if (this.chatOptions != null) {
builder.defaultOptions(this.chatOptions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(),
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
List.of(), List.of(), List.of(), null, List.of(), null, Map.of(), observationRegistry,
customObservationConvention, Map.of());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ private AdvisedRequest before(AdvisedRequest request) {
AdvisedRequest advisedRequest = AdvisedRequest.from(request).messages(advisedMessages).build();

// 4. Add the new user input to the conversation memory.
UserMessage userMessage = new UserMessage(request.userText(), request.media());
String renderedUserText = request.renderUserText();
UserMessage userMessage = new UserMessage(renderedUserText, request.media());
this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage);

return advisedRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ private AdvisedRequest before(AdvisedRequest request) {
.build();

// 4. Add the new user input to the conversation memory.
UserMessage userMessage = new UserMessage(request.userText(), request.media());
String renderedUserText = request.renderUserText();
UserMessage userMessage = new UserMessage(renderedUserText, request.media());
this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage);

return advisedRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ private AdvisedRequest before(AdvisedRequest request) {
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;

// 2. Search for similar documents in the vector store.
String query = new PromptTemplate(request.userText(), request.userParams()).render();
String query = request.renderUserText();
var searchRequestToUse = SearchRequest.from(this.searchRequest)
.query(query)
.filterExpression(doGetFilterExpression(context))
Expand All @@ -236,7 +236,7 @@ private AdvisedRequest before(AdvisedRequest request) {
advisedUserParams.put("question_answer_context", documentContext);

AdvisedRequest advisedRequest = AdvisedRequest.from(request)
.userText(advisedUserText)
.advisorText(advisedUserText)
.userParams(advisedUserParams)
.adviseContext(context)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());

// 0. Create a query from the user text and parameters.
Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
Query originalQuery = new Query(request.renderUserText());

// 1. Transform original user query based on a chain of query transformers.
Query transformedQuery = originalQuery;
Expand All @@ -129,10 +129,10 @@ public AdvisedRequest before(AdvisedRequest request) {
context.put(DOCUMENT_CONTEXT, documents);

// 5. Augment user query with the document contextual data.
Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);
Query augmentedQuery = this.queryAugmenter.augment(new Query(request.userText()), documents);

// 6. Update advised request with augmented prompt.
return AdvisedRequest.from(request).userText(augmentedQuery.text()).adviseContext(context).build();
return AdvisedRequest.from(request).advisorText(augmentedQuery.text()).adviseContext(context).build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ public String getName() {

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
String renderedUserText = advisedRequest.renderUserText();

if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
&& this.sensitiveWords.stream().anyMatch(renderedUserText::contains)) {

return createFailureResponse(advisedRequest);
}
Expand All @@ -86,9 +87,10 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis

@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
String renderedUserText = advisedRequest.renderUserText();

if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
&& this.sensitiveWords.stream().anyMatch(renderedUserText::contains)) {
return Flux.just(createFailureResponse(advisedRequest));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ private AdvisedRequest before(AdvisedRequest request) {
else {
advisedSystemText = this.systemTextAdvise;
}
String renderedUserText = request.renderUserText();

var searchRequest = SearchRequest.builder()
.query(request.userText())
.query(renderedUserText)
.topK(this.doGetChatMemoryRetrieveSize(request.adviseContext()))
.filterExpression(
DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(request.adviseContext()) + "'")
Expand All @@ -159,7 +160,7 @@ private AdvisedRequest before(AdvisedRequest request) {
.systemParams(advisedSystemParams)
.build();

UserMessage userMessage = new UserMessage(request.userText(), request.media());
UserMessage userMessage = new UserMessage(renderedUserText, request.media());
this.getChatMemoryStore()
.write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext())));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public record AdvisedRequest(
String userText,
@Nullable
String systemText,
String advisorText,
@Nullable
ChatOptions chatOptions,
List<Media> media,
Expand All @@ -82,8 +83,10 @@ public record AdvisedRequest(
Map<String, Object> toolContext
// @formatter:on
) {

public AdvisedRequest {
public AdvisedRequest(ChatModel chatModel, String userText, @Nullable
String systemText, @Nullable
String advisorText, @Nullable
ChatOptions chatOptions, List<Media> media, List<String> functionNames, List<FunctionCallback> functionCallbacks, List<Message> messages, Map<String, Object> userParams, Map<String, Object> systemParams, List<Advisor> advisors, Map<String, Object> advisorParams, Map<String, Object> adviseContext, Map<String, Object> toolContext) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages),
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
Expand Down Expand Up @@ -112,6 +115,28 @@ public record AdvisedRequest(
Assert.notNull(toolContext, "toolContext cannot be null");
Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements");
Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements");

if (!CollectionUtils.isEmpty(userParams)) {
this.userText = userText;
this.userParams = userParams;
} else {
this.userParams = Map.of("userText", userText);
this.userText = "{userText}";
}

this.chatModel = chatModel;
this.systemText = systemText;
this.advisorText = (advisorText != null) ? advisorText : "";
this.chatOptions = chatOptions;
this.media = media;
this.functionNames = functionNames;
this.functionCallbacks = functionCallbacks;
this.messages = messages;
this.systemParams = systemParams;
this.advisors = advisors;
this.advisorParams = advisorParams;
this.adviseContext = adviseContext;
this.toolContext = toolContext;
}

public static Builder builder() {
Expand All @@ -125,6 +150,7 @@ public static Builder from(AdvisedRequest from) {
builder.chatModel = from.chatModel;
builder.userText = from.userText;
builder.systemText = from.systemText;
builder.advisorText = from.advisorText;
builder.chatOptions = from.chatOptions;
builder.media = from.media;
builder.functionNames = from.functionNames;
Expand All @@ -146,6 +172,11 @@ public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Ob
.build();
}

public String renderUserText() {
return !CollectionUtils.isEmpty(this.userParams()) ? new PromptTemplate(this.userText(), this.userParams()).render()
: this.userText();
}

public Prompt toPrompt() {
var messages = new ArrayList<>(this.messages());

Expand All @@ -157,22 +188,28 @@ public Prompt toPrompt() {
messages.add(new SystemMessage(processedSystemText));
}

String formatParam = (String) this.adviseContext().get("formatParam");

var processedUserText = StringUtils.hasText(formatParam)
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();

if (StringUtils.hasText(processedUserText)) {
Map<String, Object> userParams = new HashMap<>(this.userParams());
if (StringUtils.hasText(formatParam)) {
userParams.put("spring_ai_soc_format", formatParam);
String processedAdvisorText = this.advisorText();
if (StringUtils.hasText(this.advisorText())) {
Map<String, Object> advisorParams = new HashMap<>(this.advisorParams());
if (!CollectionUtils.isEmpty(this.userParams())) {
advisorParams.putAll(this.userParams());
}
if (!CollectionUtils.isEmpty(userParams)) {
processedUserText = new PromptTemplate(processedUserText, userParams).render();

if (!CollectionUtils.isEmpty(advisorParams)) {
processedAdvisorText = new PromptTemplate(processedAdvisorText, advisorParams).render();
}
messages.add(new UserMessage(processedUserText, this.media()));
} else {
processedAdvisorText = renderUserText();
}

String formatParam = (String) this.adviseContext().get("formatParam");

if (StringUtils.hasText(formatParam)) {
processedAdvisorText += System.lineSeparator() + formatParam;
}

messages.add(new UserMessage(processedAdvisorText, this.media()));

if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
if (!this.functionNames().isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(this.functionNames()));
Expand All @@ -199,6 +236,8 @@ public static final class Builder {

private String systemText;

private String advisorText;

private ChatOptions chatOptions;

private List<Media> media = List.of();
Expand Down Expand Up @@ -254,6 +293,16 @@ public Builder systemText(String systemText) {
return this;
}

/**
* Set the advisor text.
* @param advisorText the advisor text
* @return this {@link Builder} instance
*/
public Builder advisorText(String advisorText) {
this.advisorText = advisorText;
return this;
}

/**
* Set the chat options.
* @param chatOptions the chat options
Expand Down Expand Up @@ -495,7 +544,7 @@ public Builder withToolContext(Map<String, Object> toolContext) {
* @return a new {@link AdvisedRequest} instance
*/
public AdvisedRequest build() {
return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media,
return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.advisorText, this.chatOptions, this.media,
this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams,
this.advisors, this.advisorParams, this.adviseContext, this.toolContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1198,14 +1198,14 @@ void buildChatClientRequestSpec() {
ChatModel chatModel = mock(ChatModel.class);
DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec(
chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(),
Map.of(), ObservationRegistry.NOOP, null, Map.of());
null, Map.of(), ObservationRegistry.NOOP, null, Map.of());
assertThat(spec).isNotNull();
}

@Test
void whenChatModelIsNullThenThrow() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), null,
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(),
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), null, Map.of(),
ObservationRegistry.NOOP, null, Map.of()))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("chatModel cannot be null");
Expand All @@ -1214,7 +1214,7 @@ void whenChatModelIsNullThenThrow() {
@Test
void whenObservationRegistryIsNullThenThrow() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null,
Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), null,
Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), null, Map.of(), null,
null, Map.of()))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("observationRegistry cannot be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,31 @@ public void qaAdvisorTakesUserParameterizedUserMessagesIntoAccountForSimilarityS
assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery);
}

@Test
public void qaAdvisorTakesUserTextWithBracesIntoAccountForSimilaritySearch() {
given(this.chatModel.call(this.promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))),
ChatResponseMetadata.builder().build()));

given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture()))
.willReturn(List.of(new Document("doc1"), new Document("doc2")));

var chatClient = ChatClient.builder(this.chatModel).build();
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.builder().build());

String userText = "{ \"name\" : \"Chuck\" }";

// @formatter:off
chatClient.prompt()
.user(userText)
.advisors(qaAdvisor)
.call()
.chatResponse();
//formatter:on

var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getText();
assertThat(userPrompt).contains(userText);
assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(userText);
}

}
Loading