Skip to content

Commit

Permalink
Formatting, removing comments, removing now-duplicated system message…
Browse files Browse the repository at this point in the history
… test
  • Loading branch information
colinmccloskey committed Oct 4, 2023
1 parent c802798 commit a70d5a3
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 302 deletions.
10 changes: 6 additions & 4 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;

import java.io.IOException;
import java.net.URI;
import java.time.Instant;

import org.apache.hc.client5.http.fluent.Request;
import org.apache.hc.client5.http.fluent.Response;
import org.apache.hc.core5.http.ContentType;
Expand All @@ -29,11 +31,11 @@ public class HuggingFaceLlamaPlugin<T extends Message> implements LLMPlugin<T> {

public HuggingFaceLlamaPlugin(HuggingFaceConfig config) {
this.config = config;
this.endpoint = this.config.endpoint();
this.endpoint = this.config.endpoint();
}

@Override
public T handle(ThreadState<T> threadState) throws IOException {
@Override
public T handle(ThreadState<T> threadState) throws IOException {
T fromUser = threadState.tail();

ObjectNode body = MAPPER.createObjectNode();
Expand All @@ -48,7 +50,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
HuggingFaceLlamaPromptBuilder<T> promptBuilder = new HuggingFaceLlamaPromptBuilder<>();

String prompt = promptBuilder.createPrompt(threadState, config);
if (prompt.equals("I'm sorry but that request was too long for me.")){
if (prompt.equals("I'm sorry but that request was too long for me.")) {
return threadState.newMessageFromBot(
Instant.now(), prompt);
}
Expand Down
38 changes: 10 additions & 28 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ public String createPrompt(ThreadState<T> threadState, HuggingFaceConfig config)
LOGGER.error("Failed to initialize Llama2 tokenizer from local file", e);
}

if(config.systemMessage().isPresent()){
if (config.systemMessage().isPresent()) {
return "<s>[INST] <<SYS>>\n" + (config.systemMessage().get()) + "\n<</SYS>>\n\n" + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] ";
}
else{
} else {
return "<s>[INST] " + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] ";
}
}
Expand All @@ -68,52 +67,35 @@ private String pruneMessages(ThreadState<T> threadState, HuggingFaceConfig confi

int totalTokens = 5; // Account for closing tokens at end of message
StringBuilder promptStringBuilder = new StringBuilder();
if(config.systemMessage().isPresent()){
if (config.systemMessage().isPresent()) {
String systemPrompt = "<s>[INST] <<SYS>>\n" + config.systemMessage().get() + "\n<</SYS>>\n\n";
totalTokens += tokenCount(systemPrompt, tokenizer);
promptStringBuilder.append("<s>[INST] <<SYS>>\n").append(config.systemMessage().get()).append("\n<</SYS>>\n\n");
}
else {
} else {
totalTokens += 6;
promptStringBuilder.append("<s>[INST] ");
}

// for (int i = list.size() - 1; i >= 0; i--)
// {
// // access elements by their index (position)
// System.out.println(list.get(i));
// }

// Okay so we have a system prompt stringbuilder and then a context stringbuilder and we add those together and
// only if context stringbuilde ris empty do we return the "too long" message



// The first user input is _not_ stripped
// boolean hasUserMessage = false;
Message.Role nextMessageSender = Message.Role.ASSISTANT;
StringBuilder contextStringBuilder = new StringBuilder();

List<T> messages = threadState.messages();

for (int i = messages.size() - 1; i >= 0; i--)
{
for (int i = messages.size() - 1; i >= 0; i--) {
Message message = messages.get(i);
StringBuilder messageText = new StringBuilder();
String text = message.message().strip();
Message.Role user = message.role();
boolean isUser = user == Message.Role.USER;
// access elements by their index (position)
messageText.append(text);
if (isUser && nextMessageSender == Message.Role.ASSISTANT){
if (isUser && nextMessageSender == Message.Role.ASSISTANT) {
messageText.append(" [/INST] ");
}
else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER){
} else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) {
messageText.append(" </s><s>[INST] ");
}
totalTokens += tokenCount(messageText.toString(), tokenizer);
if(totalTokens > config.maxInputTokens()){
if(contextStringBuilder.isEmpty()){
if (totalTokens > config.maxInputTokens()) {
if (contextStringBuilder.isEmpty()) {
return "I'm sorry but that request was too long for me.";
}
break;
Expand All @@ -122,7 +104,7 @@ else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USE

nextMessageSender = user;
}
if(nextMessageSender == Message.Role.ASSISTANT){
if (nextMessageSender == Message.Role.ASSISTANT) {
contextStringBuilder.append(" ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after system prompt is not from user
}

Expand Down
Loading

0 comments on commit a70d5a3

Please sign in to comment.