Skip to content

Commit

Permalink
fix bug: AiMessage text content is not copied when toolCalls are pres…
Browse files Browse the repository at this point in the history
…ent (langchain4j#1576)

## Issue
<!-- Please specify the ID of the issue this PR is addressing. For
example: "Closes langchain4j#1234" or "Fixes langchain4j#1234" -->
langchain4j#986
pr to solve same issue:
langchain4j#1069

## Change
When Azure-OpenAI returns both content and tool_calls, keep them all
instead of just keeping the tool_calls.


## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [X] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [X] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
hrhrng authored Aug 16, 2024
1 parent ac7c2b9 commit 5ec127d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
import java.util.*;

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.time.Duration.ofSeconds;
Expand Down Expand Up @@ -265,8 +264,10 @@ public void setRequired(List<String> required) {
}

public static AiMessage aiMessageFrom(ChatResponseMessage chatResponseMessage) {
String text = chatResponseMessage.getContent();

if (isNullOrEmpty(chatResponseMessage.getToolCalls())) {
return aiMessage(chatResponseMessage.getContent());
return aiMessage(text);
} else {
List<ToolExecutionRequest> toolExecutionRequests = chatResponseMessage.getToolCalls()
.stream()
Expand All @@ -280,7 +281,9 @@ public static AiMessage aiMessageFrom(ChatResponseMessage chatResponseMessage) {
.build())
.collect(toList());

return aiMessage(toolExecutionRequests);
return isNullOrBlank(text) ?
aiMessage(toolExecutionRequests) :
aiMessage(text, toolExecutionRequests);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIServiceVersion;
import com.azure.ai.openai.models.*;
import com.azure.json.JsonOptions;
import com.azure.json.JsonReader;
import com.azure.json.implementation.DefaultJsonReader;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.output.FinishReason;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.aiMessageFrom;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
Expand Down Expand Up @@ -102,4 +109,38 @@ void finishReasonFromShouldReturnCorrectFinishReason() {
FinishReason finishReason = InternalAzureOpenAiHelper.finishReasonFrom(completionsFinishReason);
assertThat(finishReason).isEqualTo(FinishReason.STOP);
}

@Test
void whenToolCallsAndContentAreBothPresentShouldReturnAiMessageWithToolExecutionRequestsAndText() throws IOException {

String functionName = "current_time";
String functionArguments = "{}";
String responseJson = "{\n" +
" \"role\": \"ASSISTANT\",\n" +
" \"content\": \"Hello\",\n" +
" \"tool_calls\": [\n" +
" {\n" +
" \"type\": \"function\",\n" +
" \"function\": {\n" +
" \"name\": \"current_time\",\n" +
" \"arguments\": \"{}\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }";
ChatResponseMessage responseMessage;
try (JsonReader jsonReader = DefaultJsonReader.fromString(responseJson, new JsonOptions())) {
responseMessage = ChatResponseMessage.fromJson(jsonReader);
}

AiMessage aiMessage = aiMessageFrom(responseMessage);

assertThat(aiMessage.text()).isEqualTo("Hello");
assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest
.builder()
.name(functionName)
.arguments(functionArguments)
.build()
);
}
}

0 comments on commit 5ec127d

Please sign in to comment.