diff --git a/src/main/java/com/github/llamara/ai/internal/chat/aiservice/ChatModelAiService.java b/src/main/java/com/github/llamara/ai/internal/chat/aiservice/ChatModelAiService.java index 05975b7..8ecf954 100644 --- a/src/main/java/com/github/llamara/ai/internal/chat/aiservice/ChatModelAiService.java +++ b/src/main/java/com/github/llamara/ai/internal/chat/aiservice/ChatModelAiService.java @@ -22,6 +22,7 @@ import java.util.UUID; import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.Result; import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.UserMessage; @@ -49,9 +50,9 @@ public interface ChatModelAiService { """; @SystemMessage(SYSTEM_MESSAGE) - String chat(@MemoryId UUID sessionId, boolean history, @UserMessage String prompt); + Result chat(@MemoryId UUID sessionId, boolean history, @UserMessage String prompt); - String chatWithoutSystemMessage( + Result chatWithoutSystemMessage( @MemoryId UUID sessionId, boolean history, @UserMessage String prompt); @SystemMessage(SYSTEM_MESSAGE) diff --git a/src/main/java/com/github/llamara/ai/internal/chat/aiservice/DelegatingChatModelAiService.java b/src/main/java/com/github/llamara/ai/internal/chat/aiservice/DelegatingChatModelAiService.java index 7531534..786b990 100644 --- a/src/main/java/com/github/llamara/ai/internal/chat/aiservice/DelegatingChatModelAiService.java +++ b/src/main/java/com/github/llamara/ai/internal/chat/aiservice/DelegatingChatModelAiService.java @@ -23,6 +23,7 @@ import java.util.UUID; +import dev.langchain4j.service.Result; import dev.langchain4j.service.TokenStream; /** @@ -52,12 +53,12 @@ public ChatModelConfig.ModelConfig config() { } @Override - public String chat(UUID sessionId, boolean history, String prompt) { + public Result chat(UUID sessionId, boolean history, String prompt) { return delegate.chat(sessionId, history, prompt); } @Override - public String chatWithoutSystemMessage(UUID sessionId, boolean history, String prompt) { + public Result chatWithoutSystemMessage(UUID sessionId, boolean history, String prompt) { return delegate.chatWithoutSystemMessage(sessionId, history, prompt); } diff --git a/src/main/java/com/github/llamara/ai/internal/chat/history/HistoryInterceptingAiService.java b/src/main/java/com/github/llamara/ai/internal/chat/history/HistoryInterceptingAiService.java index 559aabb..695dada 100644 --- a/src/main/java/com/github/llamara/ai/internal/chat/history/HistoryInterceptingAiService.java +++ b/src/main/java/com/github/llamara/ai/internal/chat/history/HistoryInterceptingAiService.java @@ -29,9 +29,9 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; -import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessageType; import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.Result; import dev.langchain4j.service.TokenStream; /** @@ -42,7 +42,7 @@ */ public class HistoryInterceptingAiService extends DelegatingChatModelAiService { private final BiConsumer promptConsumer; - private final BiConsumer responseConsumer; + private final BiConsumer responseConsumer; public HistoryInterceptingAiService( ChatModelAiService delegate, @@ -68,7 +68,7 @@ public HistoryInterceptingAiService( sessionId, new ChatMessageRecord( ChatMessageType.AI, - response.text(), + response, Instant.now(), config.provider(), config.model())) @@ -77,26 +77,26 @@ public HistoryInterceptingAiService( } @Override - public String chat(UUID sessionId, boolean history, String prompt) { + public Result chat(UUID sessionId, boolean history, String prompt) { if (!history) { return super.chat(sessionId, false, prompt); } promptConsumer.accept(sessionId, prompt); - String response = super.chat(sessionId, true, prompt); - responseConsumer.accept(sessionId, new AiMessage(response)); + Result response = super.chat(sessionId, true, prompt); + responseConsumer.accept(sessionId, response.content()); return response; } @Override - public String chatWithoutSystemMessage(UUID sessionId, boolean history, String prompt) { + public Result chatWithoutSystemMessage(UUID sessionId, boolean history, String prompt) { if (!history) { return super.chatWithoutSystemMessage(sessionId, false, prompt); } promptConsumer.accept(sessionId, prompt); - String response = super.chatWithoutSystemMessage(sessionId, true, prompt); - responseConsumer.accept(sessionId, new AiMessage(response)); + Result response = super.chatWithoutSystemMessage(sessionId, true, prompt); + responseConsumer.accept(sessionId, response.content()); return response; } @@ -109,7 +109,7 @@ public TokenStream chatAndStreamResponse(UUID sessionId, boolean history, String promptConsumer.accept(sessionId, prompt); return new CompletionInterceptingTokenStream( super.chatAndStreamResponse(sessionId, true, prompt), - response -> responseConsumer.accept(sessionId, response.aiMessage())); + response -> responseConsumer.accept(sessionId, response.aiMessage().text())); } /** diff --git a/src/main/java/com/github/llamara/ai/internal/rest/ChatResource.java b/src/main/java/com/github/llamara/ai/internal/rest/ChatResource.java index 4d52480..22af880 100644 --- a/src/main/java/com/github/llamara/ai/internal/rest/ChatResource.java +++ b/src/main/java/com/github/llamara/ai/internal/rest/ChatResource.java @@ -19,6 +19,7 @@ */ package com.github.llamara.ai.internal.rest; +import com.github.llamara.ai.internal.MetadataKeys; import com.github.llamara.ai.internal.chat.ChatModelContainer; import com.github.llamara.ai.internal.chat.ChatModelNotFoundException; import com.github.llamara.ai.internal.chat.ChatModelProvider; @@ -28,6 +29,7 @@ import com.github.llamara.ai.internal.security.session.SessionManager; import com.github.llamara.ai.internal.security.session.SessionNotFoundException; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.UUID; @@ -44,6 +46,7 @@ import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.core.MediaType; +import dev.langchain4j.service.Result; import io.quarkus.security.identity.SecurityIdentity; import io.smallrye.common.annotation.Blocking; import io.smallrye.common.annotation.NonBlocking; @@ -105,16 +108,16 @@ public Collection getModels() { @POST @Path("/prompt") @Consumes(MediaType.TEXT_PLAIN) - @Produces(MediaType.TEXT_PLAIN) + @Produces(MediaType.APPLICATION_JSON) @Operation(operationId = "prompt", summary = "Send a prompt to the given chat model.") @APIResponse( responseCode = "200", description = "OK", - content = @Content(schema = @Schema(implementation = String.class))) + content = @Content(schema = @Schema(implementation = ChatResponseDTO.class))) @APIResponse( responseCode = "404", description = "No chat model or no session with given ID found.") - public String prompt( + public ChatResponseDTO prompt( @QueryParam("uid") @Parameter( name = "uid", @@ -131,13 +134,16 @@ public String prompt( throws ChatModelNotFoundException, SessionNotFoundException { sessionManager.enforceSessionValid(sessionId); ChatModelContainer chatModel = chatModelProvider.getModel(uid); + Result result; if (chatModel.config().systemPromptEnabled()) { - return chatModel.service().chat(sessionId, !identity.isAnonymous(), prompt); + result = chatModel.service().chat(sessionId, !identity.isAnonymous(), prompt); } else { - return chatModel - .service() - .chatWithoutSystemMessage(sessionId, !identity.isAnonymous(), prompt); + result = + chatModel + .service() + .chatWithoutSystemMessage(sessionId, !identity.isAnonymous(), prompt); } + return new ChatResponseDTO(result); } /* @@ -312,4 +318,24 @@ public void keepAliveAnonymousSession( throws SessionNotFoundException { sessionManager.enforceSessionValid(sessionId); } + + public static class ChatResponseDTO { + public final String response; + public final List sources = new ArrayList<>(); + + public ChatResponseDTO(Result result) { + this.response = result.content(); + result.sources().stream() + .map(dev.langchain4j.rag.content.Content::textSegment) + .forEach( + ts -> + sources.add( + new SourceRecord( + ts.metadata() + .getUUID(MetadataKeys.KNOWLEDGE_ID), + ts.text()))); + } + + public record SourceRecord(UUID knowledgeId, String content) {} + } }