diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index adeeaa7ce8c..9a4fab867d7 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -38,8 +38,7 @@ import java.util.*; import static dev.langchain4j.data.message.AiMessage.aiMessage; -import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.output.FinishReason.*; import static java.time.Duration.ofSeconds; @@ -265,8 +264,10 @@ public void setRequired(List required) { } public static AiMessage aiMessageFrom(ChatResponseMessage chatResponseMessage) { + String text = chatResponseMessage.getContent(); + if (isNullOrEmpty(chatResponseMessage.getToolCalls())) { - return aiMessage(chatResponseMessage.getContent()); + return aiMessage(text); } else { List toolExecutionRequests = chatResponseMessage.getToolCalls() .stream() @@ -280,7 +281,9 @@ public static AiMessage aiMessageFrom(ChatResponseMessage chatResponseMessage) { .build()) .collect(toList()); - return aiMessage(toolExecutionRequests); + return isNullOrBlank(text) ? + aiMessage(toolExecutionRequests) : + aiMessage(text, toolExecutionRequests); } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java index 8dceabad31e..2a1fec68852 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelperTest.java @@ -4,18 +4,25 @@ import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIServiceVersion; import com.azure.ai.openai.models.*; +import com.azure.json.JsonOptions; +import com.azure.json.JsonReader; +import com.azure.json.implementation.DefaultJsonReader; +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.output.FinishReason; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.aiMessageFrom; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; @@ -102,4 +109,38 @@ void finishReasonFromShouldReturnCorrectFinishReason() { FinishReason finishReason = InternalAzureOpenAiHelper.finishReasonFrom(completionsFinishReason); assertThat(finishReason).isEqualTo(FinishReason.STOP); } + + @Test + void whenToolCallsAndContentAreBothPresentShouldReturnAiMessageWithToolExecutionRequestsAndText() throws IOException { + + String functionName = "current_time"; + String functionArguments = "{}"; + String responseJson = "{\n" + + " \"role\": \"ASSISTANT\",\n" + + " \"content\": \"Hello\",\n" + + " \"tool_calls\": [\n" + + " {\n" + + " \"type\": \"function\",\n" + + " \"function\": {\n" + + " \"name\": \"current_time\",\n" + + " \"arguments\": \"{}\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }"; + ChatResponseMessage responseMessage; + try (JsonReader jsonReader = DefaultJsonReader.fromString(responseJson, new JsonOptions())) { + responseMessage = ChatResponseMessage.fromJson(jsonReader); + } + + AiMessage aiMessage = aiMessageFrom(responseMessage); + + assertThat(aiMessage.text()).isEqualTo("Hello"); + assertThat(aiMessage.toolExecutionRequests()).containsExactly(ToolExecutionRequest + .builder() + .name(functionName) + .arguments(functionArguments) + .build() + ); + } }