diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceConfig.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceConfig.java index 65f186b..a995897 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceConfig.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceConfig.java @@ -107,8 +107,11 @@ public Map logitBias() { return logitBias; } - public Optional systemMessage() { - return Optional.ofNullable(systemMessage); + public String systemMessage() { + if (systemMessage == null) { + return "You're a helpful assistant."; + } + return systemMessage; } public long maxInputTokens() { diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java index abff7df..a504f20 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java @@ -14,66 +14,62 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.meta.cp4m.message.Message; import com.meta.cp4m.message.ThreadState; - import java.io.IOException; import java.net.URI; import java.time.Instant; - import org.apache.hc.client5.http.fluent.Request; import org.apache.hc.client5.http.fluent.Response; import org.apache.hc.core5.http.ContentType; public class HuggingFaceLlamaPlugin implements LLMPlugin { - private static final ObjectMapper MAPPER = new ObjectMapper(); - private final HuggingFaceConfig config; - private URI endpoint; - - public HuggingFaceLlamaPlugin(HuggingFaceConfig config) { - this.config = config; - this.endpoint = this.config.endpoint(); - } + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final HuggingFaceConfig config; + private final HuggingFaceLlamaPrompt promptCreator; - @Override - public T handle(ThreadState threadState) throws IOException { - T fromUser = threadState.tail(); + private URI endpoint; - ObjectNode body = MAPPER.createObjectNode(); - ObjectNode params = MAPPER.createObjectNode(); + public HuggingFaceLlamaPlugin(HuggingFaceConfig config) { + this.config = config; + this.endpoint = this.config.endpoint(); + promptCreator = new HuggingFaceLlamaPrompt<>(config); + } - config.topP().ifPresent(v -> params.put("top_p", v)); - config.temperature().ifPresent(v -> params.put("temperature", v)); - config.maxOutputTokens().ifPresent(v -> params.put("max_new_tokens", v)); + @Override + public T handle(ThreadState threadState) throws IOException { + ObjectNode body = MAPPER.createObjectNode(); + ObjectNode params = MAPPER.createObjectNode(); - body.set("parameters", params); + config.topP().ifPresent(v -> params.put("top_p", v)); + config.temperature().ifPresent(v -> params.put("temperature", v)); + config.maxOutputTokens().ifPresent(v -> params.put("max_new_tokens", v)); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); + body.set("parameters", params); - String prompt = promptBuilder.createPrompt(threadState, config); - if (prompt.equals("I'm sorry but that request was too long for me.")) { - return threadState.newMessageFromBot( - Instant.now(), prompt); - } - - body.put("inputs", prompt); - - String bodyString; - try { - bodyString = MAPPER.writeValueAsString(body); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - Response response = - Request.post(endpoint) - .bodyString(bodyString, ContentType.APPLICATION_JSON) - .setHeader("Authorization", "Bearer " + config.apiKey()) - .execute(); + String prompt = promptCreator.createPrompt(threadState); + if (prompt.equals("I'm sorry but that request was too long for me.")) { + return threadState.newMessageFromBot(Instant.now(), prompt); + } - JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes()); - String allGeneratedText = responseBody.get(0).get("generated_text").textValue(); - String llmResponse = allGeneratedText.strip().replace(prompt.strip(), ""); - Instant timestamp = Instant.now(); + body.put("inputs", prompt); - return threadState.newMessageFromBot(timestamp, llmResponse); + String bodyString; + try { + bodyString = MAPPER.writeValueAsString(body); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); } + Response response = + Request.post(endpoint) + .bodyString(bodyString, ContentType.APPLICATION_JSON) + .setHeader("Authorization", "Bearer " + config.apiKey()) + .execute(); + + JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes()); + String allGeneratedText = responseBody.get(0).get("generated_text").textValue(); + String llmResponse = allGeneratedText.strip().replace(prompt.strip(), ""); + Instant timestamp = Instant.now(); + + return threadState.newMessageFromBot(timestamp, llmResponse); + } } diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java new file mode 100644 index 0000000..606e4db --- /dev/null +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java @@ -0,0 +1,134 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.cp4m.llm; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import com.meta.cp4m.message.Message; +import com.meta.cp4m.message.ThreadState; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Paths; +import java.util.*; +import org.checkerframework.common.returnsreceiver.qual.This; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class HuggingFaceLlamaPrompt { + + private static final Logger LOGGER = LoggerFactory.getLogger(HuggingFaceLlamaPrompt.class); + private final String systemMessage; + private final long maxInputTokens; + private final HuggingFaceTokenizer tokenizer; + + public HuggingFaceLlamaPrompt(HuggingFaceConfig config) { + + this.systemMessage = config.systemMessage(); + this.maxInputTokens = config.maxInputTokens(); + URL llamaTokenizerUrl = + Objects.requireNonNull( + HuggingFaceLlamaPrompt.class.getClassLoader().getResource("llamaTokenizer.json")); + URI llamaTokenizer; + try { + llamaTokenizer = llamaTokenizerUrl.toURI(); + tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(llamaTokenizer)); + + } catch (URISyntaxException | IOException e) { + // this should be impossible + throw new RuntimeException(e); + } + } + + public String createPrompt(ThreadState threadState) { + + PromptBuilder builder = new PromptBuilder(); + + for (T message : threadState.messages()) { + switch (message.role()) { + case SYSTEM -> builder.addSystem(message); + case USER -> builder.addUser(message); + case ASSISTANT -> builder.addAssistant(message); + } + } + + return builder.build(); + } + + private int tokenCount(String message) { + Encoding encoding = tokenizer.encode(message); + return encoding.getTokens().length - 1; + } + + // TODO: move logic into promptbuilder + private String pruneMessages(ThreadState threadState) { + + int totalTokens = 5; // Account for closing tokens at end of message + StringBuilder promptStringBuilder = new StringBuilder(); + String systemPrompt = "[INST] <>\n" + systemMessage + "\n<>\n\n"; + totalTokens += tokenCount(systemPrompt); + promptStringBuilder + .append("[INST] <>\n") + .append(systemMessage) + .append("\n<>\n\n"); + + Message.Role nextMessageSender = Message.Role.ASSISTANT; + StringBuilder contextStringBuilder = new StringBuilder(); + + List messages = threadState.messages(); + + for (int i = messages.size() - 1; i >= 0; i--) { + Message message = messages.get(i); + StringBuilder messageText = new StringBuilder(); + String text = message.message().strip(); + Message.Role user = message.role(); + boolean isUser = user == Message.Role.USER; + messageText.append(text); + if (isUser && nextMessageSender == Message.Role.ASSISTANT) { + messageText.append(" [/INST] "); + } else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) { + messageText.append(" [INST] "); + } + totalTokens += tokenCount(messageText.toString()); + if (totalTokens > maxInputTokens) { + if (contextStringBuilder.isEmpty()) { + return "I'm sorry but that request was too long for me."; + } + break; + } + contextStringBuilder.append(messageText.reverse()); + + nextMessageSender = user; + } + if (nextMessageSender == Message.Role.ASSISTANT) { + contextStringBuilder.append( + " ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after + // system prompt is not from user + } + + promptStringBuilder.append(contextStringBuilder.reverse()); + return promptStringBuilder.toString().strip(); + } + + // TODO: convert this to a class and implement the methods to replace pruneMethod + private interface PromptBuilder { + + @This + PromptBuilder addSystem(Message message); + + @This + PromptBuilder addAssistant(Message message); + + @This + PromptBuilder addUser(Message message); + + String build(); + } +} diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java deleted file mode 100644 index 974d3ef..0000000 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.meta.cp4m.llm; - -import ai.djl.huggingface.tokenizers.Encoding; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.meta.cp4m.message.Message; -import com.meta.cp4m.message.ThreadState; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.beans.beancontext.BeanContextChild; -import java.io.IOException; -import java.nio.file.Paths; -import java.util.*; -import java.net.URI; -import java.net.URISyntaxException; - - -public class HuggingFaceLlamaPromptBuilder { - - private static final Logger LOGGER = LoggerFactory.getLogger(HuggingFaceLlamaPromptBuilder.class); - - public String createPrompt(ThreadState threadState, HuggingFaceConfig config) { - -// NEW PLAN -// WE do the node thing -// and then we use the other token coutning thing where we add buffer tokens for each -// And then we pass the remaining messages to the promptbuilder - - URI resource = null; - try { - resource = Objects.requireNonNull(HuggingFaceLlamaPromptBuilder.class.getClassLoader().getResource("llamaTokenizer.json")).toURI(); - } catch (URISyntaxException e) { - LOGGER.error("Failed to find local llama tokenizer.json file", e); - } - - try { - assert resource != null; - HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(resource)); - return pruneMessages(threadState, config, tokenizer); - } catch (IOException e) { - LOGGER.error("Failed to initialize Llama2 tokenizer from local file", e); - } - - if (config.systemMessage().isPresent()) { - return "[INST] <>\n" + (config.systemMessage().get()) + "\n<>\n\n" + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] "; - } else { - return "[INST] " + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] "; - } - } - - private int tokenCount(String message, HuggingFaceTokenizer tokenizer) { - Encoding encoding = tokenizer.encode(message); - return encoding.getTokens().length - 1; - } - - private String pruneMessages(ThreadState threadState, HuggingFaceConfig config, HuggingFaceTokenizer tokenizer) - throws JsonProcessingException { - - int totalTokens = 5; // Account for closing tokens at end of message - StringBuilder promptStringBuilder = new StringBuilder(); - if (config.systemMessage().isPresent()) { - String systemPrompt = "[INST] <>\n" + config.systemMessage().get() + "\n<>\n\n"; - totalTokens += tokenCount(systemPrompt, tokenizer); - promptStringBuilder.append("[INST] <>\n").append(config.systemMessage().get()).append("\n<>\n\n"); - } else { - totalTokens += 6; - promptStringBuilder.append("[INST] "); - } - - Message.Role nextMessageSender = Message.Role.ASSISTANT; - StringBuilder contextStringBuilder = new StringBuilder(); - - List messages = threadState.messages(); - - for (int i = messages.size() - 1; i >= 0; i--) { - Message message = messages.get(i); - StringBuilder messageText = new StringBuilder(); - String text = message.message().strip(); - Message.Role user = message.role(); - boolean isUser = user == Message.Role.USER; - messageText.append(text); - if (isUser && nextMessageSender == Message.Role.ASSISTANT) { - messageText.append(" [/INST] "); - } else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) { - messageText.append(" [INST] "); - } - totalTokens += tokenCount(messageText.toString(), tokenizer); - if (totalTokens > config.maxInputTokens()) { - if (contextStringBuilder.isEmpty()) { - return "I'm sorry but that request was too long for me."; - } - break; - } - contextStringBuilder.append(messageText.reverse()); - - nextMessageSender = user; - } - if (nextMessageSender == Message.Role.ASSISTANT) { - contextStringBuilder.append(" ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after system prompt is not from user - } - - promptStringBuilder.append(contextStringBuilder.reverse()); - return promptStringBuilder.toString().strip(); - } -} diff --git a/src/main/java/com/meta/cp4m/llm/OpenAIConfig.java b/src/main/java/com/meta/cp4m/llm/OpenAIConfig.java index c65d252..5ba5aa2 100644 --- a/src/main/java/com/meta/cp4m/llm/OpenAIConfig.java +++ b/src/main/java/com/meta/cp4m/llm/OpenAIConfig.java @@ -106,8 +106,11 @@ public Map logitBias() { return logitBias; } - public Optional systemMessage() { - return Optional.ofNullable(systemMessage); + public String systemMessage() { + if (systemMessage == null) { + return "You're a helpful assistant."; + } + return systemMessage; } public long maxInputTokens() { diff --git a/src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java b/src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java index 3e59435..1260ca9 100644 --- a/src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java +++ b/src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java @@ -148,14 +148,10 @@ public T handle(ThreadState threadState) throws IOException { } ArrayNode messages = MAPPER.createArrayNode(); - config - .systemMessage() - .ifPresent( - m -> - messages - .addObject() - .put("role", Role.SYSTEM.toString().toLowerCase()) - .put("content", m)); + messages + .addObject() + .put("role", Role.SYSTEM.toString().toLowerCase()) + .put("content", config.systemMessage()); for (T message : threadState.messages()) { messages .addObject() diff --git a/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java b/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java index 28a4bd7..7c2beb3 100644 --- a/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java +++ b/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java @@ -25,7 +25,6 @@ import com.meta.cp4m.store.ChatStore; import com.meta.cp4m.store.MemoryStoreConfig; import io.javalin.Javalin; - import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; @@ -40,7 +39,6 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; - import org.apache.hc.client5.http.fluent.Request; import org.apache.hc.core5.http.HttpResponse; import org.apache.hc.core5.net.URIBuilder; @@ -123,8 +121,8 @@ void createPayload() { HuggingFaceConfig config = HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build(); HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(STACK, config); + HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config); + String createdPayload = promptBuilder.createPrompt(STACK); assertThat(createdPayload).isEqualTo(TEST_PAYLOAD); } @@ -144,8 +142,8 @@ void createPayloadWithSystemMessage() { Identifier.random(), Identifier.random(), Role.USER)); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(stack, config); + HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config); + String createdPayload = promptBuilder.createPrompt(stack); assertThat(createdPayload).isEqualTo(TEST_PAYLOAD_WITH_SYSTEM); } @@ -195,8 +193,8 @@ void truncatesContext() throws IOException { Identifier.random(), Role.USER)); thread = thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2))); - HuggingFaceLlamaPromptBuilder promptBuilder = new HuggingFaceLlamaPromptBuilder<>(); - String createdPayload = promptBuilder.createPrompt(thread, config); + HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config); + String createdPayload = promptBuilder.createPrompt(thread); assertThat(createdPayload).isEqualTo(TEST_PAYLOAD); } diff --git a/src/test/java/com/meta/cp4m/llm/OpenAIPluginTest.java b/src/test/java/com/meta/cp4m/llm/OpenAIPluginTest.java index 3acc5f8..2d13489 100644 --- a/src/test/java/com/meta/cp4m/llm/OpenAIPluginTest.java +++ b/src/test/java/com/meta/cp4m/llm/OpenAIPluginTest.java @@ -207,7 +207,8 @@ void orderedCorrectly() throws IOException, InterruptedException { assertThat(or).isNotNull(); JsonNode body = MAPPER.readTree(or.body()); - for (int i = 0; i < thread.messages().size(); i++) { + // first is the system message + for (int i = 1; i < thread.messages().size(); i++) { FBMessage threadMessage = thread.messages().get(i); JsonNode sentMessage = body.get("messages").get(i); assertSoftly(