diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/ToolRuntimeHints.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/ToolRuntimeHints.java new file mode 100644 index 00000000000..5509813ba53 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/ToolRuntimeHints.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.aot; + +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +/** + * Registers runtime hints for the tool calling APIs. + * + * @author Thomas Vitale + */ +public class ToolRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + hints.reflection().registerType(DefaultToolCallResultConverter.class, mcs); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index ae017b66357..aebbe8d58b6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -215,6 +215,12 @@ interface ChatClientRequestSpec { ChatClientRequestSpec options(T options); + ChatClientRequestSpec tools(String... toolNames); + + ChatClientRequestSpec tools(Object... toolObjects); + + ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks); + /** * @deprecated use {@link #functions(FunctionCallback...)} instead. */ @@ -293,6 +299,12 @@ interface Builder { Builder defaultSystem(Consumer systemSpecConsumer); + Builder defaultTools(String... toolNames); + + Builder defaultTools(Object... toolObjects); + + Builder defaultToolCallbacks(FunctionCallback... toolCallbacks); + /** * @deprecated use {@link #defaultFunctions(FunctionCallback...)} instead. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 127390df7da..aec1d2f1578 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.springframework.ai.tool.ToolCallbacks; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -782,10 +783,9 @@ public Builder mutate() { builder.defaultOptions(this.chatOptions); } - // workaround to set the missing fields. - builder.defaultRequest.getMessages().addAll(this.messages); - builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks); - builder.defaultRequest.getToolContext().putAll(this.toolContext); + builder.addMessages(this.messages); + builder.addToolCallbacks(this.functionCallbacks); + builder.addToolContext(this.toolContext); return builder; } @@ -836,6 +836,30 @@ public ChatClientRequestSpec options(T options) { return this; } + @Override + public ChatClientRequestSpec tools(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + this.functionNames.addAll(List.of(toolNames)); + return this; + } + + @Override + public ChatClientRequestSpec tools(Object... toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); + this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); + return this; + } + + @Override + public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.functionCallbacks.addAll(Arrays.asList(toolCallbacks)); + return this; + } + @Override public ChatClientRequestSpec function(String name, String description, java.util.function.Function function) { @@ -888,10 +912,7 @@ public ChatClientRequestSpec function(String name, String description, @N } public ChatClientRequestSpec functions(String... functionBeanNames) { - Assert.notNull(functionBeanNames, "functionBeanNames cannot be null"); - Assert.noNullElements(functionBeanNames, "functionBeanNames cannot contain null elements"); - this.functionNames.addAll(List.of(functionBeanNames)); - return this; + return tools(functionBeanNames); } public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 7f4fbdbff17..70b15124a7f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,10 +30,12 @@ import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallbacks; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -147,6 +149,24 @@ public Builder defaultSystem(Consumer systemSpecConsumer) { return this; } + @Override + public Builder defaultTools(String... toolNames) { + this.defaultRequest.functions(toolNames); + return this; + } + + @Override + public Builder defaultTools(Object... toolObjects) { + this.defaultRequest.functions(ToolCallbacks.from(toolObjects)); + return this; + } + + @Override + public Builder defaultToolCallbacks(FunctionCallback... toolCallbacks) { + this.defaultRequest.functions(toolCallbacks); + return this; + } + public Builder defaultFunction(String name, String description, java.util.function.Function function) { this.defaultRequest.function(name, description, function); return this; @@ -173,4 +193,17 @@ public Builder defaultToolContext(Map toolContext) { return this; } + void addMessages(List messages) { + this.defaultRequest.messages(messages); + } + + void addToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.defaultRequest.toolCallbacks(toolCallbacks.toArray(FunctionCallback[]::new)); + } + + void addToolContext(Map toolContext) { + this.defaultRequest.toolContext(toolContext); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java new file mode 100644 index 00000000000..15a8f015917 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -0,0 +1,458 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Default implementation of {@link ToolCallingChatOptions}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { + + private List toolCallbacks = new ArrayList<>(); + + private Set tools = new HashSet<>(); + + private Map toolContext = new HashMap<>(); + + @Nullable + private Boolean toolCallReturnDirect; + + @Nullable + private String model; + + @Nullable + private Double frequencyPenalty; + + @Nullable + private Integer maxTokens; + + @Nullable + private Double presencePenalty; + + @Nullable + private List stopSequences; + + @Nullable + private Double temperature; + + @Nullable + private Integer topK; + + @Nullable + private Double topP; + + @Override + public List getToolCallbacks() { + return List.copyOf(this.toolCallbacks); + } + + @Override + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = new ArrayList<>(toolCallbacks); + } + + @Override + public void setToolCallbacks(ToolCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + setToolCallbacks(List.of(toolCallbacks)); + } + + @Override + public Set getTools() { + return Set.copyOf(this.tools); + } + + @Override + public void setTools(Set tools) { + Assert.notNull(tools, "tools cannot be null"); + Assert.noNullElements(tools, "tools cannot contain null elements"); + tools.forEach(tool -> Assert.hasText(tool, "tools cannot contain empty elements")); + this.tools = new HashSet<>(tools); + } + + @Override + public void setTools(String... tools) { + Assert.notNull(tools, "tools cannot be null"); + setTools(Set.of(tools)); + } + + @Override + public Map getToolContext() { + return Map.copyOf(this.toolContext); + } + + @Override + public void setToolContext(Map toolContext) { + Assert.notNull(toolContext, "toolContext cannot be null"); + Assert.noNullElements(toolContext.keySet(), "toolContext cannot contain null keys"); + this.toolContext = new HashMap<>(toolContext); + } + + @Override + @Nullable + public Boolean getToolCallReturnDirect() { + return this.toolCallReturnDirect; + } + + @Override + public void setToolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) { + this.toolCallReturnDirect = toolCallReturnDirect; + } + + @Override + public List getFunctionCallbacks() { + return getToolCallbacks().stream().map(FunctionCallback.class::cast).toList(); + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + throw new UnsupportedOperationException("Not supported. Call setToolCallbacks instead."); + } + + @Override + public Set getFunctions() { + return getTools(); + } + + @Override + public void setFunctions(Set functions) { + setTools(functions); + } + + @Override + @Nullable + public Boolean getProxyToolCalls() { + return getToolCallReturnDirect(); + } + + @Override + public void setProxyToolCalls(@Nullable Boolean proxyToolCalls) { + setToolCallReturnDirect(proxyToolCalls != null && proxyToolCalls); + } + + @Override + @Nullable + public String getModel() { + return this.model; + } + + public void setModel(@Nullable String model) { + this.model = model; + } + + @Override + @Nullable + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(@Nullable Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + @Nullable + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(@Nullable Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + @Nullable + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(@Nullable Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + @Nullable + public List getStopSequences() { + return this.stopSequences; + } + + public void setStopSequences(@Nullable List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + @Nullable + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(@Nullable Double temperature) { + this.temperature = temperature; + } + + @Override + @Nullable + public Integer getTopK() { + return this.topK; + } + + public void setTopK(@Nullable Integer topK) { + this.topK = topK; + } + + @Override + @Nullable + public Double getTopP() { + return this.topP; + } + + public void setTopP(@Nullable Double topP) { + this.topP = topP; + } + + @Override + @SuppressWarnings("unchecked") + public T copy() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + options.setToolCallbacks(getToolCallbacks()); + options.setTools(getTools()); + options.setToolContext(getToolContext()); + options.setToolCallReturnDirect(getToolCallReturnDirect()); + options.setModel(getModel()); + options.setFrequencyPenalty(getFrequencyPenalty()); + options.setMaxTokens(getMaxTokens()); + options.setPresencePenalty(getPresencePenalty()); + options.setStopSequences(getStopSequences()); + options.setTemperature(getTemperature()); + options.setTopK(getTopK()); + options.setTopP(getTopP()); + return (T) options; + } + + /** + * Merge the given {@link ChatOptions} into this instance. + */ + public ToolCallingChatOptions merge(ChatOptions options) { + ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder(); + builder.model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.getModel()); + builder.frequencyPenalty( + options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.getFrequencyPenalty()); + builder.maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxTokens()); + builder.presencePenalty( + options.getPresencePenalty() != null ? options.getPresencePenalty() : this.getPresencePenalty()); + builder.stopSequences(options.getStopSequences() != null ? new ArrayList<>(options.getStopSequences()) + : this.getStopSequences()); + builder.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature()); + builder.topK(options.getTopK() != null ? options.getTopK() : this.getTopK()); + builder.topP(options.getTopP() != null ? options.getTopP() : this.getTopP()); + + if (options instanceof ToolCallingChatOptions toolOptions) { + List toolCallbacks = new ArrayList<>(this.toolCallbacks); + if (!CollectionUtils.isEmpty(toolOptions.getToolCallbacks())) { + toolCallbacks.addAll(toolOptions.getToolCallbacks()); + } + builder.toolCallbacks(toolCallbacks); + + Set tools = new HashSet<>(this.tools); + if (!CollectionUtils.isEmpty(toolOptions.getTools())) { + tools.addAll(toolOptions.getTools()); + } + builder.tools(tools); + + Map toolContext = new HashMap<>(this.toolContext); + if (!CollectionUtils.isEmpty(toolOptions.getToolContext())) { + toolContext.putAll(toolOptions.getToolContext()); + } + builder.toolContext(toolContext); + + builder.toolCallReturnDirect(toolOptions.getToolCallReturnDirect() != null + ? toolOptions.getToolCallReturnDirect() : this.getToolCallReturnDirect()); + } + else { + builder.toolCallbacks(this.toolCallbacks); + builder.tools(this.tools); + builder.toolContext(this.toolContext); + builder.toolCallReturnDirect(this.toolCallReturnDirect); + } + + return builder.build(); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Default implementation of {@link ToolCallingChatOptions.Builder}. + */ + public static class Builder implements ToolCallingChatOptions.Builder { + + private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + + @Override + public ToolCallingChatOptions.Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + @Override + public ToolCallingChatOptions.Builder toolCallbacks(ToolCallback... toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + @Override + public ToolCallingChatOptions.Builder tools(Set toolNames) { + this.options.setTools(toolNames); + return this; + } + + @Override + public ToolCallingChatOptions.Builder tools(String... toolNames) { + this.options.setTools(toolNames); + return this; + } + + @Override + public ToolCallingChatOptions.Builder toolContext(Map context) { + this.options.setToolContext(context); + return this; + } + + @Override + public ToolCallingChatOptions.Builder toolContext(String key, Object value) { + Assert.hasText(key, "key cannot be null"); + Assert.notNull(value, "value cannot be null"); + Map updatedToolContext = new HashMap<>(this.options.getToolContext()); + updatedToolContext.put(key, value); + this.options.setToolContext(updatedToolContext); + return this; + } + + @Override + public ToolCallingChatOptions.Builder toolCallReturnDirect(@Nullable Boolean toolCallReturnDirect) { + this.options.setToolCallReturnDirect(toolCallReturnDirect); + return this; + } + + @Override + @Deprecated // Use toolCallbacks() instead + public ToolCallingChatOptions.Builder functionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + return toolCallbacks(functionCallbacks.stream().map(ToolCallback.class::cast).toList()); + } + + @Override + @Deprecated // Use toolCallbacks() instead + public ToolCallingChatOptions.Builder functionCallbacks(FunctionCallback... functionCallbacks) { + Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + return functionCallbacks(List.of(functionCallbacks)); + } + + @Override + @Deprecated // Use tools() instead + public ToolCallingChatOptions.Builder functions(Set functions) { + return tools(functions); + } + + @Override + @Deprecated // Use tools() instead + public ToolCallingChatOptions.Builder function(String function) { + return tools(function); + } + + @Override + @Deprecated // Use toolCallReturnDirect() instead + public ToolCallingChatOptions.Builder proxyToolCalls(@Nullable Boolean proxyToolCalls) { + return toolCallReturnDirect(proxyToolCalls != null && proxyToolCalls); + } + + @Override + public ToolCallingChatOptions.Builder model(@Nullable String model) { + this.options.setModel(model); + return this; + } + + @Override + public ToolCallingChatOptions.Builder frequencyPenalty(@Nullable Double frequencyPenalty) { + this.options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + @Override + public ToolCallingChatOptions.Builder maxTokens(@Nullable Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + @Override + public ToolCallingChatOptions.Builder presencePenalty(@Nullable Double presencePenalty) { + this.options.setPresencePenalty(presencePenalty); + return this; + } + + @Override + public ToolCallingChatOptions.Builder stopSequences(@Nullable List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + @Override + public ToolCallingChatOptions.Builder temperature(@Nullable Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + @Override + public ToolCallingChatOptions.Builder topK(@Nullable Integer topK) { + this.options.setTopK(topK); + return this; + } + + @Override + public ToolCallingChatOptions.Builder topP(@Nullable Double topP) { + this.options.setTopP(topP); + return this; + } + + @Override + public ToolCallingChatOptions build() { + return this.options; + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java new file mode 100644 index 00000000000..19ab72055c2 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -0,0 +1,181 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.lang.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A set of options that can be used to configure the interaction with a chat model, + * including tool calling. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolCallingChatOptions extends FunctionCallingOptions { + + /** + * ToolCallbacks to be registered with the ChatModel. + */ + List getToolCallbacks(); + + /** + * Set the ToolCallbacks to be registered with the ChatModel. + */ + void setToolCallbacks(List toolCallbacks); + + /** + * Set the ToolCallbacks to be registered with the ChatModel. + */ + void setToolCallbacks(ToolCallback... toolCallbacks); + + /** + * Names of the tools to register with the ChatModel. + */ + Set getTools(); + + /** + * Set the names of the tools to register with the ChatModel. + */ + void setTools(Set tools); + + /** + * Set the names of the tools to register with the ChatModel. + */ + void setTools(String... tools); + + /** + * Whether the result of each tool call should be returned directly or passed back to + * the model. It can be overridden for each {@link ToolCallback} instance via + * {@link ToolMetadata#returnDirect()}. + */ + @Nullable + Boolean getToolCallReturnDirect(); + + /** + * Set whether the result of each tool call should be returned directly or passed back + * to the model. It can be overridden for each {@link ToolCallback} instance via + * {@link ToolMetadata#returnDirect()}. + */ + void setToolCallReturnDirect(@Nullable Boolean toolCallReturnDirect); + + /** + * A builder to create a new {@link ToolCallingChatOptions} instance. + */ + static Builder builder() { + return new DefaultToolCallingChatOptions.Builder(); + } + + /** + * A builder to create a {@link ToolCallingChatOptions} instance. + */ + interface Builder extends FunctionCallingOptions.Builder { + + /** + * ToolCallbacks to be registered with the ChatModel. + */ + Builder toolCallbacks(List functionCallbacks); + + /** + * ToolCallbacks to be registered with the ChatModel. + */ + Builder toolCallbacks(ToolCallback... functionCallbacks); + + /** + * Names of the tools to register with the ChatModel. + */ + Builder tools(Set toolNames); + + /** + * Names of the tools to register with the ChatModel. + */ + Builder tools(String... toolNames); + + /** + * Whether the result of each tool call should be returned directly or passed back + * to the model. It can be overridden for each {@link ToolCallback} instance via + * {@link ToolMetadata#returnDirect()}. + */ + Builder toolCallReturnDirect(@Nullable Boolean toolCallReturnDirect); + + // FunctionCallingOptions.Builder methods + + @Override + Builder toolContext(Map context); + + @Override + Builder toolContext(String key, Object value); + + @Override + @Deprecated // Use toolCallbacks() instead + Builder functionCallbacks(List functionCallbacks); + + @Override + @Deprecated // Use toolCallbacks() instead + Builder functionCallbacks(FunctionCallback... functionCallbacks); + + @Override + @Deprecated // Use tools() instead + Builder functions(Set functions); + + @Override + @Deprecated // Use tools() instead + Builder function(String function); + + @Override + @Deprecated // Use toolCallReturnDirect() instead + Builder proxyToolCalls(@Nullable Boolean proxyToolCalls); + + // ChatOptions.Builder methods + + @Override + Builder model(@Nullable String model); + + @Override + Builder frequencyPenalty(@Nullable Double frequencyPenalty); + + @Override + Builder maxTokens(@Nullable Integer maxTokens); + + @Override + Builder presencePenalty(@Nullable Double presencePenalty); + + @Override + Builder stopSequences(@Nullable List stopSequences); + + @Override + Builder temperature(@Nullable Double temperature); + + @Override + Builder topK(@Nullable Integer topK); + + @Override + Builder topP(@Nullable Double topP); + + @Override + ToolCallingChatOptions build(); + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/package-info.java new file mode 100644 index 00000000000..1255fe628b6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.model.tool; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java new file mode 100644 index 00000000000..f26b3d06eee --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool; + +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +/** + * Represents a tool whose execution can be triggered by an AI model. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolCallback extends FunctionCallback { + + /** + * Definition used by the AI model to determine when and how to call the tool. + */ + ToolDefinition getToolDefinition(); + + /** + * Metadata providing additional information on how to handle the tool. + */ + default ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + default String getName() { + return getToolDefinition().name(); + } + + @Override + default String getDescription() { + return getToolDefinition().description(); + } + + @Override + default String getInputTypeSchema() { + return getToolDefinition().inputTypeSchema(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java new file mode 100644 index 00000000000..e5e4d01319c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool; + +/** + * Provides {@link ToolCallback} instances for tools defined in different sources. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolCallbackProvider { + + ToolCallback[] getToolCallbacks(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbacks.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbacks.java new file mode 100644 index 00000000000..10d69dd7530 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallbacks.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool; + +import org.springframework.ai.tool.method.MethodToolCallbackProvider; + +/** + * Provides {@link ToolCallback} instances for tools defined in different sources. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class ToolCallbacks { + + private ToolCallbacks() { + } + + public static ToolCallback[] from(Object... sources) { + return MethodToolCallbackProvider.builder().toolObjects(sources).build().getToolCallbacks(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/annotation/Tool.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/annotation/Tool.java new file mode 100644 index 00000000000..2dc1455cffd --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/annotation/Tool.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.annotation; + +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolExecutionMode; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a method as a tool in Spring AI. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface Tool { + + /** + * The name of the tool. If not provided, the method name will be used. + */ + String name() default ""; + + /** + * The description of the tool. If not provided, the method name will be used. + */ + String value() default ""; + + /** + * How the tool should be executed. + */ + ToolExecutionMode executionMode() default ToolExecutionMode.BLOCKING; + + /** + * Whether the tool result should be returned directly or passed back to the model. + */ + boolean returnDirect() default false; + + /** + * The class to use to convert the tool call result to a String. + */ + Class resultConverter() default DefaultToolCallResultConverter.class; + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/annotation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/annotation/package-info.java new file mode 100644 index 00000000000..5f31c3b576c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/annotation/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.annotation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java new file mode 100644 index 00000000000..83c9327bf12 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.definition; + +import org.springframework.util.Assert; + +/** + * Default implementation of {@link ToolDefinition}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public record DefaultToolDefinition(String name, String description, String inputTypeSchema) implements ToolDefinition { + + public DefaultToolDefinition { + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.hasText(inputTypeSchema, "inputTypeSchema cannot be null or empty"); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String name; + + private String description; + + private String inputTypeSchema; + + private Builder() { + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputTypeSchema(String inputTypeSchema) { + this.inputTypeSchema = inputTypeSchema; + return this; + } + + public DefaultToolDefinition build() { + return new DefaultToolDefinition(name, description, inputTypeSchema); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java new file mode 100644 index 00000000000..dcc42297482 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.definition; + +import org.springframework.ai.tool.util.ToolUtils; +import org.springframework.ai.util.json.JsonSchemaGenerator; + +import java.lang.reflect.Method; + +/** + * Definition used by the AI model to determine when and how to call the tool. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolDefinition { + + /** + * The tool name. Unique within the tool set provided to a model. + */ + String name(); + + /** + * The tool description, used by the AI model to determine what the tool does. + */ + String description(); + + /** + * The JSON Schema of the parameters used to call the tool. + */ + String inputTypeSchema(); + + /** + * Create a default {@link ToolDefinition} builder. + */ + static DefaultToolDefinition.Builder builder() { + return DefaultToolDefinition.builder(); + } + + /** + * Create a default {@link ToolDefinition} instance from a {@link Method}. + */ + static ToolDefinition from(Method method) { + return DefaultToolDefinition.builder() + .name(ToolUtils.getToolName(method)) + .description(ToolUtils.getToolDescription(method)) + .inputTypeSchema(JsonSchemaGenerator.generateForMethodInput(method)) + .build(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/package-info.java new file mode 100644 index 00000000000..b268ea8a9db --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.definition; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java new file mode 100644 index 00000000000..8eebbfa7ae1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +import org.springframework.ai.util.json.JsonParser; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A default implementation of {@link ToolCallResultConverter}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class DefaultToolCallResultConverter implements ToolCallResultConverter { + + @Override + public String apply(@Nullable Object result, Class returnType) { + Assert.notNull(returnType, "returnType cannot be null"); + if (returnType == Void.TYPE) { + return "Done"; + } + else { + return JsonParser.toJson(result); + } + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java new file mode 100644 index 00000000000..cb08edf768b --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +import org.springframework.lang.Nullable; + +import java.util.function.BiFunction; + +/** + * A functional interface to convert tool call results to a String that can be sent back + * to the AI model. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@FunctionalInterface +public interface ToolCallResultConverter extends BiFunction, String> { + + /** + * Given an Object returned by a tool, convert it to a String compatible with the + * given class type. + */ + String apply(@Nullable Object result, Class returnType); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolExecutionException.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolExecutionException.java new file mode 100644 index 00000000000..f76acd1d637 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolExecutionException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +import org.springframework.ai.tool.definition.ToolDefinition; + +/** + * An exception thrown when a tool execution fails. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class ToolExecutionException extends RuntimeException { + + private final ToolDefinition toolDefinition; + + public ToolExecutionException(ToolDefinition toolDefinition, Throwable cause) { + super(cause.getMessage(), cause); + this.toolDefinition = toolDefinition; + } + + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolExecutionMode.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolExecutionMode.java new file mode 100644 index 00000000000..4cd546ad0a1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolExecutionMode.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.execution; + +/** + * How the tool should be executed. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public enum ToolExecutionMode { + + /** + * The tool should be executed in a blocking manner. + */ + BLOCKING; + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/package-info.java new file mode 100644 index 00000000000..528477204d3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.execution; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java new file mode 100644 index 00000000000..3d0ec5ce5d8 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/DefaultToolMetadata.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.metadata; + +import org.springframework.ai.tool.execution.ToolExecutionMode; + +/** + * Default implementation of {@link ToolMetadata}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public record DefaultToolMetadata(ToolExecutionMode executionMode, boolean returnDirect) implements ToolMetadata { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ToolExecutionMode executionMode = ToolExecutionMode.BLOCKING; + + private boolean returnDirect = false; + + private Builder() { + } + + public Builder executionMode(ToolExecutionMode executionMode) { + this.executionMode = executionMode; + return this; + } + + public Builder returnDirect(boolean returnDirect) { + this.returnDirect = returnDirect; + return this; + } + + public DefaultToolMetadata build() { + return new DefaultToolMetadata(executionMode, returnDirect); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java new file mode 100644 index 00000000000..9dad56fc1b1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.metadata; + +import org.springframework.ai.tool.execution.ToolExecutionMode; +import org.springframework.ai.tool.util.ToolUtils; + +import java.lang.reflect.Method; + +/** + * Metadata about a tool specification and execution. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ToolMetadata { + + /** + * How the tool should be executed. + */ + default ToolExecutionMode executionMode() { + return ToolExecutionMode.BLOCKING; + } + + /** + * Whether the tool result should be returned directly or passed back to the model. + */ + default boolean returnDirect() { + return false; + } + + /** + * Create a default {@link ToolMetadata} builder. + */ + static DefaultToolMetadata.Builder builder() { + return DefaultToolMetadata.builder(); + } + + /** + * Create a default {@link ToolMetadata} instance from a {@link Method}. + */ + static ToolMetadata from(Method method) { + return DefaultToolMetadata.builder() + .executionMode(ToolUtils.getToolExecutionMode(method)) + .returnDirect(ToolUtils.getToolReturnDirect(method)) + .build(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/package-info.java new file mode 100644 index 00000000000..e3dbfb3a17c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.metadata; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java new file mode 100644 index 00000000000..424105671d9 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -0,0 +1,224 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.method; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.util.json.JsonParser; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Map; +import java.util.stream.Stream; + +/** + * A {@link ToolCallback} implementation to invoke methods as tools. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class MethodToolCallback implements ToolCallback { + + private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class); + + private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); + + private final ToolDefinition toolDefinition; + + private final ToolMetadata toolMetadata; + + private final Method toolMethod; + + private final Object toolObject; + + private final ToolCallResultConverter toolCallResultConverter; + + public MethodToolCallback(ToolDefinition toolDefinition, ToolMetadata toolMetadata, Method toolMethod, + Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { + Assert.notNull(toolDefinition, "toolDefinition cannot be null"); + Assert.notNull(toolMetadata, "toolMetadata cannot be null"); + Assert.notNull(toolMethod, "toolMethod cannot be null"); + Assert.notNull(toolObject, "toolObject cannot be null"); + this.toolDefinition = toolDefinition; + this.toolMetadata = toolMetadata; + this.toolMethod = toolMethod; + this.toolObject = toolObject; + this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter + : DEFAULT_RESULT_CONVERTER; + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public ToolMetadata getToolMetadata() { + return toolMetadata; + } + + @Override + public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, @Nullable ToolContext toolContext) { + Assert.hasText(toolInput, "toolInput cannot be null or empty"); + + logger.debug("Starting execution of tool: {}", toolDefinition.name()); + + validateToolContextSupport(toolContext); + + Map toolArguments = extractToolArguments(toolInput); + + Object[] methodArguments = buildMethodArguments(toolArguments, toolContext); + + Object result = callMethod(methodArguments); + + logger.debug("Successful execution of tool: {}", toolDefinition.name()); + + Class returnType = toolMethod.getReturnType(); + + return toolCallResultConverter.apply(result, returnType); + } + + private void validateToolContextSupport(@Nullable ToolContext toolContext) { + var isToolContextRequired = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()); + var isToolContextAcceptedByMethod = Stream.of(toolMethod.getParameterTypes()) + .anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class)); + if (isToolContextRequired && !isToolContextAcceptedByMethod) { + throw new IllegalArgumentException("ToolContext is not supported by the method as an argument"); + } + } + + private Map extractToolArguments(String toolInput) { + return JsonParser.fromJson(toolInput, new TypeReference<>() { + }); + } + + // Based on the implementation in MethodInvokingFunctionCallback. + private Object[] buildMethodArguments(Map toolInputArguments, @Nullable ToolContext toolContext) { + return Stream.of(toolMethod.getParameters()).map(parameter -> { + if (parameter.getType().isAssignableFrom(ToolContext.class)) { + return toolContext; + } + Object rawArgument = toolInputArguments.get(parameter.getName()); + return buildTypedArgument(rawArgument, parameter.getType()); + }).toArray(); + } + + @Nullable + private Object buildTypedArgument(@Nullable Object value, Class type) { + if (value == null) { + return null; + } + return JsonParser.toTypedObject(value, type); + } + + @Nullable + private Object callMethod(Object[] methodArguments) { + if (isObjectNotPublic() || isMethodNotPublic()) { + toolMethod.setAccessible(true); + } + + Object result; + try { + result = toolMethod.invoke(toolObject, methodArguments); + } + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); + } + catch (InvocationTargetException ex) { + throw new ToolExecutionException(toolDefinition, ex.getCause()); + } + return result; + } + + private boolean isObjectNotPublic() { + return !Modifier.isPublic(toolObject.getClass().getModifiers()); + } + + private boolean isMethodNotPublic() { + return !Modifier.isPublic(toolMethod.getModifiers()); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ToolDefinition toolDefinition; + + private ToolMetadata toolMetadata; + + private Method toolMethod; + + private Object toolObject; + + private ToolCallResultConverter toolCallResultConverter; + + private Builder() { + } + + public Builder toolDefinition(ToolDefinition toolDefinition) { + this.toolDefinition = toolDefinition; + return this; + } + + public Builder toolMetadata(ToolMetadata toolMetadata) { + this.toolMetadata = toolMetadata; + return this; + } + + public Builder toolMethod(Method toolMethod) { + this.toolMethod = toolMethod; + return this; + } + + public Builder toolObject(Object toolObject) { + this.toolObject = toolObject; + return this; + } + + public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) { + this.toolCallResultConverter = toolCallResultConverter; + return this; + } + + public MethodToolCallback build() { + return new MethodToolCallback(toolDefinition, toolMetadata, toolMethod, toolObject, + toolCallResultConverter); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java new file mode 100644 index 00000000000..ab89f2d73ef --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java @@ -0,0 +1,126 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.method; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.util.ToolUtils; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A {@link ToolCallbackProvider} that builds {@link ToolCallback} instances from + * {@link Tool}-annotated methods. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class MethodToolCallbackProvider implements ToolCallbackProvider { + + private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class); + + private final List toolObjects; + + private MethodToolCallbackProvider(List toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); + this.toolObjects = toolObjects; + } + + @Override + public ToolCallback[] getToolCallbacks() { + var toolCallbacks = toolObjects.stream() + .map(toolObject -> Stream.of(ReflectionUtils.getDeclaredMethods(toolObject.getClass())) + .filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class)) + .filter(toolMethod -> !isFunctionalType(toolMethod)) + .map(toolMethod -> MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(toolObject) + .toolCallResultConverter(ToolUtils.getToolCallResultConverter(toolMethod)) + .build()) + .toArray(ToolCallback[]::new)) + .flatMap(Stream::of) + .toArray(ToolCallback[]::new); + + validateToolCallbacks(toolCallbacks); + + return toolCallbacks; + } + + private boolean isFunctionalType(Method toolMethod) { + var isFunction = ClassUtils.isAssignable(toolMethod.getReturnType(), Function.class) + || ClassUtils.isAssignable(toolMethod.getReturnType(), Supplier.class) + || ClassUtils.isAssignable(toolMethod.getReturnType(), Consumer.class); + + if (isFunction) { + logger.warn("Method {} is annotated with @Tool but returns a functional type. " + + "This is not supported and the method will be ignored.", toolMethod.getName()); + } + + return isFunction; + } + + private void validateToolCallbacks(ToolCallback[] toolCallbacks) { + List duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks); + if (!duplicateToolNames.isEmpty()) { + throw new IllegalStateException("Multiple tools with the same name (%s) found in sources: %s".formatted( + String.join(", ", duplicateToolNames), + toolObjects.stream().map(o -> o.getClass().getName()).collect(Collectors.joining(", ")))); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List toolObjects; + + private Builder() { + } + + public Builder toolObjects(Object... toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + this.toolObjects = Arrays.asList(toolObjects); + return this; + } + + public MethodToolCallbackProvider build() { + return new MethodToolCallbackProvider(toolObjects); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/package-info.java new file mode 100644 index 00000000000..18245cf7b4f --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.method; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/package-info.java new file mode 100644 index 00000000000..58ce69dc1bd --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java new file mode 100644 index 00000000000..084f5441b7f --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java @@ -0,0 +1,93 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.util; + +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolExecutionMode; +import org.springframework.ai.util.ParsingUtils; +import org.springframework.util.StringUtils; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Miscellaneous tool utility methods. Mainly for internal use within the framework. + * + * @author Thomas Vitale + */ +public final class ToolUtils { + + private ToolUtils() { + } + + public static String getToolName(Method method) { + var tool = method.getAnnotation(Tool.class); + if (tool == null) { + return method.getName(); + } + return StringUtils.hasText(tool.name()) ? tool.name() : method.getName(); + } + + public static String getToolDescription(Method method) { + var tool = method.getAnnotation(Tool.class); + if (tool == null) { + return ParsingUtils.reConcatenateCamelCase(method.getName(), " "); + } + return StringUtils.hasText(tool.value()) ? tool.value() : method.getName(); + } + + public static ToolExecutionMode getToolExecutionMode(Method method) { + var tool = method.getAnnotation(Tool.class); + return tool != null ? tool.executionMode() : ToolExecutionMode.BLOCKING; + } + + public static boolean getToolReturnDirect(Method method) { + var tool = method.getAnnotation(Tool.class); + return tool != null && tool.returnDirect(); + } + + public static ToolCallResultConverter getToolCallResultConverter(Method method) { + var tool = method.getAnnotation(Tool.class); + if (tool == null) { + return new DefaultToolCallResultConverter(); + } + var type = tool.resultConverter(); + try { + return type.getDeclaredConstructor().newInstance(); + } + catch (Exception e) { + throw new IllegalArgumentException("Failed to instantiate ToolCallResultConverter: " + type, e); + } + } + + public static List getDuplicateToolNames(FunctionCallback... functionCallbacks) { + return Stream.of(functionCallbacks) + .collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting())) + .entrySet() + .stream() + .filter(entry -> entry.getValue() > 1) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/package-info.java new file mode 100644 index 00000000000..6fb2e67ae1d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.util; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java new file mode 100644 index 00000000000..cf0acb2d3c5 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java @@ -0,0 +1,138 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util.json; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.json.JsonMapper; +import org.springframework.ai.util.JacksonUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +/** + * Utilities to perform parsing operations between JSON and Java. + */ +public final class JsonParser { + + private static final ObjectMapper OBJECT_MAPPER = JsonMapper.builder() + .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) + .disable(SerializationFeature.FAIL_ON_EMPTY_BEANS) + .addModules(JacksonUtils.instantiateAvailableModules()) + .build(); + + private JsonParser() { + } + + /** + * Returns a Jackson {@link ObjectMapper} instance tailored for JSON-parsing + * operations for tool calling and structured output. + */ + public static ObjectMapper getObjectMapper() { + return OBJECT_MAPPER; + } + + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, Class type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, type); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getName()), ex); + } + } + + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, TypeReference type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, type); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getType().getTypeName()), + ex); + } + } + + /** + * Converts a Java object to a JSON string. + */ + public static String toJson(@Nullable Object object) { + try { + return OBJECT_MAPPER.writeValueAsString(object); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from Object to JSON failed", ex); + } + } + + /** + * Convert a Java Object to a typed Object. Based on the implementation in + * MethodInvokingFunctionCallback. + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + public static Object toTypedObject(Object value, Class type) { + Assert.notNull(value, "value cannot be null"); + Assert.notNull(type, "type cannot be null"); + + var javaType = ClassUtils.resolvePrimitiveIfNecessary(type); + + if (javaType == String.class) { + return value.toString(); + } + else if (javaType == Byte.class) { + return Byte.parseByte(value.toString()); + } + else if (javaType == Integer.class) { + return Integer.parseInt(value.toString()); + } + else if (javaType == Short.class) { + return Short.parseShort(value.toString()); + } + else if (javaType == Long.class) { + return Long.parseLong(value.toString()); + } + else if (javaType == Double.class) { + return Double.parseDouble(value.toString()); + } + else if (javaType == Float.class) { + return Float.parseFloat(value.toString()); + } + else if (javaType == Boolean.class) { + return Boolean.parseBoolean(value.toString()); + } + else if (javaType.isEnum()) { + return Enum.valueOf((Class) javaType, value.toString()); + } + + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, javaType); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java new file mode 100644 index 00000000000..6196de3b05f --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java @@ -0,0 +1,189 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util.json; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.victools.jsonschema.generator.Option; +import com.github.victools.jsonschema.generator.OptionPreset; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaVersion; +import com.github.victools.jsonschema.module.jackson.JacksonModule; +import com.github.victools.jsonschema.module.jackson.JacksonOption; +import com.github.victools.jsonschema.module.swagger2.Swagger2Module; +import org.springframework.util.Assert; + +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +/** + * Utilities to generate JSON Schemas from Java entities. + */ +public final class JsonSchemaGenerator { + + private static final SchemaGenerator TYPE_SCHEMA_GENERATOR; + + private static final SchemaGenerator SUBTYPE_SCHEMA_GENERATOR; + + /* + * Initialize JSON Schema generators. + */ + static { + var schemaGeneratorConfigBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON) + .with(new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED)) + .with(new Swagger2Module()) + .with(Option.EXTRA_OPEN_API_FORMAT_VALUES) + .with(Option.PLAIN_DEFINITION_KEYS); + + var typeSchemaGeneratorConfig = schemaGeneratorConfigBuilder.without(Option.SCHEMA_VERSION_INDICATOR).build(); + TYPE_SCHEMA_GENERATOR = new SchemaGenerator(typeSchemaGeneratorConfig); + + var subtypeSchemaGeneratorConfig = schemaGeneratorConfigBuilder.build(); + SUBTYPE_SCHEMA_GENERATOR = new SchemaGenerator(subtypeSchemaGeneratorConfig); + } + + private JsonSchemaGenerator() { + } + + /** + * Generate a JSON Schema for a method's input parameters. + */ + public static String generateForMethodInput(Method method, SchemaOption... schemaOptions) { + ObjectNode schema = JsonParser.getObjectMapper().createObjectNode(); + schema.put("$schema", SchemaVersion.DRAFT_2020_12.getIdentifier()); + schema.put("type", "object"); + + ObjectNode properties = schema.putObject("properties"); + List required = new ArrayList<>(); + + for (int i = 0; i < method.getParameterCount(); i++) { + var parameterName = method.getParameters()[i].getName(); + var parameterType = method.getGenericParameterTypes()[i]; + if (isMethodParameterRequired(method, i)) { + required.add(parameterName); + } + properties.set(parameterName, SUBTYPE_SCHEMA_GENERATOR.generateSchema(parameterType)); + } + + var requiredArray = schema.putArray("required"); + if (Stream.of(schemaOptions).anyMatch(option -> option == SchemaOption.RESPECT_JSON_PROPERTY_REQUIRED)) { + required.forEach(requiredArray::add); + } + else { + Stream.of(method.getParameters()).map(Parameter::getName).forEach(requiredArray::add); + } + + if (Stream.of(schemaOptions) + .noneMatch(option -> option == SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT)) { + schema.put("additionalProperties", false); + } + + if (Stream.of(schemaOptions).anyMatch(option -> option == SchemaOption.UPPER_CASE_TYPE_VALUES)) { + convertTypeValuesToUpperCase(schema); + } + + return schema.toPrettyString(); + } + + /** + * Generate a JSON Schema for a class type. + */ + public static String generateForType(Type type, SchemaOption... schemaOptions) { + Assert.notNull(type, "type cannot be null"); + ObjectNode schema = TYPE_SCHEMA_GENERATOR.generateSchema(type); + if (Stream.of(schemaOptions) + .noneMatch(option -> option == SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT)) { + schema.put("additionalProperties", false); + } + if (Stream.of(schemaOptions).anyMatch(option -> option == SchemaOption.UPPER_CASE_TYPE_VALUES)) { + convertTypeValuesToUpperCase(schema); + } + return schema.toPrettyString(); + } + + private static boolean isMethodParameterRequired(Method method, int index) { + var jsonPropertyAnnotation = method.getParameters()[index].getAnnotation(JsonProperty.class); + if (jsonPropertyAnnotation == null) { + return false; + } + return jsonPropertyAnnotation.required(); + } + + // Based on the method in ModelOptionsUtils. + private static void convertTypeValuesToUpperCase(ObjectNode node) { + if (node.isObject()) { + node.fields().forEachRemaining(entry -> { + JsonNode value = entry.getValue(); + if (value.isObject()) { + convertTypeValuesToUpperCase((ObjectNode) value); + } + else if (value.isArray()) { + value.elements().forEachRemaining(element -> { + if (element.isObject() || element.isArray()) { + convertTypeValuesToUpperCase((ObjectNode) element); + } + }); + } + else if (value.isTextual() && entry.getKey().equals("type")) { + String oldValue = node.get("type").asText(); + node.put("type", oldValue.toUpperCase()); + } + }); + } + else if (node.isArray()) { + node.elements().forEachRemaining(element -> { + if (element.isObject() || element.isArray()) { + convertTypeValuesToUpperCase((ObjectNode) element); + } + }); + } + } + + /** + * Options for generating JSON Schemas. + */ + public enum SchemaOption { + + /** + * Properties are only required if marked as such via the Jackson annotation + * "@JsonProperty(required = true)". Beware, that OpenAI requires all properties + * to be required. + */ + RESPECT_JSON_PROPERTY_REQUIRED, + + /** + * Allow additional properties by default. Beware, that OpenAI requires additional + * properties NOT to be allowed. + */ + ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT, + + /** + * Convert all "type" values to upper case. For example, it's require in OpenAPI + * 3.0 with Vertex AI. + */ + UPPER_CASE_TYPE_VALUES; + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/package-info.java new file mode 100644 index 00000000000..5a13643a184 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.util.json; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/resources/META-INF/spring/aot.factories b/spring-ai-core/src/main/resources/META-INF/spring/aot.factories index bf3d572cdd8..05bc046104d 100644 --- a/spring-ai-core/src/main/resources/META-INF/spring/aot.factories +++ b/spring-ai-core/src/main/resources/META-INF/spring/aot.factories @@ -1,3 +1,4 @@ org.springframework.aot.hint.RuntimeHintsRegistrar=\ org.springframework.ai.aot.SpringAiCoreRuntimeHints,\ - org.springframework.ai.aot.KnuddelsRuntimeHints \ No newline at end of file + org.springframework.ai.aot.KnuddelsRuntimeHints,\ + org.springframework.ai.aot.ToolRuntimeHints \ No newline at end of file diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/ToolRuntimeHintsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/ToolRuntimeHintsTests.java new file mode 100644 index 00000000000..c5a37471a2d --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/ToolRuntimeHintsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.aot; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.aot.hint.RuntimeHints; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; + +/** + * Unit tests for {@link ToolRuntimeHints}. + */ +class ToolRuntimeHintsTests { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + ToolRuntimeHints toolRuntimeHints = new ToolRuntimeHints(); + toolRuntimeHints.registerHints(runtimeHints, null); + assertThat(runtimeHints).matches(reflection().onType(DefaultToolCallResultConverter.class)); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 5f7951ca5cd..e52a2ca763b 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.ai.tool.ToolCallback; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; @@ -1350,6 +1351,43 @@ void whenOptionsThenReturn() { assertThat(defaultSpec.getChatOptions()).isEqualTo(options); } + @Test + void whenToolNamesElementIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.tools("myTool", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolNames cannot contain null elements"); + } + + @Test + void whenToolNamesThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + String toolName = "myTool"; + spec = spec.tools(toolName); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionNames()).contains(toolName); + } + + @Test + void whenToolCallbacksElementIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.toolCallbacks(mock(ToolCallback.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolCallbacks cannot contain null elements"); + } + + @Test + void whenToolCallbacksThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + ToolCallback toolCallback = mock(ToolCallback.class); + spec = spec.toolCallbacks(toolCallback); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).contains(toolCallback); + } + // FunctionCallback.builder().description("description").function(null,input->"hello").inputType(String.class).build() @Test @@ -1480,7 +1518,7 @@ void whenFunctionBeanNamesElementIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); assertThatThrownBy(() -> spec.functions("myFunction", null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("functionBeanNames cannot contain null elements"); + .hasMessage("toolNames cannot contain null elements"); } @Test diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java new file mode 100644 index 00000000000..154a31fa28f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link DefaultToolCallingChatOptions}. + */ +class DefaultToolCallingChatOptionsTests { + + @Test + void setToolCallbacksShouldStoreToolCallbacks() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + ToolCallback callback1 = mock(ToolCallback.class); + ToolCallback callback2 = mock(ToolCallback.class); + List callbacks = List.of(callback1, callback2); + + options.setToolCallbacks(callbacks); + + assertThat(options.getToolCallbacks()).hasSize(2).containsExactlyElementsOf(callbacks); + } + + @Test + void setToolCallbacksWithVarargsShouldStoreToolCallbacks() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + ToolCallback callback1 = mock(ToolCallback.class); + ToolCallback callback2 = mock(ToolCallback.class); + + options.setToolCallbacks(callback1, callback2); + + assertThat(options.getToolCallbacks()).hasSize(2).containsExactly(callback1, callback2); + } + + @Test + void setToolCallbacksShouldRejectNullList() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + + assertThatThrownBy(() -> options.setToolCallbacks((List) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolCallbacks cannot be null"); + } + + @Test + void setToolsShouldStoreTools() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + Set tools = Set.of("tool1", "tool2"); + + options.setTools(tools); + + assertThat(options.getTools()).hasSize(2).containsExactlyInAnyOrderElementsOf(tools); + } + + @Test + void setToolsWithVarargsShouldStoreTools() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + + options.setTools("tool1", "tool2"); + + assertThat(options.getTools()).hasSize(2).containsExactlyInAnyOrder("tool1", "tool2"); + } + + @Test + void setToolsShouldRejectNullSet() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + + assertThatThrownBy(() -> options.setTools((Set) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("tools cannot be null"); + } + + @Test + void setToolsShouldRejectNullElements() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + Set tools = new HashSet<>(); + tools.add(null); + + assertThatThrownBy(() -> options.setTools(tools)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("tools cannot contain null elements"); + } + + @Test + void setToolsShouldRejectEmptyElements() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + Set tools = new HashSet<>(); + tools.add(""); + + assertThatThrownBy(() -> options.setTools(tools)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("tools cannot contain empty elements"); + } + + @Test + void setToolContextShouldStoreContext() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + Map context = Map.of("key1", "value1", "key2", 42); + + options.setToolContext(context); + + assertThat(options.getToolContext()).hasSize(2).containsAllEntriesOf(context); + } + + @Test + void setToolContextShouldRejectNullMap() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + + assertThatThrownBy(() -> options.setToolContext(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolContext cannot be null"); + } + + @Test + void copyShouldCreateNewInstanceWithSameValues() { + DefaultToolCallingChatOptions original = new DefaultToolCallingChatOptions(); + ToolCallback callback = mock(ToolCallback.class); + original.setToolCallbacks(List.of(callback)); + original.setTools(Set.of("tool1")); + original.setToolContext(Map.of("key", "value")); + original.setToolCallReturnDirect(true); + original.setModel("gpt-4"); + original.setTemperature(0.7); + + DefaultToolCallingChatOptions copy = original.copy(); + + assertThat(copy).isNotSameAs(original).satisfies(c -> { + assertThat(c.getToolCallbacks()).isEqualTo(original.getToolCallbacks()); + assertThat(c.getTools()).isEqualTo(original.getTools()); + assertThat(c.getToolContext()).isEqualTo(original.getToolContext()); + assertThat(c.getToolCallReturnDirect()).isEqualTo(original.getToolCallReturnDirect()); + assertThat(c.getModel()).isEqualTo(original.getModel()); + assertThat(c.getTemperature()).isEqualTo(original.getTemperature()); + }); + } + + @Test + void gettersShouldReturnImmutableCollections() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + ToolCallback callback = mock(ToolCallback.class); + options.setToolCallbacks(List.of(callback)); + options.setTools(Set.of("tool1")); + options.setToolContext(Map.of("key", "value")); + + assertThatThrownBy(() -> options.getToolCallbacks().add(mock(ToolCallback.class))) + .isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> options.getTools().add("tool2")).isInstanceOf(UnsupportedOperationException.class); + assertThatThrownBy(() -> options.getToolContext().put("key2", "value2")) + .isInstanceOf(UnsupportedOperationException.class); + } + + @Test + void mergeShouldCombineWithNonToolCallingChatOptions() { + DefaultToolCallingChatOptions original = new DefaultToolCallingChatOptions(); + original.setToolCallbacks(List.of(mock(ToolCallback.class))); + original.setTools(Set.of("tool1")); + original.setModel("gpt-3.5"); + + ChatOptions toMerge = ChatOptions.builder().model("gpt-4").build(); + + ToolCallingChatOptions merged = original.merge(toMerge); + + assertThat(merged.getToolCallbacks()).hasSize(1); + assertThat(merged.getTools()).containsExactly("tool1"); + assertThat(merged.getModel()).isEqualTo("gpt-4"); + } + + @Test + void mergeShouldCombineOptionsCorrectly() { + DefaultToolCallingChatOptions original = new DefaultToolCallingChatOptions(); + original.setToolCallbacks(List.of(mock(ToolCallback.class))); + original.setTools(Set.of("tool1")); + original.setToolContext(Map.of("key1", "value1")); + original.setModel("gpt-3.5"); + + DefaultToolCallingChatOptions toMerge = new DefaultToolCallingChatOptions(); + toMerge.setToolCallbacks(List.of(mock(ToolCallback.class))); + toMerge.setTools(Set.of("tool2")); + toMerge.setToolContext(Map.of("key2", "value2")); + toMerge.setTemperature(0.8); + + ToolCallingChatOptions merged = original.merge(toMerge); + + assertThat(merged.getToolCallbacks()).hasSize(2); + assertThat(merged.getTools()).containsExactlyInAnyOrder("tool1", "tool2"); + assertThat(merged.getToolContext()).containsEntry("key1", "value1").containsEntry("key2", "value2"); + assertThat(merged.getModel()).isEqualTo("gpt-3.5"); + assertThat(merged.getTemperature()).isEqualTo(0.8); + } + + @Test + void builderShouldCreateOptionsWithAllProperties() { + ToolCallback callback = mock(ToolCallback.class); + Map context = Map.of("key", "value"); + + ToolCallingChatOptions options = DefaultToolCallingChatOptions.builder() + .toolCallbacks(List.of(callback)) + .tools(Set.of("tool1")) + .toolContext(context) + .toolCallReturnDirect(true) + .model("gpt-4") + .temperature(0.7) + .maxTokens(100) + .frequencyPenalty(0.5) + .presencePenalty(0.3) + .stopSequences(List.of("stop")) + .topK(3) + .topP(0.9) + .build(); + + assertThat(options).satisfies(o -> { + assertThat(o.getToolCallbacks()).containsExactly(callback); + assertThat(o.getTools()).containsExactly("tool1"); + assertThat(o.getToolContext()).isEqualTo(context); + assertThat(o.getToolCallReturnDirect()).isTrue(); + assertThat(o.getModel()).isEqualTo("gpt-4"); + assertThat(o.getTemperature()).isEqualTo(0.7); + assertThat(o.getMaxTokens()).isEqualTo(100); + assertThat(o.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(o.getPresencePenalty()).isEqualTo(0.3); + assertThat(o.getStopSequences()).containsExactly("stop"); + assertThat(o.getTopK()).isEqualTo(3); + assertThat(o.getTopP()).isEqualTo(0.9); + }); + } + + @Test + void builderShouldSupportToolContextAddition() { + ToolCallingChatOptions options = DefaultToolCallingChatOptions.builder() + .toolContext("key1", "value1") + .toolContext("key2", "value2") + .build(); + + assertThat(options.getToolContext()).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + void deprecatedMethodsShouldWorkCorrectly() { + DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + FunctionCallback callback = mock(FunctionCallback.class); + + assertThatThrownBy(() -> options.setFunctionCallbacks(List.of(callback))) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("Not supported. Call setToolCallbacks instead."); + + options.setTools(Set.of("tool1")); + assertThat(options.getFunctions()).containsExactly("tool1"); + + options.setFunctions(Set.of("function1")); + assertThat(options.getTools()).containsExactly("function1"); + + options.setToolCallReturnDirect(true); + assertThat(options.getProxyToolCalls()).isTrue(); + + options.setProxyToolCalls(true); + assertThat(options.getToolCallReturnDirect()).isTrue(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java new file mode 100644 index 00000000000..c823a9525ad --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java @@ -0,0 +1,42 @@ +package org.springframework.ai.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolCallback}. + * + * @author Thomas Vitale + */ +class ToolCallbackTests { + + @Test + void shouldOnlyImplementRequiredMethods() { + var testToolCallback = new TestToolCallback("test"); + assertThat(testToolCallback.getToolDefinition()).isNotNull(); + assertThat(testToolCallback.getToolMetadata()).isNotNull(); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).description(name).inputTypeSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return ""; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java new file mode 100644 index 00000000000..3f203e5f13f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java @@ -0,0 +1,65 @@ +package org.springframework.ai.tool.definition; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link DefaultToolDefinition}. + * + * @author Thomas Vitale + */ +class DefaultToolDefinitionTests { + + @Test + void shouldCreateDefaultToolDefinition() { + var toolDefinition = new DefaultToolDefinition("name", "description", "{}"); + assertThat(toolDefinition.name()).isEqualTo("name"); + assertThat(toolDefinition.description()).isEqualTo("description"); + assertThat(toolDefinition.inputTypeSchema()).isEqualTo("{}"); + } + + @Test + void shouldThrowExceptionWhenNameIsNull() { + assertThatThrownBy(() -> new DefaultToolDefinition(null, "description", "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenNameIsEmpty() { + assertThatThrownBy(() -> new DefaultToolDefinition("", "description", "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenDescriptionIsNull() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", null, "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenDescriptionIsEmpty() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", "", "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenInputTypeSchemaIsNull() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", "description", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("inputTypeSchema cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenInputTypeSchemaIsEmpty() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", "description", "")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("inputTypeSchema cannot be null or empty"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java new file mode 100644 index 00000000000..5e03710aa4e --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java @@ -0,0 +1,58 @@ +package org.springframework.ai.tool.definition; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.annotation.Tool; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolDefinition}. + * + * @author Thomas Vitale + */ +class ToolDefinitionTests { + + @Test + void shouldCreateDefaultToolDefinitionBuilder() { + var toolDefinition = ToolDefinition.builder() + .name("name") + .description("description") + .inputTypeSchema("{}") + .build(); + assertThat(toolDefinition.name()).isEqualTo("name"); + assertThat(toolDefinition.description()).isEqualTo("description"); + assertThat(toolDefinition.inputTypeSchema()).isEqualTo("{}"); + } + + @Test + void shouldCreateToolDefinitionFromMethod() { + var toolDefinition = ToolDefinition.from(Tools.class.getDeclaredMethods()[0]); + assertThat(toolDefinition.name()).isEqualTo("mySuperTool"); + assertThat(toolDefinition.description()).isEqualTo("Test description"); + assertThat(toolDefinition.inputTypeSchema()).isEqualToIgnoringWhitespace(""" + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "object", + "properties" : { + "input" : { + "type" : "string" + } + }, + "required" : [ "input" ], + "additionalProperties" : false + } + """); + } + + static class Tools { + + @Tool("Test description") + public List mySuperTool(String input) { + return List.of(input); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java new file mode 100644 index 00000000000..7db22d41c9f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java @@ -0,0 +1,96 @@ +package org.springframework.ai.tool.execution; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link DefaultToolCallResultConverter}. + * + * @author Thomas Vitale + */ +class DefaultToolCallResultConverterTests { + + private final DefaultToolCallResultConverter converter = new DefaultToolCallResultConverter(); + + @Test + void convertWithNullReturnTypeShouldThrowException() { + assertThatThrownBy(() -> converter.apply(null, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("returnType cannot be null"); + } + + @Test + void convertVoidReturnTypeShouldReturnDone() { + String result = converter.apply(null, void.class); + assertThat(result).isEqualTo("Done"); + } + + @Test + void convertStringReturnTypeShouldReturnJson() { + String result = converter.apply("test", String.class); + assertThat(result).isEqualTo("\"test\""); + } + + @Test + void convertNullReturnValueShouldReturnNullJson() { + String result = converter.apply(null, String.class); + assertThat(result).isEqualTo("null"); + } + + @Test + void convertObjectReturnTypeShouldReturnJson() { + TestObject testObject = new TestObject("test", 42); + String result = converter.apply(testObject, TestObject.class); + assertThat(result).containsIgnoringWhitespaces(""" + "name": "test" + """).containsIgnoringWhitespaces(""" + "value": 42 + """); + } + + @Test + void convertCollectionReturnTypeShouldReturnJson() { + List testList = List.of("one", "two", "three"); + String result = converter.apply(testList, List.class); + assertThat(result).isEqualTo(""" + ["one","two","three"] + """.trim()); + } + + @Test + void convertMapReturnTypeShouldReturnJson() { + Map testMap = Map.of("one", 1, "two", 2); + String result = converter.apply(testMap, Map.class); + assertThat(result).containsIgnoringWhitespaces(""" + "one": 1 + """).containsIgnoringWhitespaces(""" + "two": 2 + """); + } + + static class TestObject { + + private final String name; + + private final int value; + + TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + public String getName() { + return name; + } + + public int getValue() { + return value; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/ToolExecutionExceptionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/ToolExecutionExceptionTests.java new file mode 100644 index 00000000000..42f61893955 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/ToolExecutionExceptionTests.java @@ -0,0 +1,36 @@ +package org.springframework.ai.tool.execution; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link ToolExecutionException}. + * + * @author Thomas Vitale + */ +class ToolExecutionExceptionTests { + + @Test + void constructorShouldSetCauseAndMessage() { + String errorMessage = "Test error message"; + RuntimeException cause = new RuntimeException(errorMessage); + + ToolExecutionException exception = new ToolExecutionException(mock(ToolDefinition.class), cause); + + assertThat(exception.getCause()).isEqualTo(cause); + assertThat(exception.getMessage()).isEqualTo(errorMessage); + } + + @Test + void getToolDefinitionShouldReturnToolDefinition() { + RuntimeException cause = new RuntimeException("Test error"); + ToolDefinition toolDefinition = mock(ToolDefinition.class); + ToolExecutionException exception = new ToolExecutionException(toolDefinition, cause); + + assertThat(exception.getToolDefinition()).isEqualTo(toolDefinition); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/metadata/DefaultToolMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/metadata/DefaultToolMetadataTests.java new file mode 100644 index 00000000000..8917204ace7 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/metadata/DefaultToolMetadataTests.java @@ -0,0 +1,32 @@ +package org.springframework.ai.tool.metadata; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.execution.ToolExecutionMode; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link DefaultToolMetadata}. + * + * @author Thomas Vitale + */ +class DefaultToolMetadataTests { + + @Test + void shouldCreateDefaultToolMetadataWithDefaultValues() { + var toolMetadata = DefaultToolMetadata.builder().build(); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isFalse(); + } + + @Test + void shouldCreateDefaultToolMetadataWithGivenValues() { + var toolMetadata = DefaultToolMetadata.builder() + .executionMode(ToolExecutionMode.BLOCKING) + .returnDirect(true) + .build(); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isTrue(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/metadata/ToolMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/metadata/ToolMetadataTests.java new file mode 100644 index 00000000000..51ed140ad1d --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/metadata/ToolMetadataTests.java @@ -0,0 +1,41 @@ +package org.springframework.ai.tool.metadata; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.execution.ToolExecutionMode; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolMetadata}. + * + * @author Thomas Vitale + */ +class ToolMetadataTests { + + @Test + void shouldCreateDefaultToolMetadataBuilder() { + var toolMetadata = ToolMetadata.builder().build(); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isFalse(); + } + + @Test + void shouldCreateToolMetadataFromMethod() { + var toolMetadata = ToolMetadata.from(Tools.class.getDeclaredMethods()[0]); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isTrue(); + } + + static class Tools { + + @Tool(value = "Test description", returnDirect = true) + public List mySuperTool(String input) { + return List.of(input); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java new file mode 100644 index 00000000000..34058a05276 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java @@ -0,0 +1,254 @@ +package org.springframework.ai.tool.method; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; + +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MethodToolCallbackProvider}. + * + * @author Thomas Vitale + */ +class MethodToolCallbackProviderTests { + + @Nested + class BuilderValidationTests { + + @Test + void shouldRejectNullToolObjects() { + assertThatThrownBy(() -> MethodToolCallbackProvider.builder().toolObjects((Object[]) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolObjects cannot be null"); + } + + @Test + void shouldRejectNullToolObjectElements() { + assertThatThrownBy(() -> MethodToolCallbackProvider.builder().toolObjects(new Tools(), null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolObjects cannot contain null elements"); + } + + @Test + void shouldAcceptEmptyToolObjects() { + var provider = MethodToolCallbackProvider.builder().toolObjects().build(); + assertThat(provider.getToolCallbacks()).isEmpty(); + } + + } + + @Test + void shouldProvideToolCallbacksFromObject() { + Tools tools = new Tools(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + + assertThat(callbacks).hasSize(2); + + var callback1 = Stream.of(callbacks).filter(c -> c.getName().equals("testMethod")).findFirst(); + assertThat(callback1).isPresent(); + assertThat(callback1.get().getName()).isEqualTo("testMethod"); + assertThat(callback1.get().getDescription()).isEqualTo("Test description"); + + var callback2 = Stream.of(callbacks).filter(c -> c.getName().equals("testStaticMethod")).findFirst(); + assertThat(callback2).isPresent(); + assertThat(callback2.get().getName()).isEqualTo("testStaticMethod"); + assertThat(callback2.get().getDescription()).isEqualTo("Test description"); + } + + @Test + void shouldProvideToolCallbacksFromMultipleObjects() { + Tools tools1 = new Tools(); + ToolsExtra tools2 = new ToolsExtra(); + + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools1, tools2).build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(4); // 2 from Tools + 2 from ToolsExtra + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)).containsExactlyInAnyOrder("testMethod", + "testStaticMethod", "extraMethod1", "extraMethod2"); + } + + @Test + void shouldEnsureUniqueToolNames() { + ToolsWithDuplicates testComponent = new ToolsWithDuplicates(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build(); + + assertThatThrownBy(provider::getToolCallbacks).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Multiple tools with the same name (testMethod) found in sources: " + + testComponent.getClass().getName()); + } + + @Test + void shouldHandleToolMethodsWithDifferentVisibility() { + ToolsWithVisibility tools = new ToolsWithVisibility(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(3); + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)).containsExactlyInAnyOrder("publicMethod", + "protectedMethod", "privateMethod"); + } + + @Test + void shouldHandleToolMethodsWithDifferentParameters() { + ToolsWithParameters tools = new ToolsWithParameters(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(3); + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)).containsExactlyInAnyOrder("noParams", "oneParam", + "multipleParams"); + } + + @Test + void shouldHandleToolMethodsWithDifferentReturnTypes() { + ToolsWithReturnTypes tools = new ToolsWithReturnTypes(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(4); + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)).containsExactlyInAnyOrder("voidMethod", + "primitiveMethod", "objectMethod", "collectionMethod"); + } + + static class Tools { + + @Tool("Test description") + static List testStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List testMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + Function testFunction(String input) { + // This method should be ignored as it's a functional type + return String::length; + } + + @Tool("Test description") + Consumer testConsumer(String input) { + // This method should be ignored as it's a functional type + return System.out::println; + } + + @Tool("Test description") + Supplier testSupplier() { + // This method should be ignored as it's a functional type + return () -> "test"; + } + + void nonToolMethod() { + // This method should be ignored as it doesn't have @Tool annotation + } + + } + + static class ToolsExtra { + + @Tool("Extra method 1") + String extraMethod1() { + return "extra1"; + } + + @Tool("Extra method 2") + String extraMethod2() { + return "extra2"; + } + + } + + static class ToolsWithDuplicates { + + @Tool(name = "testMethod", value = "Test description") + List testMethod1(String input) { + return List.of(input); + } + + @Tool(name = "testMethod", value = "Test description") + List testMethod2(String input) { + return List.of(input); + } + + } + + static class ToolsWithVisibility { + + @Tool("Public method") + public String publicMethod() { + return "public"; + } + + @Tool("Protected method") + protected String protectedMethod() { + return "protected"; + } + + @Tool("Private method") + private String privateMethod() { + return "private"; + } + + } + + static class ToolsWithParameters { + + @Tool("No parameters") + String noParams() { + return "no params"; + } + + @Tool("One parameter") + String oneParam(String param) { + return param; + } + + @Tool("Multiple parameters") + String multipleParams(String param1, int param2, boolean param3) { + return param1 + param2 + param3; + } + + } + + static class ToolsWithReturnTypes { + + @Tool("Void method") + void voidMethod() { + } + + @Tool("Primitive method") + int primitiveMethod() { + return 42; + } + + @Tool("Object method") + String objectMethod() { + return "object"; + } + + @Tool("Collection method") + List collectionMethod() { + return List.of("collection"); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java new file mode 100644 index 00000000000..b363f3e8cbd --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java @@ -0,0 +1,309 @@ +package org.springframework.ai.tool.method; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.util.json.JsonParser; +import org.springframework.util.ReflectionUtils; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MethodToolCallback}. + * + * @author Thomas Vitale + */ +class MethodToolCallbackTests { + + @ParameterizedTest + @ValueSource(strings = { "publicStaticMethod", "privateStaticMethod", "packageStaticMethod", "publicMethod", + "privateMethod", "packageMethod" }) + void shouldCallToolFromPublicClass(String methodName) { + validateAssertions(methodName, new PublicTools()); + } + + @ParameterizedTest + @ValueSource(strings = { "publicStaticMethod", "privateStaticMethod", "packageStaticMethod", "publicMethod", + "privateMethod", "packageMethod" }) + void shouldCallToolFromPrivateClass(String methodName) { + validateAssertions(methodName, new PrivateTools()); + } + + @ParameterizedTest + @ValueSource(strings = { "publicStaticMethod", "privateStaticMethod", "packageStaticMethod", "publicMethod", + "privateMethod", "packageMethod" }) + void shouldCallToolFromPackageClass(String methodName) { + validateAssertions(methodName, new PackageTools()); + } + + @Test + void shouldHandleToolContextWhenSupported() { + Method toolMethod = getMethod("methodWithToolContext", ToolContextTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ToolContextTools()) + .build(); + + ToolContext toolContext = new ToolContext(Map.of("key", "value")); + String result = callback.call(""" + { + "input": "test" + } + """, toolContext); + + assertThat(result).contains("value"); + } + + @Test + void shouldThrowExceptionWhenToolContextNotSupported() { + Method toolMethod = getMethod("publicMethod", PublicTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new PublicTools()) + .build(); + + ToolContext toolContext = new ToolContext(Map.of("key", "value")); + + assertThatThrownBy(() -> callback.call(""" + { + "input": "test" + } + """, toolContext)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ToolContext is not supported"); + } + + @Test + void shouldHandleComplexArguments() { + Method toolMethod = getMethod("complexArgumentMethod", ComplexTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ComplexTools()) + .build(); + + String result = callback.call(""" + { + "stringArg": "test", + "intArg": 42, + "listArg": ["a", "b", "c"], + "optionalArg": null + } + """); + + assertThat(JsonParser.fromJson(result, new TypeReference>() { + })).containsEntry("stringValue", "test").containsEntry("intValue", 42).containsEntry("listSize", 3); + } + + @Test + void shouldHandleCustomResultConverter() { + Method toolMethod = getMethod("publicMethod", PublicTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new PublicTools()) + .toolCallResultConverter((result, type) -> "Converted: " + result) + .build(); + + String result = callback.call(""" + { + "input": "test" + } + """); + + assertThat(result).startsWith("Converted:"); + } + + @Test + void shouldThrowExceptionWhenToolExecutionFails() { + Method toolMethod = getMethod("errorMethod", ErrorTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ErrorTools()) + .build(); + + assertThatThrownBy(() -> callback.call(""" + { + "input": "test" + } + """)).isInstanceOf(ToolExecutionException.class).hasMessageContaining("Test error"); + } + + private static void validateAssertions(String methodName, Object toolObject) { + Method toolMethod = getMethod(methodName, toolObject.getClass()); + assertThat(toolMethod).isNotNull(); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(toolObject) + .build(); + + String result = callback.call(""" + { + "input": "Wingardium Leviosa" + } + """); + + assertThat(JsonParser.fromJson(result, new TypeReference>() { + })).contains("Wingardium Leviosa"); + } + + private static Method getMethod(String name, Class toolsClass) { + return Arrays.stream(ReflectionUtils.getDeclaredMethods(toolsClass)) + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(); + } + + static public class PublicTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static private class PrivateTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static class PackageTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static class ToolContextTools { + + @Tool("Test description") + public String methodWithToolContext(String input, ToolContext toolContext) { + return input + ": " + toolContext.getContext().get("key"); + } + + } + + static class ComplexTools { + + @Tool("Test description") + public Map complexArgumentMethod(String stringArg, int intArg, List listArg, + String optionalArg) { + return Map.of("stringValue", stringArg, "intValue", intArg, "listSize", listArg.size(), "optionalProvided", + optionalArg != null); + } + + } + + static class ErrorTools { + + @Tool("Test description") + public String errorMethod(String input) { + throw new IllegalArgumentException("Test error"); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java new file mode 100644 index 00000000000..ad8876d4fce --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java @@ -0,0 +1,214 @@ +package org.springframework.ai.tool.utils; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolExecutionMode; +import org.springframework.ai.tool.util.ToolUtils; + +import java.lang.reflect.Method; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ToolUtils}. + * + * @author Thomas Vitale + */ +class ToolUtilsTests { + + @Test + void shouldDetectDuplicateToolNames() { + ToolCallback callback1 = new TestToolCallback("tool_a"); + ToolCallback callback2 = new TestToolCallback("tool_a"); + ToolCallback callback3 = new TestToolCallback("tool_b"); + + List duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3); + + assertThat(duplicates).isNotEmpty(); + assertThat(duplicates).contains("tool_a"); + } + + @Test + void shouldNotDetectDuplicateToolNames() { + ToolCallback callback1 = new TestToolCallback("tool_a"); + ToolCallback callback2 = new TestToolCallback("tool_b"); + ToolCallback callback3 = new TestToolCallback("tool_c"); + + List duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3); + + assertThat(duplicates).isEmpty(); + } + + @Test + void shouldGetToolNameFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomName"); + assertThat(ToolUtils.getToolName(method)).isEqualTo("customName"); + } + + @Test + void shouldGetMethodNameWhenNoCustomNameInAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithoutCustomName"); + assertThat(ToolUtils.getToolName(method)).isEqualTo("toolWithoutCustomName"); + } + + @Test + void shouldGetMethodNameWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + assertThat(ToolUtils.getToolName(method)).isEqualTo("methodWithoutAnnotation"); + } + + @Test + void shouldGetToolDescriptionFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomDescription"); + assertThat(ToolUtils.getToolDescription(method)).isEqualTo("Custom description"); + } + + @Test + void shouldGetMethodNameWhenNoCustomDescriptionInAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithoutCustomDescription"); + assertThat(ToolUtils.getToolDescription(method)).isEqualTo("toolWithoutCustomDescription"); + } + + @Test + void shouldGetFormattedMethodNameWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("camelCaseMethodWithoutAnnotation"); + assertThat(ToolUtils.getToolDescription(method)).isEqualTo("camel case method without annotation"); + } + + @Test + void shouldGetToolExecutionModeFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomExecutionMode"); + assertThat(ToolUtils.getToolExecutionMode(method)).isEqualTo(ToolExecutionMode.BLOCKING); + } + + @Test + void shouldGetDefaultExecutionModeWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + assertThat(ToolUtils.getToolExecutionMode(method)).isEqualTo(ToolExecutionMode.BLOCKING); + } + + @Test + void shouldGetToolReturnDirectFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithReturnDirect"); + assertThat(ToolUtils.getToolReturnDirect(method)).isTrue(); + } + + @Test + void shouldGetDefaultReturnDirectWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + assertThat(ToolUtils.getToolReturnDirect(method)).isFalse(); + } + + @Test + void shouldGetToolCallResultConverterFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomConverter"); + ToolCallResultConverter converter = ToolUtils.getToolCallResultConverter(method); + assertThat(converter).isInstanceOf(CustomToolCallResultConverter.class); + } + + @Test + void shouldGetDefaultConverterWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + ToolCallResultConverter converter = ToolUtils.getToolCallResultConverter(method); + assertThat(converter).isInstanceOf(DefaultToolCallResultConverter.class); + } + + @Test + void shouldThrowExceptionWhenConverterCannotBeInstantiated() throws Exception { + Method method = TestTools.class.getMethod("toolWithInvalidConverter"); + assertThatThrownBy(() -> ToolUtils.getToolCallResultConverter(method)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to instantiate ToolCallResultConverter"); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).description(name).inputTypeSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String functionInput) { + return ""; + } + + } + + static class TestTools { + + @Tool(name = "customName") + public void toolWithCustomName() { + } + + @Tool + public void toolWithoutCustomName() { + } + + @Tool(value = "Custom description") + public void toolWithCustomDescription() { + } + + @Tool + public void toolWithoutCustomDescription() { + } + + @Tool(executionMode = ToolExecutionMode.BLOCKING) + public void toolWithCustomExecutionMode() { + } + + @Tool(returnDirect = true) + public void toolWithReturnDirect() { + } + + @Tool(resultConverter = CustomToolCallResultConverter.class) + public void toolWithCustomConverter() { + } + + @Tool(resultConverter = InvalidToolCallResultConverter.class) + public void toolWithInvalidConverter() { + } + + public void methodWithoutAnnotation() { + } + + public void camelCaseMethodWithoutAnnotation() { + } + + } + + public static class CustomToolCallResultConverter implements ToolCallResultConverter { + + @Override + public String apply(Object result, Class returnType) { + return returnType.getName(); + } + + } + + // No-public class with no-public constructor + static class InvalidToolCallResultConverter implements ToolCallResultConverter { + + private InvalidToolCallResultConverter() { + } + + @Override + public String apply(Object result, Class returnType) { + return returnType.getName(); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java b/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java new file mode 100644 index 00000000000..e3552f69407 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java @@ -0,0 +1,221 @@ +package org.springframework.ai.util.json; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for the {@link JsonParser} class. + * + * @author Thomas Vitale + */ +class JsonParserTests { + + @Test + void shouldGetObjectMapper() { + var objectMapper = JsonParser.getObjectMapper(); + assertThat(objectMapper).isNotNull(); + } + + @Test + void shouldThrowExceptionWhenJsonIsNull() { + assertThatThrownBy(() -> JsonParser.fromJson(null, TestRecord.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("json cannot be null"); + } + + @Test + void shouldThrowExceptionWhenClassIsNull() { + assertThatThrownBy(() -> JsonParser.fromJson("{}", (Class) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void shouldThrowExceptionWhenTypeIsNull() { + assertThatThrownBy(() -> JsonParser.fromJson("{}", (TypeReference) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void fromJsonToObject() { + var json = """ + { + "name" : "John", + "age" : 30 + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isEqualTo(30); + } + + @Test + void fromJsonToObjectWithMissingProperty() { + var json = """ + { + "name": "John" + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isNull(); + } + + @Test + void fromJsonToObjectWithNullProperty() { + var json = """ + { + "name": "John", + "age": null + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isNull(); + } + + @Test + void fromJsonToObjectWithOtherNullProperty() { + var json = """ + { + "name": null, + "age": 21 + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isNull(); + assertThat(object.age).isEqualTo(21); + } + + @Test + void fromJsonToObjectWithUnknownProperty() { + var json = """ + { + "name": "James", + "surname": "Bond" + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("James"); + } + + @Test + void fromObjectToJson() { + var object = new TestRecord("John", 30); + var json = JsonParser.toJson(object); + assertThat(json).isEqualToIgnoringWhitespace(""" + { + "name" : "John", + "age" : 30 + } + """); + } + + @Test + void fromObjectToJsonWithNullValues() { + var object = new TestRecord("John", null); + var json = JsonParser.toJson(object); + assertThat(json).isEqualToIgnoringWhitespace(""" + { + "name" : "John", + "age" : null + } + """); + } + + @Test + void fromNullObjectToJson() { + var json = JsonParser.toJson(null); + assertThat(json).isEqualToIgnoringWhitespace("null"); + } + + @Test + void fromObjectToString() { + var value = JsonParser.toTypedObject("John", String.class); + assertThat(value).isOfAnyClassIn(String.class); + assertThat(value).isEqualTo("John"); + } + + @Test + void fromObjectToByte() { + var value = JsonParser.toTypedObject("1", Byte.class); + assertThat(value).isOfAnyClassIn(Byte.class); + assertThat(value).isEqualTo((byte) 1); + } + + @Test + void fromObjectToInteger() { + var value = JsonParser.toTypedObject("1", Integer.class); + assertThat(value).isOfAnyClassIn(Integer.class); + assertThat(value).isEqualTo(1); + } + + @Test + void fromObjectToShort() { + var value = JsonParser.toTypedObject("1", Short.class); + assertThat(value).isOfAnyClassIn(Short.class); + assertThat(value).isEqualTo((short) 1); + } + + @Test + void fromObjectToLong() { + var value = JsonParser.toTypedObject("1", Long.class); + assertThat(value).isOfAnyClassIn(Long.class); + assertThat(value).isEqualTo(1L); + } + + @Test + void fromObjectToDouble() { + var value = JsonParser.toTypedObject("1.0", Double.class); + assertThat(value).isOfAnyClassIn(Double.class); + assertThat(value).isEqualTo(1.0); + } + + @Test + void fromObjectToFloat() { + var value = JsonParser.toTypedObject("1.0", Float.class); + assertThat(value).isOfAnyClassIn(Float.class); + assertThat(value).isEqualTo(1.0f); + } + + @Test + void fromObjectToBoolean() { + var value = JsonParser.toTypedObject("true", Boolean.class); + assertThat(value).isOfAnyClassIn(Boolean.class); + assertThat(value).isEqualTo(true); + } + + @Test + void fromObjectToEnum() { + var value = JsonParser.toTypedObject("VALUE", TestEnum.class); + assertThat(value).isOfAnyClassIn(TestEnum.class); + assertThat(value).isEqualTo(TestEnum.VALUE); + } + + @Test + void fromObjectToRecord() { + var record = new TestRecord("John", 30); + var value = JsonParser.toTypedObject(record, TestRecord.class); + assertThat(value).isOfAnyClassIn(TestRecord.class); + assertThat(value).isEqualTo(new TestRecord("John", 30)); + } + + record TestRecord(String name, Integer age) { + } + + enum TestEnum { + + VALUE + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java new file mode 100644 index 00000000000..406d8cf3555 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java @@ -0,0 +1,368 @@ +package org.springframework.ai.util.json; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.Month; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link JsonSchemaGenerator}. + * + * @author Thomas Vitale + */ +class JsonSchemaGeneratorTests { + + @Test + void generateSchemaForMethodWithSimpleParameters() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer", + "format" : "int32" + } + }, + "required": [ + "name", + "age" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithJsonPropertyAnnotations() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("annotatedMethod", String.class, String.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method, + JsonSchemaGenerator.SchemaOption.RESPECT_JSON_PROPERTY_REQUIRED); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "password": { + "type": "string" + } + }, + "required": [ + "password" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithAdditionalPropertiesAllowed() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method, + JsonSchemaGenerator.SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT); + + JsonNode jsonNode = JsonParser.getObjectMapper().readTree(schema); + assertThat(jsonNode.has("additionalProperties")).isFalse(); + } + + @Test + void generateSchemaForMethodWithUpperCaseTypes() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "OBJECT", + "properties": { + "name": { + "type": "STRING" + }, + "age": { + "type": "INTEGER", + "format" : "int32" + } + }, + "required": [ + "name", + "age" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithComplexParameters() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("complexMethod", List.class, TestData.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method); + + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "string" + } + }, + "data": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format" : "int32" + }, + "name": { + "type": "string" + } + } + } + }, + "required": [ + "items", + "data" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithTimeParameters() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("timeMethod", Duration.class, LocalDateTime.class, + Instant.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "duration": { + "type": "string", + "format" : "duration" + }, + "localDateTime": { + "type": "string", + "format": "date-time" + }, + "instant": { + "type": "string", + "format": "date-time" + } + }, + "required": [ + "duration", + "localDateTime", + "instant" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForSimpleType() { + String schema = JsonSchemaGenerator.generateForType(Person.class); + String expectedJsonSchema = """ + { + "type": "object", + "properties": { + "email": { + "type": "string" + }, + "id": { + "type": "integer", + "format" : "int32" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForTypeWithAdditionalPropertiesAllowed() throws JsonProcessingException { + String schema = JsonSchemaGenerator.generateForType(Person.class, + JsonSchemaGenerator.SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT); + + JsonNode jsonNode = JsonParser.getObjectMapper().readTree(schema); + assertThat(jsonNode.has("additionalProperties")).isFalse(); + } + + @Test + void generateSchemaForTypeWithUpperCaseValues() { + String schema = JsonSchemaGenerator.generateForType(Person.class, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES); + String expectedJsonSchema = """ + { + "type": "OBJECT", + "properties": { + "email": { + "type": "STRING" + }, + "id": { + "type": "INTEGER", + "format" : "int32" + }, + "name": { + "type": "STRING" + } + }, + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForRecord() { + String schema = JsonSchemaGenerator.generateForType(TestData.class); + String expectedJsonSchema = """ + { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format" : "int32" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForEnum() { + String schema = JsonSchemaGenerator.generateForType(Month.class); + String expectedJsonSchema = """ + { + "type": "string", + "enum": [ + "JANUARY", + "FEBRUARY", + "MARCH", + "APRIL", + "MAY", + "JUNE", + "JULY", + "AUGUST", + "SEPTEMBER", + "OCTOBER", + "NOVEMBER", + "DECEMBER" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void throwExceptionWhenTypeIsNull() { + assertThatThrownBy(() -> JsonSchemaGenerator.generateForType(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + static class TestMethods { + + public void simpleMethod(String name, int age) { + } + + public void annotatedMethod(String username, @JsonProperty(required = true) String password) { + } + + public void complexMethod(List items, TestData data) { + } + + public void timeMethod(Duration duration, LocalDateTime localDateTime, Instant instant) { + } + + } + + record TestData(int id, String name) { + } + + static class Person { + + private int id; + + private String name; + + private String email; + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getEmail() { + return email; + } + + public void setEmail(String email) { + this.email = email; + } + + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java new file mode 100644 index 00000000000..23573f23dc2 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java @@ -0,0 +1,193 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.method.MethodToolCallback; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link MethodToolCallback}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class MethodToolCallbackTests { + + @Autowired + OpenAiChatModel openAiChatModel; + + Tools tools = new Tools(new BookService()); + + @Test + void chatMethodNoArgs() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("Welcome the user to the library") + .tools(tools) + .call() + .content(); + assertThat(content).isNotEmpty(); + } + + @Test + void chatMethodVoid() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("Welcome %s to the library".formatted("James Bond")) + .tools(tools) + .call() + .content(); + assertThat(content).isNotEmpty(); + } + + @Test + void chatMethodSingle() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")) + .tools(tools) + .call() + .content(); + assertThat(content).isNotEmpty() + .contains("The Hobbit") + .contains("The Lord of The Rings") + .contains("The Silmarillion"); + } + + @Test + void chatMethodList() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia")) + .tools(tools) + .call() + .content(); + assertThat(content).isNotEmpty().contains("J.R.R. Tolkien").contains("C.S. Lewis"); + } + + @Test + void chatMethodCallback() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia")) + .toolCallbacks(ToolCallbacks.from(tools)) + .call() + .content(); + assertThat(content).isNotEmpty().contains("J.R.R. Tolkien").contains("C.S. Lewis"); + } + + @Test + void chatMethodCallbackDefault() { + var content = ChatClient.builder(this.openAiChatModel) + .defaultTools(tools) + .build() + .prompt() + .user("How many books written by %s are available in the library?".formatted("J.R.R. Tolkien")) + .call() + .content(); + assertThat(content).isNotEmpty().containsAnyOf("three", "3"); + } + + static class Tools { + + private static final Logger logger = LoggerFactory.getLogger(Tools.class); + + private final BookService bookService; + + Tools(BookService bookService) { + this.bookService = bookService; + } + + @Tool("Welcome users to the library") + void welcome() { + logger.info("Welcoming users to the library"); + } + + @Tool("Welcome a specific user to the library") + void welcomeUser(String user) { + logger.info("Welcoming {} to the library", user); + } + + @Tool("Get the list of books written by the given author available in the library") + List booksByAuthor(String author) { + logger.info("Getting books by author: {}", author); + return bookService.getBooksByAuthor(new Author(author)); + } + + @Tool("Get the list of authors who wrote the given books available in the library") + List authorsByBooks(List books) { + logger.info("Getting authors by books: {}", String.join(", ", books)); + return bookService.getAuthorsByBook(books.stream().map(b -> new Book(b, "")).toList()); + } + + } + + public record Author(String name) { + } + + public record Book(String title, String author) { + } + + static class BookService { + + private static final Map books = new ConcurrentHashMap<>(); + + static { + books.put(1, new Book("His Dark Materials", "Philip Pullman")); + books.put(2, new Book("Narnia", "C.S. Lewis")); + books.put(3, new Book("The Hobbit", "J.R.R. Tolkien")); + books.put(4, new Book("The Lord of The Rings", "J.R.R. Tolkien")); + books.put(5, new Book("The Silmarillion", "J.R.R. Tolkien")); + } + + public List getBooksByAuthor(Author author) { + return books.values().stream().filter(book -> author.name().equals(book.author())).toList(); + } + + public List getAuthorsByBook(List booksToSearch) { + return books.values() + .stream() + .filter(book -> booksToSearch.stream().anyMatch(b -> b.title().equals(book.title()))) + .map(book -> new Author(book.author())) + .toList(); + } + + } + +}