Skip to content

Commit

Permalink
suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hunterjackson committed Oct 10, 2023
1 parent a70d5a3 commit c308bd2
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 179 deletions.
7 changes: 5 additions & 2 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ public Map<Long, Double> logitBias() {
return logitBias;
}

public Optional<String> systemMessage() {
return Optional.ofNullable(systemMessage);
public String systemMessage() {
if (systemMessage == null) {
return "You're a helpful assistant.";
}
return systemMessage;
}

public long maxInputTokens() {
Expand Down
84 changes: 40 additions & 44 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,62 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;

import java.io.IOException;
import java.net.URI;
import java.time.Instant;

import org.apache.hc.client5.http.fluent.Request;
import org.apache.hc.client5.http.fluent.Response;
import org.apache.hc.core5.http.ContentType;

public class HuggingFaceLlamaPlugin<T extends Message> implements LLMPlugin<T> {

private static final ObjectMapper MAPPER = new ObjectMapper();
private final HuggingFaceConfig config;
private URI endpoint;

public HuggingFaceLlamaPlugin(HuggingFaceConfig config) {
this.config = config;
this.endpoint = this.config.endpoint();
}
private static final ObjectMapper MAPPER = new ObjectMapper();
private final HuggingFaceConfig config;
private final HuggingFaceLlamaPrompt<T> promptCreator;

@Override
public T handle(ThreadState<T> threadState) throws IOException {
T fromUser = threadState.tail();
private URI endpoint;

ObjectNode body = MAPPER.createObjectNode();
ObjectNode params = MAPPER.createObjectNode();
public HuggingFaceLlamaPlugin(HuggingFaceConfig config) {
this.config = config;
this.endpoint = this.config.endpoint();
promptCreator = new HuggingFaceLlamaPrompt<>(config);
}

config.topP().ifPresent(v -> params.put("top_p", v));
config.temperature().ifPresent(v -> params.put("temperature", v));
config.maxOutputTokens().ifPresent(v -> params.put("max_new_tokens", v));
@Override
public T handle(ThreadState<T> threadState) throws IOException {
ObjectNode body = MAPPER.createObjectNode();
ObjectNode params = MAPPER.createObjectNode();

body.set("parameters", params);
config.topP().ifPresent(v -> params.put("top_p", v));
config.temperature().ifPresent(v -> params.put("temperature", v));
config.maxOutputTokens().ifPresent(v -> params.put("max_new_tokens", v));

HuggingFaceLlamaPromptBuilder<T> promptBuilder = new HuggingFaceLlamaPromptBuilder<>();
body.set("parameters", params);

String prompt = promptBuilder.createPrompt(threadState, config);
if (prompt.equals("I'm sorry but that request was too long for me.")) {
return threadState.newMessageFromBot(
Instant.now(), prompt);
}

body.put("inputs", prompt);

String bodyString;
try {
bodyString = MAPPER.writeValueAsString(body);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Response response =
Request.post(endpoint)
.bodyString(bodyString, ContentType.APPLICATION_JSON)
.setHeader("Authorization", "Bearer " + config.apiKey())
.execute();
String prompt = promptCreator.createPrompt(threadState);
if (prompt.equals("I'm sorry but that request was too long for me.")) {
return threadState.newMessageFromBot(Instant.now(), prompt);
}

JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes());
String allGeneratedText = responseBody.get(0).get("generated_text").textValue();
String llmResponse = allGeneratedText.strip().replace(prompt.strip(), "");
Instant timestamp = Instant.now();
body.put("inputs", prompt);

return threadState.newMessageFromBot(timestamp, llmResponse);
String bodyString;
try {
bodyString = MAPPER.writeValueAsString(body);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Response response =
Request.post(endpoint)
.bodyString(bodyString, ContentType.APPLICATION_JSON)
.setHeader("Authorization", "Bearer " + config.apiKey())
.execute();

JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes());
String allGeneratedText = responseBody.get(0).get("generated_text").textValue();
String llmResponse = allGeneratedText.strip().replace(prompt.strip(), "");
Instant timestamp = Instant.now();

return threadState.newMessageFromBot(timestamp, llmResponse);
}
}
134 changes: 134 additions & 0 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.llm;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;
import org.checkerframework.common.returnsreceiver.qual.This;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HuggingFaceLlamaPrompt<T extends Message> {

private static final Logger LOGGER = LoggerFactory.getLogger(HuggingFaceLlamaPrompt.class);
private final String systemMessage;
private final long maxInputTokens;
private final HuggingFaceTokenizer tokenizer;

public HuggingFaceLlamaPrompt(HuggingFaceConfig config) {

this.systemMessage = config.systemMessage();
this.maxInputTokens = config.maxInputTokens();
URL llamaTokenizerUrl =
Objects.requireNonNull(
HuggingFaceLlamaPrompt.class.getClassLoader().getResource("llamaTokenizer.json"));
URI llamaTokenizer;
try {
llamaTokenizer = llamaTokenizerUrl.toURI();
tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(llamaTokenizer));

} catch (URISyntaxException | IOException e) {
// this should be impossible
throw new RuntimeException(e);
}
}

public String createPrompt(ThreadState<T> threadState) {

PromptBuilder builder = new PromptBuilder();

for (T message : threadState.messages()) {
switch (message.role()) {
case SYSTEM -> builder.addSystem(message);
case USER -> builder.addUser(message);
case ASSISTANT -> builder.addAssistant(message);
}
}

return builder.build();
}

private int tokenCount(String message) {
Encoding encoding = tokenizer.encode(message);
return encoding.getTokens().length - 1;
}

// TODO: move logic into promptbuilder
private String pruneMessages(ThreadState<T> threadState) {

int totalTokens = 5; // Account for closing tokens at end of message
StringBuilder promptStringBuilder = new StringBuilder();
String systemPrompt = "<s>[INST] <<SYS>>\n" + systemMessage + "\n<</SYS>>\n\n";
totalTokens += tokenCount(systemPrompt);
promptStringBuilder
.append("<s>[INST] <<SYS>>\n")
.append(systemMessage)
.append("\n<</SYS>>\n\n");

Message.Role nextMessageSender = Message.Role.ASSISTANT;
StringBuilder contextStringBuilder = new StringBuilder();

List<T> messages = threadState.messages();

for (int i = messages.size() - 1; i >= 0; i--) {
Message message = messages.get(i);
StringBuilder messageText = new StringBuilder();
String text = message.message().strip();
Message.Role user = message.role();
boolean isUser = user == Message.Role.USER;
messageText.append(text);
if (isUser && nextMessageSender == Message.Role.ASSISTANT) {
messageText.append(" [/INST] ");
} else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) {
messageText.append(" </s><s>[INST] ");
}
totalTokens += tokenCount(messageText.toString());
if (totalTokens > maxInputTokens) {
if (contextStringBuilder.isEmpty()) {
return "I'm sorry but that request was too long for me.";
}
break;
}
contextStringBuilder.append(messageText.reverse());

nextMessageSender = user;
}
if (nextMessageSender == Message.Role.ASSISTANT) {
contextStringBuilder.append(
" ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after
// system prompt is not from user
}

promptStringBuilder.append(contextStringBuilder.reverse());
return promptStringBuilder.toString().strip();
}

// TODO: convert this to a class and implement the methods to replace pruneMethod
private interface PromptBuilder {

@This
PromptBuilder addSystem(Message message);

@This
PromptBuilder addAssistant(Message message);

@This
PromptBuilder addUser(Message message);

String build();
}
}
114 changes: 0 additions & 114 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java

This file was deleted.

7 changes: 5 additions & 2 deletions src/main/java/com/meta/cp4m/llm/OpenAIConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ public Map<Long, Double> logitBias() {
return logitBias;
}

public Optional<String> systemMessage() {
return Optional.ofNullable(systemMessage);
public String systemMessage() {
if (systemMessage == null) {
return "You're a helpful assistant.";
}
return systemMessage;
}

public long maxInputTokens() {
Expand Down
Loading

0 comments on commit c308bd2

Please sign in to comment.