Skip to content

Commit

Permalink
Local tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
colinmccloskey committed Oct 3, 2023
1 parent 6491ed2 commit 3ade3a0
Show file tree
Hide file tree
Showing 4 changed files with 93,451 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@
<artifactId>jtokkit</artifactId>
<version>0.6.1</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.23.0</version>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import org.checkerframework.checker.nullness.qual.Nullable;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;

import java.util.Optional;

Expand Down Expand Up @@ -75,56 +76,59 @@ public String createPrompt(MessageStack<T> messageStack, HuggingFaceConfig confi
}

private int tokenCount(JsonNode message) {
int tokenCount = tokensPerMessage;
tokenCount += tokenEncoding.countTokens(message.get("content").textValue());
tokenCount += tokenEncoding.countTokens(message.get("role").textValue());
@Nullable JsonNode name = message.get("name");
if (name != null) {
tokenCount += tokenEncoding.countTokens(name.textValue());
tokenCount += tokensPerName;
}
return tokenCount;
// int tokenCount = tokensPerMessage;
// tokenCount += tokenEncoding.countTokens(message.get("content").textValue());
// tokenCount += tokenEncoding.countTokens(message.get("role").textValue());
// @Nullable JsonNode name = message.get("name");
// if (name != null) {
// tokenCount += tokenEncoding.countTokens(name.textValue());
// tokenCount += tokensPerName;
// }
// return tokenCount;
return 100;
}

private Optional<ArrayNode> pruneMessages(ArrayNode messages, @Nullable JsonNode functions)
throws JsonProcessingException {

int functionTokens = 0;
if (functions != null) {
// This is honestly a guess, it's undocumented
functionTokens = tokenEncoding.countTokens(MAPPER.writeValueAsString(functions));
}

ArrayNode output = MAPPER.createArrayNode();
int totalTokens = functionTokens;
totalTokens += 3; // every reply is primed with <|start|>assistant<|message|>

JsonNode systemMessage = messages.get(0);
boolean hasSystemMessage = systemMessage.get("role").textValue().equals("system");
if (hasSystemMessage) {
// if the system message is present it's required
totalTokens += tokenCount(messages.get(0));
}
for (int i = messages.size() - 1; i >= 0; i--) {
JsonNode m = messages.get(i);
String role = m.get("role").textValue();
if (role.equals("system")) {
continue; // system has already been counted
}
totalTokens += tokenCount(m);
if (totalTokens > MAX_TOTAL_TOKENS) {
break;
}
output.insert(0, m);
}
if (hasSystemMessage) {
output.insert(0, systemMessage);
}

if ((hasSystemMessage && output.size() <= 1) || output.isEmpty()) {
return Optional.empty();
}

return Optional.of(output);
}
// private Optional<ArrayNode> pruneMessages(ArrayNode messages, @Nullable JsonNode functions)
// throws JsonProcessingException {
//
// HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("meta-llama/Llama-2-7b-chat-hf");
//
// int functionTokens = 0;
// if (functions != null) {
// // This is honestly a guess, it's undocumented
// functionTokens = tokenEncoding.countTokens(MAPPER.writeValueAsString(functions));
// }
//
// ArrayNode output = MAPPER.createArrayNode();
// int totalTokens = functionTokens;
// totalTokens += 3; // every reply is primed with <|start|>assistant<|message|>
//
// JsonNode systemMessage = messages.get(0);
// boolean hasSystemMessage = systemMessage.get("role").textValue().equals("system");
// if (hasSystemMessage) {
// // if the system message is present it's required
// totalTokens += tokenCount(messages.get(0));
// }
// for (int i = messages.size() - 1; i >= 0; i--) {
// JsonNode m = messages.get(i);
// String role = m.get("role").textValue();
// if (role.equals("system")) {
// continue; // system has already been counted
// }
// totalTokens += tokenCount(m);
// if (totalTokens > MAX_TOTAL_TOKENS) {
// break;
// }
// output.insert(0, m);
// }
// if (hasSystemMessage) {
// output.insert(0, systemMessage);
// }
//
// if ((hasSystemMessage && output.size() <= 1) || output.isEmpty()) {
// return Optional.empty();
// }
//
// return Optional.of(output);
// }
}
Loading

0 comments on commit 3ade3a0

Please sign in to comment.