diff --git a/pom.xml b/pom.xml
index 95afe2e..f92f650 100644
--- a/pom.xml
+++ b/pom.xml
@@ -82,6 +82,11 @@
log4j-slf4j2-impl
2.20.0
+
+ com.knuddels
+ jtokkit
+ 0.6.1
+
diff --git a/src/main/java/com/meta/chatbridge/Configuration.java b/src/main/java/com/meta/chatbridge/Configuration.java
new file mode 100644
index 0000000..59eb727
--- /dev/null
+++ b/src/main/java/com/meta/chatbridge/Configuration.java
@@ -0,0 +1,23 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge;
+
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+public class Configuration {
+ public static final ObjectMapper MAPPER =
+ new ObjectMapper()
+ .enable(DeserializationFeature.FAIL_ON_NULL_FOR_PRIMITIVES)
+ .enable(DeserializationFeature.WRAP_EXCEPTIONS)
+ .enable(DeserializationFeature.READ_ENUMS_USING_TO_STRING)
+ .enable(DeserializationFeature.USE_LONG_FOR_INTS)
+ .disable(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT)
+ .disable(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES);
+}
diff --git a/src/main/java/com/meta/chatbridge/Identifier.java b/src/main/java/com/meta/chatbridge/Identifier.java
index 05716a2..88b1af0 100644
--- a/src/main/java/com/meta/chatbridge/Identifier.java
+++ b/src/main/java/com/meta/chatbridge/Identifier.java
@@ -11,6 +11,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Objects;
+import java.util.UUID;
import org.jetbrains.annotations.NotNull;
public class Identifier implements Comparable {
@@ -21,6 +22,11 @@ private Identifier(byte[] id) {
this.id = id;
}
+ public static Identifier random() {
+ UUID uuid = UUID.randomUUID();
+ return new Identifier(uuid.toString().getBytes(StandardCharsets.UTF_8));
+ }
+
public static Identifier from(String id) {
return new Identifier(id.getBytes(StandardCharsets.UTF_8));
}
diff --git a/src/main/java/com/meta/chatbridge/Pipeline.java b/src/main/java/com/meta/chatbridge/Pipeline.java
index afb4180..0afa19e 100644
--- a/src/main/java/com/meta/chatbridge/Pipeline.java
+++ b/src/main/java/com/meta/chatbridge/Pipeline.java
@@ -8,13 +8,14 @@
package com.meta.chatbridge;
-import com.meta.chatbridge.llm.LLMHandler;
+import com.meta.chatbridge.llm.LLMPlugin;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageHandler;
+import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.store.ChatStore;
-import com.meta.chatbridge.store.MessageStack;
import io.javalin.Javalin;
import io.javalin.http.Context;
+import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.*;
@@ -27,15 +28,15 @@ public class Pipeline {
private final ExecutorService executorService = Executors.newCachedThreadPool();
private final MessageHandler handler;
private final ChatStore store;
- private final LLMHandler llmHandler;
+ private final LLMPlugin llmPlugin;
private final String path;
public Pipeline(
- ChatStore store, MessageHandler handler, LLMHandler llmHandler, String path) {
+ ChatStore store, MessageHandler handler, LLMPlugin llmPlugin, String path) {
this.handler = Objects.requireNonNull(handler);
this.store = Objects.requireNonNull(store);
- this.llmHandler = llmHandler;
+ this.llmPlugin = llmPlugin;
this.path = path;
}
@@ -61,12 +62,18 @@ public MessageHandler messageHandler() {
}
private void execute(MessageStack stack) {
- T llmResponse = llmHandler.handle(stack);
+ T llmResponse;
+ try {
+ llmResponse = llmPlugin.handle(stack);
+ } catch (IOException e) {
+ LOGGER.error("failed to communicate with LLM", e);
+ return;
+ }
store.add(llmResponse);
try {
handler.respond(llmResponse);
} catch (Exception e) {
- LOGGER.error("failed to respond to user", e);
+ // we log in the handler where we have the body context
// TODO: create transactional store add
// TODO: implement retry with exponential backoff
}
diff --git a/src/main/java/com/meta/chatbridge/PipelinesRunner.java b/src/main/java/com/meta/chatbridge/PipelinesRunner.java
index 1acdb80..127d6be 100644
--- a/src/main/java/com/meta/chatbridge/PipelinesRunner.java
+++ b/src/main/java/com/meta/chatbridge/PipelinesRunner.java
@@ -29,8 +29,7 @@ public static PipelinesRunner newInstance() {
return new PipelinesRunner();
}
- @This
- public PipelinesRunner start() {
+ public @This PipelinesRunner start() {
if (!started) {
started = true;
app.start(port);
@@ -38,8 +37,7 @@ public PipelinesRunner start() {
return this;
}
- @This
- public PipelinesRunner pipeline(Pipeline> pipeline) {
+ public @This PipelinesRunner pipeline(Pipeline> pipeline) {
Preconditions.checkState(!started, "cannot add pipeline, server already started");
if (pipelines.add(pipeline)) {
pipeline.register(app);
@@ -64,8 +62,7 @@ public int port() {
* @param port the port the server will start on
* @return this
*/
- @This
- public PipelinesRunner port(int port) {
+ public @This PipelinesRunner port(int port) {
Preconditions.checkState(!started, "cannot change port, server already started");
this.port = port;
return this;
diff --git a/src/main/java/com/meta/chatbridge/llm/LLMHandler.java b/src/main/java/com/meta/chatbridge/llm/LLMPlugin.java
similarity index 59%
rename from src/main/java/com/meta/chatbridge/llm/LLMHandler.java
rename to src/main/java/com/meta/chatbridge/llm/LLMPlugin.java
index 5c5ce2b..1cb18c1 100644
--- a/src/main/java/com/meta/chatbridge/llm/LLMHandler.java
+++ b/src/main/java/com/meta/chatbridge/llm/LLMPlugin.java
@@ -9,9 +9,10 @@
package com.meta.chatbridge.llm;
import com.meta.chatbridge.message.Message;
-import com.meta.chatbridge.store.MessageStack;
+import com.meta.chatbridge.message.MessageStack;
+import java.io.IOException;
-public interface LLMHandler {
+public interface LLMPlugin {
- T handle(MessageStack messageStack);
+ T handle(MessageStack messageStack) throws IOException;
}
diff --git a/src/main/java/com/meta/chatbridge/llm/OpenAIConfig.java b/src/main/java/com/meta/chatbridge/llm/OpenAIConfig.java
new file mode 100644
index 0000000..b124786
--- /dev/null
+++ b/src/main/java/com/meta/chatbridge/llm/OpenAIConfig.java
@@ -0,0 +1,257 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.llm;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
+import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
+import com.google.common.base.Preconditions;
+import java.util.*;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.checkerframework.common.returnsreceiver.qual.This;
+
+@JsonDeserialize(builder = OpenAIConfig.Builder.class)
+public class OpenAIConfig {
+
+ private final OpenAIModel model;
+ private final String apiKey;
+ @Nullable private final Double temperature;
+ @Nullable private final Double topP;
+ private final List stop;
+ @Nullable private final Long maxOutputTokens;
+ @Nullable private final Double presencePenalty;
+ @Nullable private final Double frequencyPenalty;
+ private final Map logitBias;
+ private final @Nullable String systemMessage;
+
+ private final long maxInputTokens;
+
+ private OpenAIConfig(
+ OpenAIModel model,
+ String apiKey,
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ List stop,
+ @Nullable Long maxOutputTokens,
+ @Nullable Double presencePenalty,
+ @Nullable Double frequencyPenalty,
+ Map logitBias,
+ @Nullable String systemMessage,
+ long maxInputTokens) {
+ this.apiKey = apiKey;
+ this.temperature = temperature;
+ this.topP = topP;
+ this.model = model;
+ this.stop = stop;
+ this.maxOutputTokens = maxOutputTokens;
+ this.presencePenalty = presencePenalty;
+ this.frequencyPenalty = frequencyPenalty;
+ this.logitBias = Collections.unmodifiableMap(logitBias);
+ this.systemMessage = systemMessage;
+ this.maxInputTokens = maxInputTokens;
+ }
+
+ public static Builder builder(OpenAIModel model, String apiKey) {
+ return new Builder().model(model).apiKey(apiKey);
+ }
+
+ public OpenAIModel model() {
+ return model;
+ }
+
+ public String apiKey() {
+ return apiKey;
+ }
+
+ public Optional temperature() {
+ return Optional.ofNullable(temperature);
+ }
+
+ public Optional topP() {
+ return Optional.ofNullable(topP);
+ }
+
+ public List stop() {
+ return stop;
+ }
+
+ public Optional maxOutputTokens() {
+ return Optional.ofNullable(maxOutputTokens);
+ }
+
+ public Optional presencePenalty() {
+ return Optional.ofNullable(presencePenalty);
+ }
+
+ public Optional frequencyPenalty() {
+ return Optional.ofNullable(frequencyPenalty);
+ }
+
+ public Map logitBias() {
+ return logitBias;
+ }
+
+ public Optional systemMessage() {
+ return Optional.ofNullable(systemMessage);
+ }
+
+ public long maxInputTokens() {
+ return maxInputTokens;
+ }
+
+ @JsonPOJOBuilder(withPrefix = "")
+ public static class Builder {
+ private @Nullable OpenAIModel model;
+
+ @JsonProperty("api_key")
+ private @Nullable String apiKey;
+
+ private @Nullable Double temperature;
+
+ @JsonProperty("top_p")
+ private @Nullable Double topP;
+
+ private List stop = List.of();
+
+ @JsonProperty("max_output_tokens")
+ private @Nullable Long maxOutputTokens;
+
+ @JsonProperty("presence_penalty")
+ private @Nullable Double presencePenalty;
+
+ @JsonProperty("frequency_penalty")
+ private @Nullable Double frequencyPenalty;
+
+ @JsonProperty("logit_bias")
+ private Map logitBias = Collections.emptyMap();
+
+ @JsonProperty("system_message")
+ private @Nullable String systemMessage;
+
+ @JsonProperty("max_input_tokens")
+ private @Nullable Long maxInputTokens;
+
+ public @This Builder model(OpenAIModel model) {
+ this.model = model;
+ return this;
+ }
+
+ public @This Builder apiKey(String apiKey) {
+ Objects.requireNonNull(apiKey);
+ Preconditions.checkArgument(!apiKey.isBlank(), "api key cannot be empty");
+ this.apiKey = apiKey;
+ return this;
+ }
+
+ public @This Builder temperature(double temperature) {
+ Preconditions.checkArgument(
+ temperature >= 0 && temperature <= 2, "temperature must be >= 0 and <= 2");
+ this.temperature = temperature;
+ return this;
+ }
+
+ public @This Builder topP(double topP) {
+ Preconditions.checkArgument(topP > 0 && topP <= 1, "top_p must be > 0 and <= 1");
+ this.topP = topP;
+ return this;
+ }
+
+ public @This Builder stop(List stop) {
+ Objects.requireNonNull(stop);
+ this.stop = Collections.unmodifiableList(stop);
+ return this;
+ }
+
+ public @This Builder maxOutputTokens(long maxOutputTokens) {
+ Preconditions.checkArgument(
+ maxOutputTokens > 0, "max_output_tokens must be greater than zero");
+ this.maxOutputTokens = maxOutputTokens;
+ return this;
+ }
+
+ public @This Builder presencePenalty(double presencePenalty) {
+ Preconditions.checkArgument(
+ presencePenalty >= -2.0 && presencePenalty <= 2.0,
+ "presence_penalty must be between -2.0 and 2.0");
+ this.presencePenalty = presencePenalty;
+ return this;
+ }
+
+ public @This Builder frequencyPenalty(double frequencyPenalty) {
+ Preconditions.checkArgument(
+ frequencyPenalty >= -2.0 && frequencyPenalty <= 2.0,
+ "frequency_penalty must be between -2.0 and 2.0");
+ this.frequencyPenalty = frequencyPenalty;
+ return this;
+ }
+
+ public @This Builder logitBias(Map logitBias) {
+ Preconditions.checkArgument(
+ logitBias.values().stream().allMatch(v -> v >= -100 && v <= 100),
+ "all values for log_bias must be between -100 and 100");
+ this.logitBias = logitBias;
+ return this;
+ }
+
+ public @This Builder systemMessage(String systemMessage) {
+ Preconditions.checkArgument(!systemMessage.isBlank(), "system_message cannot be blank");
+ this.systemMessage = systemMessage;
+ return this;
+ }
+
+ public @This Builder maxInputTokens(long maxInputTokens) {
+ Preconditions.checkArgument(maxInputTokens > 0, "max_input_tokens must be greater than zero");
+ this.maxInputTokens = maxInputTokens;
+ return this;
+ }
+
+ public OpenAIConfig build() {
+ Objects.requireNonNull(model, "model is a required parameter");
+ Objects.requireNonNull(apiKey, "api_key is a required parameter");
+ if (maxOutputTokens != null) {
+ Preconditions.checkArgument(
+ maxOutputTokens <= model.properties().tokenLimit(),
+ "max_tokens must be <= "
+ + model.properties().tokenLimit()
+ + ", the maximum tokens allowed for the selected model '"
+ + model.properties().name()
+ + "'");
+ }
+ if (maxInputTokens == null) {
+ if (maxOutputTokens == null) {
+ // set the max input size to 50% of the total context size so that there is always some
+ // room for the output
+ maxInputTokens = (long) (model.properties().tokenLimit() * 0.50);
+ } else {
+ maxInputTokens = model.properties().tokenLimit() - maxOutputTokens;
+ }
+ }
+
+ Preconditions.checkArgument(
+ maxInputTokens + (maxOutputTokens == null ? 0 : maxOutputTokens)
+ <= model.properties().tokenLimit(),
+ "max_input_tokens + max_output_tokens must total to be less than or equal to "
+ + model.properties().tokenLimit()
+ + ", the total context tokens allowed by this model");
+
+ return new OpenAIConfig(
+ model,
+ apiKey,
+ temperature,
+ topP,
+ stop,
+ maxOutputTokens,
+ presencePenalty,
+ frequencyPenalty,
+ logitBias,
+ systemMessage,
+ maxInputTokens);
+ }
+ }
+}
diff --git a/src/main/java/com/meta/chatbridge/llm/OpenAIModel.java b/src/main/java/com/meta/chatbridge/llm/OpenAIModel.java
new file mode 100644
index 0000000..977eb2c
--- /dev/null
+++ b/src/main/java/com/meta/chatbridge/llm/OpenAIModel.java
@@ -0,0 +1,48 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.llm;
+
+import com.knuddels.jtokkit.api.ModelType;
+
+public enum OpenAIModel {
+ GPT4 {
+ OpenAIModelProperties properties() {
+ return new OpenAIModelProperties("gpt-4", 8_192, ModelType.GPT_4);
+ }
+ },
+
+ GPT432K {
+ @Override
+ OpenAIModelProperties properties() {
+ return new OpenAIModelProperties("gpt-4-32k", 32_768, ModelType.GPT_4_32K);
+ }
+ },
+ GPT35TURBO {
+ @Override
+ OpenAIModelProperties properties() {
+ return new OpenAIModelProperties("gpt-3.5-turbo", 4_096, ModelType.GPT_3_5_TURBO);
+ }
+ },
+
+ GPT35TURBO16K {
+ @Override
+ OpenAIModelProperties properties() {
+ return new OpenAIModelProperties("gpt-3.5-turbo-16k", 16_384, ModelType.GPT_3_5_TURBO_16K);
+ }
+ };
+
+ @Override
+ public String toString() {
+ return this.properties().name();
+ }
+
+ abstract OpenAIModelProperties properties();
+
+ public record OpenAIModelProperties(String name, long tokenLimit, ModelType jtokkinModel) {}
+}
diff --git a/src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java b/src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java
new file mode 100644
index 0000000..1600fb2
--- /dev/null
+++ b/src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java
@@ -0,0 +1,191 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.llm;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import com.knuddels.jtokkit.Encodings;
+import com.knuddels.jtokkit.api.Encoding;
+import com.meta.chatbridge.message.Message;
+import com.meta.chatbridge.message.Message.Role;
+import com.meta.chatbridge.message.MessageStack;
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.time.Instant;
+import java.util.Optional;
+import org.apache.hc.client5.http.fluent.Request;
+import org.apache.hc.client5.http.fluent.Response;
+import org.apache.hc.core5.http.ContentType;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.checkerframework.common.returnsreceiver.qual.This;
+import org.jetbrains.annotations.TestOnly;
+
+public class OpenAIPlugin implements LLMPlugin {
+
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+ private static final String ENDPOINT = "https://api.openai.com/v1/chat/completions";
+ private final OpenAIConfig config;
+ private final Encoding tokenEncoding;
+ private final int tokensPerMessage;
+ private final int tokensPerName;
+ private URI endpoint;
+
+ public OpenAIPlugin(OpenAIConfig config) {
+ this.config = config;
+
+ try {
+ this.endpoint = new URI(ENDPOINT);
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e); // this should be impossible
+ }
+ tokenEncoding =
+ Encodings.newDefaultEncodingRegistry()
+ .getEncodingForModel(config.model().properties().jtokkinModel());
+
+ switch (config.model()) {
+ case GPT4, GPT432K -> {
+ tokensPerMessage = 3;
+ tokensPerName = 1;
+ }
+ case GPT35TURBO, GPT35TURBO16K -> {
+ tokensPerMessage = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n
+ tokensPerName = -1; // if there's a name, the role is omitted
+ }
+ default -> throw new IllegalArgumentException("Unsupported model: " + config.model());
+ }
+ }
+
+ @TestOnly
+ public @This OpenAIPlugin endpoint(URI endpoint) {
+ this.endpoint = endpoint;
+ return this;
+ }
+
+ private int tokenCount(JsonNode message) {
+ int tokenCount = tokensPerMessage;
+ tokenCount += tokenEncoding.countTokens(message.get("content").textValue());
+ tokenCount += tokenEncoding.countTokens(message.get("role").textValue());
+ @Nullable JsonNode name = message.get("name");
+ if (name != null) {
+ tokenCount += tokenEncoding.countTokens(name.textValue());
+ tokenCount += tokensPerName;
+ }
+ return tokenCount;
+ }
+
+ private Optional pruneMessages(ArrayNode messages, @Nullable JsonNode functions)
+ throws JsonProcessingException {
+
+ int functionTokens = 0;
+ if (functions != null) {
+ // This is honestly a guess, it's undocumented
+ functionTokens = tokenEncoding.countTokens(MAPPER.writeValueAsString(functions));
+ }
+
+ ArrayNode output = MAPPER.createArrayNode();
+ int totalTokens = functionTokens;
+ totalTokens += 3; // every reply is primed with <|start|>assistant<|message|>
+
+ JsonNode systemMessage = messages.get(0);
+ boolean hasSystemMessage = systemMessage.get("role").textValue().equals("system");
+ if (hasSystemMessage) {
+ // if the system message is present it's required
+ totalTokens += tokenCount(messages.get(0));
+ }
+ for (int i = messages.size() - 1; i >= 0; i--) {
+ JsonNode m = messages.get(i);
+ String role = m.get("role").textValue();
+ if (role.equals("system")) {
+ continue; // system has already been counted
+ }
+ totalTokens += tokenCount(m);
+ if (totalTokens > config.maxInputTokens()) {
+ break;
+ }
+ output.insert(0, m);
+ }
+ if (hasSystemMessage) {
+ output.insert(0, systemMessage);
+ }
+
+ if ((hasSystemMessage && output.size() <= 1) || output.isEmpty()) {
+ return Optional.empty();
+ }
+
+ return Optional.of(output);
+ }
+
+ @Override
+ public T handle(MessageStack messageStack) throws IOException {
+ T fromUser = messageStack.tail();
+
+ ObjectNode body = MAPPER.createObjectNode();
+ body.put("model", config.model().properties().name())
+ // .put("function_call", "auto") // Update when we support functions
+ .put("n", 1)
+ .put("stream", false)
+ .put("user", fromUser.senderId().toString());
+ config.topP().ifPresent(v -> body.put("top_p", v));
+ config.temperature().ifPresent(v -> body.put("temperature", v));
+ config.maxOutputTokens().ifPresent(v -> body.put("max_tokens", v));
+ config.presencePenalty().ifPresent(v -> body.put("presence_penalty", v));
+ config.frequencyPenalty().ifPresent(v -> body.put("frequency_penalty", v));
+ if (!config.logitBias().isEmpty()) {
+ body.set("logit_bias", MAPPER.valueToTree(config.logitBias()));
+ }
+ if (!config.stop().isEmpty()) {
+ body.set("stop", MAPPER.valueToTree(config.stop()));
+ }
+
+ ArrayNode messages = MAPPER.createArrayNode();
+ config
+ .systemMessage()
+ .ifPresent(
+ m ->
+ messages
+ .addObject()
+ .put("role", Role.SYSTEM.toString().toLowerCase())
+ .put("content", m));
+ for (T message : messageStack.messages()) {
+ messages
+ .addObject()
+ .put("role", message.role().toString().toLowerCase())
+ .put("content", message.message());
+ }
+
+ Optional prunedMessages = pruneMessages(messages, null);
+ if (prunedMessages.isEmpty()) {
+ return messageStack.newMessageFromBot(
+ Instant.now(), "I'm sorry but that request was too long for me.");
+ }
+ body.set("messages", prunedMessages.get());
+
+ String bodyString;
+ try {
+ bodyString = MAPPER.writeValueAsString(body);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e); // this should be impossible
+ }
+ Response response =
+ Request.post(endpoint)
+ .bodyString(bodyString, ContentType.APPLICATION_JSON)
+ .setHeader("Authorization", "Bearer " + config.apiKey())
+ .execute();
+
+ JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes());
+ Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
+ JsonNode choice = responseBody.get("choices").get(0);
+ String messageContent = choice.get("message").get("content").textValue();
+ return messageStack.newMessageFromBot(timestamp, messageContent);
+ }
+}
diff --git a/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java b/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java
index 51ccfc9..6a9da33 100644
--- a/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java
+++ b/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java
@@ -26,6 +26,7 @@
import java.time.Instant;
import java.util.*;
import java.util.function.Function;
+import java.util.stream.Stream;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.apache.hc.client5.http.fluent.Request;
@@ -73,11 +74,11 @@ public FBMessageHandler(String verifyToken, String pageAccessToken, String appSe
this.accessToken = pageAccessToken;
}
- @TestOnly
- @This
- FBMessageHandler baseURLFactory(Function baseURLFactory) {
- this.baseURLFactory = Objects.requireNonNull(baseURLFactory);
- return this;
+ private static Stream textChunker(String text, String regexSeparator) {
+ if (text.length() > 2000) {
+ return Arrays.stream(text.split(regexSeparator)).map(String::strip);
+ }
+ return Stream.of(text);
}
@Override
@@ -213,19 +214,37 @@ private List postHandler(Context ctx) throws JsonProcessingException
return output;
}
+ @TestOnly
+ public @This FBMessageHandler baseURLFactory(Function baseURLFactory) {
+ this.baseURLFactory = Objects.requireNonNull(baseURLFactory);
+ return this;
+ }
+
@Override
public void respond(FBMessage message) throws IOException {
+ List chunkedText =
+ Stream.of(message.message().strip())
+ .flatMap(m -> textChunker(m, "\n\n\n+"))
+ .flatMap(m -> textChunker(m, "\n\n"))
+ .flatMap(m -> textChunker(m, "\n"))
+ .flatMap(m -> textChunker(m, "\\. +"))
+ .flatMap(m -> textChunker(m, " +"))
+ .toList();
+ for (String text : chunkedText) {
+ send(text, message.recipientId(), message.senderId());
+ }
+ }
+
+ private void send(String message, Identifier recipient, Identifier sender) throws IOException {
URI url;
ObjectNode body = MAPPER.createObjectNode();
- body.put("messaging_type", "RESPONSE")
- .putObject("recipient")
- .put("id", message.recipientId().toString());
- body.putObject("message").put("text", message.message());
+ body.put("messaging_type", "RESPONSE").putObject("recipient").put("id", recipient.toString());
+ body.putObject("message").put("text", message);
String bodyString;
try {
bodyString = MAPPER.writeValueAsString(body);
url =
- new URIBuilder(baseURLFactory.apply(message.senderId()))
+ new URIBuilder(baseURLFactory.apply(sender))
.addParameter("access_token", accessToken)
.build();
} catch (JsonProcessingException | URISyntaxException e) {
@@ -237,6 +256,13 @@ public void respond(FBMessage message) throws IOException {
Request.post(url).bodyString(bodyString, ContentType.APPLICATION_JSON).execute();
HttpResponse responseContent = response.returnResponse();
if (responseContent.getCode() != 200) {
+ String errorMessage =
+ "received a "
+ + responseContent.getCode()
+ + " error code when attempting to reply. "
+ + responseContent.getReasonPhrase();
+
+ LOGGER.atError().addKeyValue("body", bodyString).setMessage(errorMessage).log();
throw new IOException(
"received a "
+ responseContent.getCode()
diff --git a/src/main/java/com/meta/chatbridge/message/Message.java b/src/main/java/com/meta/chatbridge/message/Message.java
index 7ff82b4..cdebcb3 100644
--- a/src/main/java/com/meta/chatbridge/message/Message.java
+++ b/src/main/java/com/meta/chatbridge/message/Message.java
@@ -31,10 +31,14 @@ enum Role {
SYSTEM
}
- default Identifier conversationId() {
- if (senderId().compareTo(recipientId()) <= 0) {
- return Identifier.from(senderId().toString() + '|' + recipientId());
+ static Identifier conversationId(Identifier id1, Identifier id2) {
+ if (id1.compareTo(id2) <= 0) {
+ return Identifier.from(id1.toString() + '|' + id2);
}
- return Identifier.from(recipientId().toString() + '|' + senderId());
+ return Identifier.from(id2.toString() + '|' + id1);
+ }
+
+ default Identifier conversationId() {
+ return conversationId(senderId(), recipientId());
}
}
diff --git a/src/main/java/com/meta/chatbridge/message/MessageFactory.java b/src/main/java/com/meta/chatbridge/message/MessageFactory.java
new file mode 100644
index 0000000..e5c2adb
--- /dev/null
+++ b/src/main/java/com/meta/chatbridge/message/MessageFactory.java
@@ -0,0 +1,51 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.message;
+
+import com.meta.chatbridge.Identifier;
+import com.meta.chatbridge.message.Message.Role;
+import java.time.Instant;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+@FunctionalInterface
+public interface MessageFactory {
+ Map, MessageFactory extends Message>> FACTORY_MAP =
+ Stream.>of(
+ new FactoryContainer<>(
+ FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r)))
+ .collect(
+ Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory));
+
+ static MessageFactory instance(Class clazz) {
+ @SuppressWarnings("unchecked") // static map guarantees this to be true
+ MessageFactory factory = (MessageFactory) FACTORY_MAP.get(clazz);
+ Objects.requireNonNull(factory, clazz + " does not have a registered factory");
+ return factory;
+ }
+
+ static MessageFactory instance(T message) {
+ @SuppressWarnings("unchecked") // class of an object is its class :)
+ Class clazz = (Class) message.getClass();
+ return instance(clazz);
+ }
+
+ T newMessage(
+ Instant timestamp,
+ String message,
+ Identifier senderId,
+ Identifier recipientId,
+ Identifier instanceId,
+ Role role);
+
+ /** this exists to provide compiler guarantees for type matching in the FACTORY_MAP */
+ record FactoryContainer(Class clazz, MessageFactory factory) {}
+}
diff --git a/src/main/java/com/meta/chatbridge/message/MessageStack.java b/src/main/java/com/meta/chatbridge/message/MessageStack.java
new file mode 100644
index 0000000..ee4a801
--- /dev/null
+++ b/src/main/java/com/meta/chatbridge/message/MessageStack.java
@@ -0,0 +1,92 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.message;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.meta.chatbridge.Identifier;
+import com.meta.chatbridge.message.Message.Role;
+import java.time.Instant;
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class MessageStack {
+ private final List messages;
+ private final MessageFactory messageFactory;
+
+ private MessageStack(T message) {
+ Objects.requireNonNull(message);
+ this.messages = ImmutableList.of(message);
+ messageFactory = MessageFactory.instance(message);
+ }
+
+ /** Constructor that exists to support the with method */
+ private MessageStack(MessageStack old, T newMessage) {
+ Objects.requireNonNull(newMessage);
+ messageFactory = old.messageFactory;
+ Preconditions.checkArgument(
+ old.tail().conversationId().equals(newMessage.conversationId()),
+ "all messages in a stack must have the same conversation id");
+ List messages = old.messages;
+ if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
+ this.messages =
+ Stream.concat(messages.stream(), Stream.of(newMessage))
+ .sorted(Comparator.comparing(Message::timestamp))
+ .collect(Collectors.toUnmodifiableList());
+ } else {
+ this.messages = ImmutableList.builder().addAll(messages).add(newMessage).build();
+ }
+
+ Preconditions.checkArgument(
+ old.userId().equals(userId()) && old.botId().equals(botId()),
+ "userId and botId not consistent with this message stack");
+ }
+
+ public static MessageStack of(T message) {
+ return new MessageStack<>(message);
+ }
+
+ public Identifier userId() {
+ T message = tail();
+ return switch (message.role()) {
+ case ASSISTANT, SYSTEM -> message.recipientId();
+ case USER -> message.senderId();
+ };
+ }
+
+ public Identifier botId() {
+ T message = tail();
+ return switch (message.role()) {
+ case ASSISTANT, SYSTEM -> message.senderId();
+ case USER -> message.recipientId();
+ };
+ }
+
+ public T newMessageFromBot(Instant timestamp, String message) {
+ return messageFactory.newMessage(
+ timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
+ }
+
+ public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
+ return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
+ }
+
+ public MessageStack with(T message) {
+ return new MessageStack<>(this, message);
+ }
+
+ public List messages() {
+ return messages;
+ }
+
+ public T tail() {
+ return messages.get(messages.size() - 1);
+ }
+}
diff --git a/src/main/java/com/meta/chatbridge/store/ChatStore.java b/src/main/java/com/meta/chatbridge/store/ChatStore.java
index 07d09a8..428c70a 100644
--- a/src/main/java/com/meta/chatbridge/store/ChatStore.java
+++ b/src/main/java/com/meta/chatbridge/store/ChatStore.java
@@ -9,6 +9,7 @@
package com.meta.chatbridge.store;
import com.meta.chatbridge.message.Message;
+import com.meta.chatbridge.message.MessageStack;
/**
* This class is in charge of both maintaining a chat history and managing a queue of conversations
diff --git a/src/main/java/com/meta/chatbridge/store/MemoryStore.java b/src/main/java/com/meta/chatbridge/store/MemoryStore.java
index 36863f4..36385ef 100644
--- a/src/main/java/com/meta/chatbridge/store/MemoryStore.java
+++ b/src/main/java/com/meta/chatbridge/store/MemoryStore.java
@@ -12,6 +12,7 @@
import com.google.common.cache.CacheBuilder;
import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.Message;
+import com.meta.chatbridge.message.MessageStack;
import java.time.Duration;
public class MemoryStore implements ChatStore {
diff --git a/src/main/java/com/meta/chatbridge/store/MessageStack.java b/src/main/java/com/meta/chatbridge/store/MessageStack.java
deleted file mode 100644
index d55607a..0000000
--- a/src/main/java/com/meta/chatbridge/store/MessageStack.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- *
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-package com.meta.chatbridge.store;
-
-import com.google.common.collect.ImmutableList;
-import com.meta.chatbridge.message.Message;
-import java.util.*;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-public class MessageStack {
-
- private final List messages;
-
- private MessageStack(Collection messages) {
- Objects.requireNonNull(messages);
- this.messages =
- messages.stream()
- .sorted(Comparator.comparing(Message::timestamp))
- .collect(Collectors.toUnmodifiableList());
- }
-
- /** Constructor that exists to support the with method */
- private MessageStack(List messages, T newMessage) {
- if (!messages.isEmpty()
- && newMessage.timestamp().isBefore(messages.get(messages.size() - 1).timestamp())) {
- this.messages =
- Stream.concat(messages.stream(), Stream.of(newMessage))
- .sorted(Comparator.comparing(Message::timestamp))
- .collect(Collectors.toUnmodifiableList());
- } else {
- this.messages = ImmutableList.builder().addAll(messages).add(newMessage).build();
- }
- }
-
- public static MessageStack of(T message) {
- return new MessageStack<>(List.of(message));
- }
-
- public static MessageStack of(Collection messages) {
- return new MessageStack<>(messages);
- }
-
- public MessageStack with(T message) {
- return new MessageStack<>(messages, message);
- }
-
- public List messages() {
- return messages;
- }
-}
diff --git a/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java b/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java
index 5548b6e..689d3fe 100644
--- a/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java
+++ b/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java
@@ -11,12 +11,12 @@
import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.message.Message;
-import com.meta.chatbridge.store.MessageStack;
+import com.meta.chatbridge.message.MessageStack;
import java.time.Instant;
import java.util.concurrent.*;
import org.checkerframework.checker.nullness.qual.Nullable;
-public class DummyFBMessageLLMHandler implements LLMHandler {
+public class DummyFBMessageLLMHandler implements LLMPlugin {
private final String dummyLLMResponse;
private final BlockingQueue> receivedMessageStacks =
@@ -51,7 +51,10 @@ public String dummyResponse() {
public FBMessage handle(MessageStack messageStack) {
receivedMessageStacks.add(messageStack);
FBMessage inbound =
- messageStack.messages().stream().filter(m -> m.role() == Message.Role.USER).findAny().get();
+ messageStack.messages().stream()
+ .filter(m -> m.role() == Message.Role.USER)
+ .findAny()
+ .orElseThrow();
return new FBMessage(
Instant.now(),
Identifier.from("test_message"),
diff --git a/src/test/java/com/meta/chatbridge/llm/OpenAIConfigTest.java b/src/test/java/com/meta/chatbridge/llm/OpenAIConfigTest.java
new file mode 100644
index 0000000..99db5ec
--- /dev/null
+++ b/src/test/java/com/meta/chatbridge/llm/OpenAIConfigTest.java
@@ -0,0 +1,182 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.llm;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.*;
+import com.google.common.collect.ImmutableList;
+import com.meta.chatbridge.Configuration;
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Stream;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+class OpenAIConfigTest {
+
+ private static final ObjectMapper MAPPER = Configuration.MAPPER;
+ static final Collection CONFIG_ITEMS =
+ ImmutableList.of(
+ new ConfigItem("model", true, TextNode.valueOf("gpt-4"), List.of(TextNode.valueOf("n"))),
+ new ConfigItem(
+ "api_key",
+ true,
+ TextNode.valueOf("notempty"),
+ List.of(TextNode.valueOf(""), TextNode.valueOf(" "))),
+ new ConfigItem(
+ "temperature",
+ false,
+ DoubleNode.valueOf(1),
+ List.of(DoubleNode.valueOf(-0.1), DoubleNode.valueOf(2.1))),
+ new ConfigItem(
+ "top_p",
+ false,
+ DoubleNode.valueOf(0.5),
+ List.of(DoubleNode.valueOf(0), DoubleNode.valueOf(1.1))),
+ new ConfigItem(
+ "stop",
+ false,
+ MAPPER.createArrayNode().add("10").add("stop"),
+ List.of(
+ TextNode.valueOf("I'm a stop"),
+ DoubleNode.valueOf(10),
+ MAPPER.createArrayNode().addObject())),
+ new ConfigItem(
+ "max_output_tokens",
+ false,
+ LongNode.valueOf(100),
+ List.of(
+ LongNode.valueOf(0),
+ LongNode.valueOf(OpenAIModel.GPT4.properties().tokenLimit() + 1))),
+ new ConfigItem(
+ "presence_penalty",
+ false,
+ DoubleNode.valueOf(0),
+ List.of(DoubleNode.valueOf(-2.1), DoubleNode.valueOf(2.1))),
+ new ConfigItem(
+ "frequency_penalty",
+ false,
+ DoubleNode.valueOf(0),
+ List.of(DoubleNode.valueOf(-2.1), DoubleNode.valueOf(2.1))),
+ new ConfigItem(
+ "logit_bias",
+ false,
+ MAPPER.createObjectNode().put("1", 0.5),
+ List.of(
+ MAPPER.createObjectNode().put("1", -101),
+ MAPPER.createObjectNode().put("1", 101))),
+ new ConfigItem(
+ "system_message",
+ false,
+ TextNode.valueOf("you're a helpful assistant"),
+ List.of(TextNode.valueOf(""), TextNode.valueOf(" "))),
+ new ConfigItem(
+ "max_input_tokens",
+ false,
+ LongNode.valueOf(4000),
+ List.of(LongNode.valueOf(-1), LongNode.valueOf(100_000))));
+ private ObjectNode minimalConfig;
+
+ static Stream configItems() {
+ return CONFIG_ITEMS.stream();
+ }
+
+ static Stream invalidValues() {
+ return configItems()
+ .flatMap(c -> c.invalidValues().stream().map(t -> Arguments.of(c.key(), t)));
+ }
+
+ static Stream requiredKeys() {
+ return configItems().filter(ConfigItem::required).map(ConfigItem::key);
+ }
+
+ @BeforeEach
+ void setUp() {
+ minimalConfig = MAPPER.createObjectNode();
+ CONFIG_ITEMS.forEach(
+ t -> {
+ if (t.required()) {
+ minimalConfig.set(t.key(), t.validValue());
+ }
+ });
+ }
+
+ @Test
+ void maximalValidConfig() throws JsonProcessingException {
+ ObjectNode body = MAPPER.createObjectNode();
+ CONFIG_ITEMS.forEach(t -> body.set(t.key(), t.validValue()));
+ OpenAIConfig config = MAPPER.readValue(MAPPER.writeValueAsString(body), OpenAIConfig.class);
+ assertThat(config.model()).isEqualTo(OpenAIModel.GPT4);
+ assertThat(config.temperature().isPresent()).isTrue();
+ assertThat(config.frequencyPenalty().isPresent()).isTrue();
+ assertThat(config.topP().isPresent()).isTrue();
+ assertThat(config.maxOutputTokens().isPresent()).isTrue();
+ assertThat(config.presencePenalty().isPresent()).isTrue();
+ assertThat(config.frequencyPenalty().isPresent()).isTrue();
+ assertThat(config.logitBias().isEmpty()).isFalse();
+ }
+
+ @Test
+ void minimalValidConfig() throws JsonProcessingException {
+ ObjectNode body = MAPPER.createObjectNode();
+ CONFIG_ITEMS.forEach(
+ t -> {
+ if (t.required()) {
+ body.set(t.key(), t.validValue());
+ }
+ });
+ OpenAIConfig config = MAPPER.readValue(MAPPER.writeValueAsString(body), OpenAIConfig.class);
+ assertThat(config.model()).isEqualTo(OpenAIModel.GPT4);
+ assertThat(config.temperature().isEmpty()).isTrue();
+ assertThat(config.frequencyPenalty().isEmpty()).isTrue();
+ assertThat(config.topP().isEmpty()).isTrue();
+ assertThat(config.maxOutputTokens().isEmpty()).isTrue();
+ assertThat(config.presencePenalty().isEmpty()).isTrue();
+ assertThat(config.frequencyPenalty().isEmpty()).isTrue();
+ assertThat(config.logitBias().isEmpty()).isTrue();
+ }
+
+ @ParameterizedTest
+ @MethodSource("configItems")
+ void nullValues(ConfigItem item) throws JsonProcessingException {
+ minimalConfig.putNull(item.key());
+ String bodyString = MAPPER.writeValueAsString(minimalConfig);
+ assertThatThrownBy(() -> MAPPER.readValue(bodyString, OpenAIConfig.class))
+ .isInstanceOf(Exception.class);
+ }
+
+ @ParameterizedTest
+ @MethodSource("invalidValues")
+ void invalidValues(String key, JsonNode value) throws JsonProcessingException {
+ minimalConfig.set(key, value);
+ String bodyString = MAPPER.writeValueAsString(minimalConfig);
+ assertThatThrownBy(() -> MAPPER.readValue(bodyString, OpenAIConfig.class))
+ .isInstanceOf(Exception.class);
+ }
+
+ @ParameterizedTest
+ @MethodSource("requiredKeys")
+ void requiredKeysMissing(String key) throws JsonProcessingException {
+ minimalConfig.remove(key);
+ String bodyString = MAPPER.writeValueAsString(minimalConfig);
+ assertThatThrownBy(() -> MAPPER.readValue(bodyString, OpenAIConfig.class))
+ .isInstanceOf(Exception.class);
+ }
+
+ record ConfigItem(
+ String key, boolean required, JsonNode validValue, List invalidValues) {}
+}
diff --git a/src/test/java/com/meta/chatbridge/llm/OpenAIModelTest.java b/src/test/java/com/meta/chatbridge/llm/OpenAIModelTest.java
new file mode 100644
index 0000000..09203ed
--- /dev/null
+++ b/src/test/java/com/meta/chatbridge/llm/OpenAIModelTest.java
@@ -0,0 +1,35 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.llm;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.*;
+
+import java.util.Arrays;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.junit.jupiter.api.Test;
+
+class OpenAIModelTest {
+
+ /** This is super important because the JsonDeserializer relies on it */
+ @Test
+ void namesAreUnique() {
+ Set uniqueElements =
+ Arrays.stream(OpenAIModel.values()).map(OpenAIModel::toString).collect(Collectors.toSet());
+ assertThat(OpenAIModel.values()).extracting(Enum::toString).hasSize(uniqueElements.size());
+ }
+
+ /** This is super important because the JsonDeserializer relies on it */
+ @Test
+ void toStringMatchesName() {
+ assertThat(OpenAIModel.values())
+ .allSatisfy(v -> assertThat(v.properties().name()).isEqualTo(v.toString()));
+ }
+}
diff --git a/src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java b/src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java
new file mode 100644
index 0000000..f6174d7
--- /dev/null
+++ b/src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java
@@ -0,0 +1,269 @@
+/*
+ *
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.meta.chatbridge.llm;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.SoftAssertions.assertSoftly;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import com.meta.chatbridge.Configuration;
+import com.meta.chatbridge.Identifier;
+import com.meta.chatbridge.Pipeline;
+import com.meta.chatbridge.PipelinesRunner;
+import com.meta.chatbridge.message.*;
+import com.meta.chatbridge.message.Message.Role;
+import com.meta.chatbridge.store.ChatStore;
+import com.meta.chatbridge.store.MemoryStore;
+import io.javalin.Javalin;
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.UnknownHostException;
+import java.time.Instant;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.hc.client5.http.fluent.Request;
+import org.apache.hc.core5.http.HttpResponse;
+import org.apache.hc.core5.net.URIBuilder;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
+import org.junit.jupiter.params.provider.MethodSource;
+
+public class OpenAIPluginTest {
+
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+ public static final JsonNode SAMPLE_RESPONSE = MAPPER.createObjectNode();
+ private static final String PATH = "/";
+ private static final String TEST_MESSAGE = "this is a test message";
+ private static final MessageStack STACK =
+ MessageStack.of(
+ MessageFactory.instance(FBMessage.class)
+ .newMessage(
+ Instant.now(),
+ "test message",
+ Identifier.random(),
+ Identifier.random(),
+ Identifier.random(),
+ Role.USER));
+
+ static {
+ ((ObjectNode) SAMPLE_RESPONSE)
+ .put("created", Instant.now().getEpochSecond())
+ .put("object", "chat.completion")
+ .put("id", UUID.randomUUID().toString())
+ .putArray("choices")
+ .addObject()
+ .put("index", 0)
+ .put("finish_reason", "stop")
+ .putObject("message")
+ .put("role", "assistant")
+ .put("content", TEST_MESSAGE);
+ }
+
+ private BlockingQueue openAIRequests;
+ private Javalin app;
+ private URI endpoint;
+ private ObjectNode minimalConfig;
+
+ static Stream modelOptions() {
+ Set non_model_options = Set.of("model", "api_key", "max_input_tokens");
+ return OpenAIConfigTest.CONFIG_ITEMS.stream().filter(c -> !non_model_options.contains(c.key()));
+ }
+
+ @BeforeEach
+ void setUp() throws UnknownHostException, URISyntaxException {
+ openAIRequests = new LinkedBlockingDeque<>();
+ app = Javalin.create();
+ app.before(
+ PATH,
+ ctx ->
+ openAIRequests.add(
+ new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap())));
+ app.post(PATH, ctx -> ctx.result(MAPPER.writeValueAsString(SAMPLE_RESPONSE)));
+ app.start(0);
+ endpoint =
+ URIBuilder.localhost().setScheme("http").appendPath(PATH).setPort(app.port()).build();
+ }
+
+ @ParameterizedTest
+ @EnumSource(OpenAIModel.class)
+ void sampleValid(OpenAIModel model) throws IOException, InterruptedException {
+ String apiKey = UUID.randomUUID().toString();
+ OpenAIConfig config = OpenAIConfig.builder(model, apiKey).build();
+ OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint);
+ FBMessage message = plugin.handle(STACK);
+ assertThat(message.message()).isEqualTo(TEST_MESSAGE);
+ assertThat(message.role()).isSameAs(Role.ASSISTANT);
+ assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
+ @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
+ assertThat(or).isNotNull();
+ assertThat(or.headerMap().get("Authorization")).isNotNull().isEqualTo("Bearer " + apiKey);
+ assertThat(MAPPER.readTree(or.body()).get("model").textValue()).isEqualTo(model.toString());
+ }
+
+ @BeforeEach
+ void setUpMinConfig() {
+ minimalConfig = MAPPER.createObjectNode();
+ OpenAIConfigTest.CONFIG_ITEMS.forEach(
+ t -> {
+ if (t.required()) {
+ minimalConfig.set(t.key(), t.validValue());
+ }
+ });
+ }
+
+ @ParameterizedTest
+ @MethodSource("modelOptions")
+ void validConfigValues(OpenAIConfigTest.ConfigItem configItem)
+ throws IOException, InterruptedException {
+ minimalConfig.set(configItem.key(), configItem.validValue());
+ OpenAIConfig config = Configuration.MAPPER.convertValue(minimalConfig, OpenAIConfig.class);
+ OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint);
+ FBMessage message = plugin.handle(STACK);
+ assertThat(message.message()).isEqualTo(TEST_MESSAGE);
+ assertThat(message.role()).isSameAs(Role.ASSISTANT);
+ assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
+ @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
+ assertThat(or).isNotNull();
+ assertThat(or.headerMap().get("Authorization"))
+ .isNotNull()
+ .isEqualTo("Bearer " + config.apiKey());
+ JsonNode body = Configuration.MAPPER.readTree(or.body());
+ assertThat(body.get("model").textValue()).isEqualTo(config.model().toString());
+ if (configItem.key().equals("system_message")) {
+ assertThat(body.get("messages"))
+ .satisfiesOnlyOnce(
+ m -> {
+ assertThat(m.get("role").textValue())
+ .isEqualTo(Role.SYSTEM.toString().toLowerCase());
+ assertThat(m.get("content").textValue())
+ .isEqualTo(configItem.validValue().textValue());
+ });
+ } else {
+ if (configItem.key().equals("max_output_tokens")) {
+ assertThat(body.get("max_tokens")).isEqualTo(minimalConfig.get(configItem.key()));
+ } else {
+ assertThat(body.get(configItem.key())).isEqualTo(minimalConfig.get(configItem.key()));
+ }
+ }
+ }
+
+ @Test
+ void contextTooBig() throws IOException {
+ OpenAIConfig config =
+ OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build();
+ OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint);
+ MessageStack stack =
+ STACK.with(
+ STACK.newMessageFromUser(
+ Instant.now(),
+ Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()),
+ Identifier.random()));
+ FBMessage response = plugin.handle(stack);
+ assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me.");
+ assertThat(openAIRequests).hasSize(0);
+ }
+
+ @Test
+ void orderedCorrectly() throws IOException, InterruptedException {
+ OpenAIConfig config =
+ OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build();
+ OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint);
+ MessageStack stack =
+ MessageStack.of(
+ MessageFactory.instance(FBMessage.class)
+ .newMessage(
+ Instant.now(),
+ "1",
+ Identifier.random(),
+ Identifier.random(),
+ Identifier.random(),
+ Role.SYSTEM));
+ stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
+ stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
+ stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
+ plugin.handle(stack);
+ @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
+ assertThat(or).isNotNull();
+ JsonNode body = MAPPER.readTree(or.body());
+
+ for (int i = 0; i < stack.messages().size(); i++) {
+ FBMessage stackMessage = stack.messages().get(i);
+ JsonNode sentMessage = body.get("messages").get(i);
+ assertSoftly(
+ s ->
+ s.assertThat(stackMessage.message())
+ .isEqualTo(sentMessage.get("content").textValue()));
+ ;
+ }
+ }
+
+ @Test
+ void inPipeline() throws IOException, URISyntaxException, InterruptedException {
+ ChatStore store = new MemoryStore<>();
+ String appSecret = "app secret";
+ String accessToken = "access token";
+ String verifyToken = "verify token";
+
+ BlockingQueue metaRequests = new LinkedBlockingDeque<>();
+ String metaPath = "/meta";
+ URI messageReceiver =
+ URIBuilder.localhost().appendPath(metaPath).setScheme("http").setPort(app.port()).build();
+ app.post(
+ metaPath,
+ ctx ->
+ metaRequests.put(
+ new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap())));
+ FBMessageHandler handler =
+ new FBMessageHandler(verifyToken, accessToken, appSecret)
+ .baseURLFactory(ignored -> messageReceiver);
+
+ String apiKey = "api key";
+ OpenAIConfig config = OpenAIConfig.builder(OpenAIModel.GPT4, apiKey).build();
+ OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint);
+
+ String webhookPath = "/webhook";
+ Pipeline pipeline = new Pipeline<>(store, handler, plugin, webhookPath);
+ PipelinesRunner runner = PipelinesRunner.newInstance().pipeline(pipeline).port(0);
+ runner.start();
+
+ // TODO: create test harness
+ Request request =
+ FBMessageHandlerTest.createMessageRequest(FBMessageHandlerTest.SAMPLE_MESSAGE, runner);
+ HttpResponse response = request.execute().returnResponse();
+ assertThat(response.getCode()).isEqualTo(200);
+ @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
+ assertThat(or).isNotNull();
+ assertThat(or.headerMap().get("Authorization"))
+ .isNotNull()
+ .isEqualTo("Bearer " + config.apiKey());
+ JsonNode body = Configuration.MAPPER.readTree(or.body());
+ assertThat(body.get("model").textValue()).isEqualTo(config.model().toString());
+
+ or = metaRequests.poll(500, TimeUnit.MILLISECONDS);
+ // plugin output got back to meta
+ assertThat(or).isNotNull().satisfies(r -> assertThat(r.body()).contains(TEST_MESSAGE));
+ }
+
+ private record OutboundRequest(
+ String body, Map headerMap, Map> queryParamMap) {}
+}
diff --git a/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java b/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java
index 1ebb5c6..d49a53f 100644
--- a/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java
+++ b/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java
@@ -22,7 +22,6 @@
import com.meta.chatbridge.llm.DummyFBMessageLLMHandler;
import com.meta.chatbridge.message.Message.Role;
import com.meta.chatbridge.store.MemoryStore;
-import com.meta.chatbridge.store.MessageStack;
import io.javalin.Javalin;
import io.javalin.http.HandlerType;
import java.io.IOException;
@@ -36,6 +35,7 @@
import java.util.Map;
import java.util.concurrent.*;
import java.util.function.Function;
+import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.hc.client5.http.fluent.Request;
import org.apache.hc.client5.http.fluent.Response;
@@ -49,7 +49,7 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
-class FBMessageHandlerTest {
+public class FBMessageHandlerTest {
@FunctionalInterface
private interface ThrowableFunction {
@@ -59,7 +59,7 @@ private interface ThrowableFunction {
private static final ObjectMapper MAPPER = new ObjectMapper();
/** Example message collected directly from the messenger webhook */
- private static final String SAMPLE_MESSAGE =
+ public static final String SAMPLE_MESSAGE =
"{\"object\":\"page\",\"entry\":[{\"id\":\"106195825075770\",\"time\":1692813219204,\"messaging\":[{\"sender\":{\"id\":\"6357858494326947\"},\"recipient\":{\"id\":\"106195825075770\"},\"timestamp\":1692813218705,\"message\":{\"mid\":\"m_kT_mWOSYh_eK3kF8chtyCWfcD9-gomvu4mhaMFQl-gt4D3LjORi6k3BXD6_x9a-FOUt-D2LFuywJN6HfrpAnDg\",\"text\":\"asdfa\"}}]}]}";
private static final String SAMPLE_MESSAGE_HMAC =
@@ -78,7 +78,7 @@ static void beforeAll() throws JsonProcessingException {
private Javalin app;
private BlockingQueue requests;
- private HttpResponse getRequest(String path, int port, Map params)
+ private static HttpResponse getRequest(String path, int port, Map params)
throws IOException, URISyntaxException {
URIBuilder uriBuilder =
URIBuilder.loopbackAddress().setScheme("http").setPort(port).appendPath(path);
@@ -142,7 +142,7 @@ private static Request createMessageRequest(
throws IOException, URISyntaxException {
@SuppressWarnings("unchecked") // for the scope of this test this is guaranteed
Pipeline pipeline =
- (Pipeline) runner.pipelines().stream().findAny().get();
+ (Pipeline) runner.pipelines().stream().findAny().orElseThrow();
String path = pipeline.path();
// for the scope of this test this is guaranteed
@@ -160,7 +160,7 @@ private static Request createMessageRequest(
return request;
}
- private static Request createMessageRequest(String body, PipelinesRunner runner)
+ public static Request createMessageRequest(String body, PipelinesRunner runner)
throws IOException, URISyntaxException {
return createMessageRequest(body, runner, true);
}
@@ -373,4 +373,23 @@ void invalidMessage(
assertThat(body.get("message").get("text").textValue()).isEqualTo(llmHandler.dummyResponse());
}
}
+
+ @Test
+ void chunkingHappens() throws IOException {
+ app.start(0);
+ Identifier pageId = Identifier.from(106195825075770L);
+ String token = "243af3c6-9994-4869-ae13-ad61a38323f5"; // this is fake don't worry
+ String secret = "f74a638462f975e9eadfcbb84e4aa06b"; // it's been rolled don't worry
+ FBMessageHandler messageHandler =
+ new FBMessageHandler("0", token, secret).baseURLFactory(testURLFactoryFactory(pageId));
+
+ String bigText =
+ Stream.generate(() -> "0123456789.").limit(300).collect(Collectors.joining(" "));
+ FBMessage bigMessage =
+ new FBMessage(
+ Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER);
+ messageHandler.respond(bigMessage);
+ assertThat(requests.size()).isEqualTo(300);
+ assertThat(requests).allSatisfy(m -> assertThat(m.body()).contains("0123456789"));
+ }
}
diff --git a/src/test/java/com/meta/chatbridge/store/MessageStackTest.java b/src/test/java/com/meta/chatbridge/store/MessageStackTest.java
index 0ce1f94..429f823 100644
--- a/src/test/java/com/meta/chatbridge/store/MessageStackTest.java
+++ b/src/test/java/com/meta/chatbridge/store/MessageStackTest.java
@@ -8,58 +8,40 @@
package com.meta.chatbridge.store;
-import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.*;
-import com.google.common.collect.Lists;
import com.meta.chatbridge.Identifier;
+import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.message.Message;
+import com.meta.chatbridge.message.MessageFactory;
+import com.meta.chatbridge.message.MessageStack;
import java.time.Instant;
-import java.util.List;
import org.junit.jupiter.api.Test;
class MessageStackTest {
- record TestMessage(Instant timestamp) implements Message {
-
- @Override
- public Identifier instanceId() {
- return Identifier.from("0");
- }
-
- @Override
- public Identifier senderId() {
- return Identifier.from(0);
- }
-
- @Override
- public Identifier recipientId() {
- return Identifier.from(0);
- }
-
- @Override
- public String message() {
- return "";
- }
-
- @Override
- public Role role() {
- return Role.SYSTEM;
- }
- }
+ private static final MessageFactory FACTORY = MessageFactory.instance(FBMessage.class);
@Test
void orderPreservation() {
Instant start = Instant.now();
- TestMessage message1 = new TestMessage(start);
- TestMessage message2 = new TestMessage(start.plusSeconds(1));
- MessageStack ms = MessageStack.of(Lists.newArrayList(message1, message2));
+ FBMessage message1 =
+ FACTORY.newMessage(
+ start,
+ "sample message",
+ Identifier.random(),
+ Identifier.random(),
+ Identifier.random(),
+ Message.Role.USER);
+
+ MessageStack ms = MessageStack.of(message1);
+ FBMessage message2 = ms.newMessageFromBot(start.plusSeconds(1), "other sample message");
+ ms = ms.with(message2);
assertThat(ms.messages()).hasSize(2);
assertThat(ms.messages().get(0)).isSameAs(message1);
assertThat(ms.messages().get(1)).isSameAs(message2);
- ms = MessageStack.of(List.of());
- assertThat(ms.messages()).hasSize(0);
- ms = ms.with(message1);
+ ms = MessageStack.of(message1);
assertThat(ms.messages()).hasSize(1);
ms = ms.with(message2);
assertThat(ms.messages()).hasSize(2);
@@ -70,18 +52,105 @@ void orderPreservation() {
@Test
void orderCorrection() {
Instant start = Instant.now();
- TestMessage message1 = new TestMessage(start);
- TestMessage message2 = new TestMessage(start.plusSeconds(1));
- MessageStack ms = MessageStack.of(Lists.newArrayList(message2, message1));
+ FBMessage message2 =
+ FACTORY.newMessage(
+ start,
+ "sample message",
+ Identifier.random(),
+ Identifier.random(),
+ Identifier.random(),
+ Message.Role.USER);
+ MessageStack ms = MessageStack.of(message2);
+
+ FBMessage message1 = ms.newMessageFromBot(start.minusSeconds(1), "other sample message");
+
+ ms = ms.with(message1);
assertThat(ms.messages().get(0)).isSameAs(message1);
assertThat(ms.messages().get(1)).isSameAs(message2);
- ms = MessageStack.of(List.of());
- ms = ms.with(message2);
+ ms = MessageStack.of(message2);
assertThat(ms.messages()).hasSize(1);
ms = ms.with(message1);
assertThat(ms.messages()).hasSize(2);
assertThat(ms.messages().get(0)).isSameAs(message1);
assertThat(ms.messages().get(1)).isSameAs(message2);
}
+
+ @Test
+ void botAndUserId() {
+ Instant start = Instant.now();
+ FBMessage message1 =
+ FACTORY.newMessage(
+ start,
+ "sample message",
+ Identifier.random(),
+ Identifier.random(),
+ Identifier.random(),
+ Message.Role.USER);
+
+ MessageStack ms = MessageStack.of(message1);
+ FBMessage message2 =
+ FACTORY.newMessage(
+ start,
+ "sample message",
+ message1.recipientId(),
+ message1.senderId(),
+ Identifier.random(),
+ Message.Role.ASSISTANT);
+
+ final MessageStack finalMs = ms;
+ assertThatCode(() -> finalMs.with(message2)).doesNotThrowAnyException();
+ assertThatCode(() -> finalMs.with(finalMs.newMessageFromBot(start, "")))
+ .doesNotThrowAnyException();
+ assertThatCode(() -> finalMs.with(finalMs.newMessageFromUser(start, "", Identifier.random())))
+ .doesNotThrowAnyException();
+ ms = ms.with(message2);
+ assertThat(ms.userId()).isEqualTo(message1.senderId());
+ assertThat(ms.botId()).isEqualTo(message1.recipientId());
+ FBMessage mDifferentSenderId =
+ FACTORY.newMessage(
+ start,
+ "",
+ Identifier.random(),
+ message1.recipientId(),
+ Identifier.random(),
+ Message.Role.USER);
+
+ MessageStack finalMs1 = ms;
+ assertThatThrownBy(() -> finalMs1.with(mDifferentSenderId))
+ .isInstanceOf(IllegalArgumentException.class);
+
+ FBMessage mDifferentRecipientId =
+ FACTORY.newMessage(
+ start,
+ "",
+ message1.senderId(),
+ Identifier.random(),
+ Identifier.random(),
+ Message.Role.USER);
+ assertThatThrownBy(() -> finalMs1.with(mDifferentRecipientId))
+ .isInstanceOf(IllegalArgumentException.class);
+
+ FBMessage illegalSenderId =
+ FACTORY.newMessage(
+ start,
+ "",
+ message1.recipientId(),
+ message1.senderId(),
+ Identifier.random(),
+ Message.Role.USER);
+ assertThatThrownBy(() -> finalMs1.with(illegalSenderId))
+ .isInstanceOf(IllegalArgumentException.class);
+
+ FBMessage illegalRecipientId =
+ FACTORY.newMessage(
+ start,
+ "",
+ message1.senderId(),
+ message1.recipientId(),
+ Identifier.random(),
+ Message.Role.ASSISTANT);
+ assertThatThrownBy(() -> finalMs1.with(illegalRecipientId))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
}