diff --git a/pom.xml b/pom.xml index 95afe2e..f92f650 100644 --- a/pom.xml +++ b/pom.xml @@ -82,6 +82,11 @@ log4j-slf4j2-impl 2.20.0 + + com.knuddels + jtokkit + 0.6.1 + diff --git a/src/main/java/com/meta/chatbridge/Configuration.java b/src/main/java/com/meta/chatbridge/Configuration.java new file mode 100644 index 0000000..59eb727 --- /dev/null +++ b/src/main/java/com/meta/chatbridge/Configuration.java @@ -0,0 +1,23 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class Configuration { + public static final ObjectMapper MAPPER = + new ObjectMapper() + .enable(DeserializationFeature.FAIL_ON_NULL_FOR_PRIMITIVES) + .enable(DeserializationFeature.WRAP_EXCEPTIONS) + .enable(DeserializationFeature.READ_ENUMS_USING_TO_STRING) + .enable(DeserializationFeature.USE_LONG_FOR_INTS) + .disable(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT) + .disable(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES); +} diff --git a/src/main/java/com/meta/chatbridge/Identifier.java b/src/main/java/com/meta/chatbridge/Identifier.java index 05716a2..88b1af0 100644 --- a/src/main/java/com/meta/chatbridge/Identifier.java +++ b/src/main/java/com/meta/chatbridge/Identifier.java @@ -11,6 +11,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Objects; +import java.util.UUID; import org.jetbrains.annotations.NotNull; public class Identifier implements Comparable { @@ -21,6 +22,11 @@ private Identifier(byte[] id) { this.id = id; } + public static Identifier random() { + UUID uuid = UUID.randomUUID(); + return new Identifier(uuid.toString().getBytes(StandardCharsets.UTF_8)); + } + public static Identifier from(String id) { return new Identifier(id.getBytes(StandardCharsets.UTF_8)); } diff --git a/src/main/java/com/meta/chatbridge/Pipeline.java b/src/main/java/com/meta/chatbridge/Pipeline.java index afb4180..0afa19e 100644 --- a/src/main/java/com/meta/chatbridge/Pipeline.java +++ b/src/main/java/com/meta/chatbridge/Pipeline.java @@ -8,13 +8,14 @@ package com.meta.chatbridge; -import com.meta.chatbridge.llm.LLMHandler; +import com.meta.chatbridge.llm.LLMPlugin; import com.meta.chatbridge.message.Message; import com.meta.chatbridge.message.MessageHandler; +import com.meta.chatbridge.message.MessageStack; import com.meta.chatbridge.store.ChatStore; -import com.meta.chatbridge.store.MessageStack; import io.javalin.Javalin; import io.javalin.http.Context; +import java.io.IOException; import java.util.List; import java.util.Objects; import java.util.concurrent.*; @@ -27,15 +28,15 @@ public class Pipeline { private final ExecutorService executorService = Executors.newCachedThreadPool(); private final MessageHandler handler; private final ChatStore store; - private final LLMHandler llmHandler; + private final LLMPlugin llmPlugin; private final String path; public Pipeline( - ChatStore store, MessageHandler handler, LLMHandler llmHandler, String path) { + ChatStore store, MessageHandler handler, LLMPlugin llmPlugin, String path) { this.handler = Objects.requireNonNull(handler); this.store = Objects.requireNonNull(store); - this.llmHandler = llmHandler; + this.llmPlugin = llmPlugin; this.path = path; } @@ -61,12 +62,18 @@ public MessageHandler messageHandler() { } private void execute(MessageStack stack) { - T llmResponse = llmHandler.handle(stack); + T llmResponse; + try { + llmResponse = llmPlugin.handle(stack); + } catch (IOException e) { + LOGGER.error("failed to communicate with LLM", e); + return; + } store.add(llmResponse); try { handler.respond(llmResponse); } catch (Exception e) { - LOGGER.error("failed to respond to user", e); + // we log in the handler where we have the body context // TODO: create transactional store add // TODO: implement retry with exponential backoff } diff --git a/src/main/java/com/meta/chatbridge/PipelinesRunner.java b/src/main/java/com/meta/chatbridge/PipelinesRunner.java index 1acdb80..127d6be 100644 --- a/src/main/java/com/meta/chatbridge/PipelinesRunner.java +++ b/src/main/java/com/meta/chatbridge/PipelinesRunner.java @@ -29,8 +29,7 @@ public static PipelinesRunner newInstance() { return new PipelinesRunner(); } - @This - public PipelinesRunner start() { + public @This PipelinesRunner start() { if (!started) { started = true; app.start(port); @@ -38,8 +37,7 @@ public PipelinesRunner start() { return this; } - @This - public PipelinesRunner pipeline(Pipeline pipeline) { + public @This PipelinesRunner pipeline(Pipeline pipeline) { Preconditions.checkState(!started, "cannot add pipeline, server already started"); if (pipelines.add(pipeline)) { pipeline.register(app); @@ -64,8 +62,7 @@ public int port() { * @param port the port the server will start on * @return this */ - @This - public PipelinesRunner port(int port) { + public @This PipelinesRunner port(int port) { Preconditions.checkState(!started, "cannot change port, server already started"); this.port = port; return this; diff --git a/src/main/java/com/meta/chatbridge/llm/LLMHandler.java b/src/main/java/com/meta/chatbridge/llm/LLMPlugin.java similarity index 59% rename from src/main/java/com/meta/chatbridge/llm/LLMHandler.java rename to src/main/java/com/meta/chatbridge/llm/LLMPlugin.java index 5c5ce2b..1cb18c1 100644 --- a/src/main/java/com/meta/chatbridge/llm/LLMHandler.java +++ b/src/main/java/com/meta/chatbridge/llm/LLMPlugin.java @@ -9,9 +9,10 @@ package com.meta.chatbridge.llm; import com.meta.chatbridge.message.Message; -import com.meta.chatbridge.store.MessageStack; +import com.meta.chatbridge.message.MessageStack; +import java.io.IOException; -public interface LLMHandler { +public interface LLMPlugin { - T handle(MessageStack messageStack); + T handle(MessageStack messageStack) throws IOException; } diff --git a/src/main/java/com/meta/chatbridge/llm/OpenAIConfig.java b/src/main/java/com/meta/chatbridge/llm/OpenAIConfig.java new file mode 100644 index 0000000..b124786 --- /dev/null +++ b/src/main/java/com/meta/chatbridge/llm/OpenAIConfig.java @@ -0,0 +1,257 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; +import com.google.common.base.Preconditions; +import java.util.*; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.common.returnsreceiver.qual.This; + +@JsonDeserialize(builder = OpenAIConfig.Builder.class) +public class OpenAIConfig { + + private final OpenAIModel model; + private final String apiKey; + @Nullable private final Double temperature; + @Nullable private final Double topP; + private final List stop; + @Nullable private final Long maxOutputTokens; + @Nullable private final Double presencePenalty; + @Nullable private final Double frequencyPenalty; + private final Map logitBias; + private final @Nullable String systemMessage; + + private final long maxInputTokens; + + private OpenAIConfig( + OpenAIModel model, + String apiKey, + @Nullable Double temperature, + @Nullable Double topP, + List stop, + @Nullable Long maxOutputTokens, + @Nullable Double presencePenalty, + @Nullable Double frequencyPenalty, + Map logitBias, + @Nullable String systemMessage, + long maxInputTokens) { + this.apiKey = apiKey; + this.temperature = temperature; + this.topP = topP; + this.model = model; + this.stop = stop; + this.maxOutputTokens = maxOutputTokens; + this.presencePenalty = presencePenalty; + this.frequencyPenalty = frequencyPenalty; + this.logitBias = Collections.unmodifiableMap(logitBias); + this.systemMessage = systemMessage; + this.maxInputTokens = maxInputTokens; + } + + public static Builder builder(OpenAIModel model, String apiKey) { + return new Builder().model(model).apiKey(apiKey); + } + + public OpenAIModel model() { + return model; + } + + public String apiKey() { + return apiKey; + } + + public Optional temperature() { + return Optional.ofNullable(temperature); + } + + public Optional topP() { + return Optional.ofNullable(topP); + } + + public List stop() { + return stop; + } + + public Optional maxOutputTokens() { + return Optional.ofNullable(maxOutputTokens); + } + + public Optional presencePenalty() { + return Optional.ofNullable(presencePenalty); + } + + public Optional frequencyPenalty() { + return Optional.ofNullable(frequencyPenalty); + } + + public Map logitBias() { + return logitBias; + } + + public Optional systemMessage() { + return Optional.ofNullable(systemMessage); + } + + public long maxInputTokens() { + return maxInputTokens; + } + + @JsonPOJOBuilder(withPrefix = "") + public static class Builder { + private @Nullable OpenAIModel model; + + @JsonProperty("api_key") + private @Nullable String apiKey; + + private @Nullable Double temperature; + + @JsonProperty("top_p") + private @Nullable Double topP; + + private List stop = List.of(); + + @JsonProperty("max_output_tokens") + private @Nullable Long maxOutputTokens; + + @JsonProperty("presence_penalty") + private @Nullable Double presencePenalty; + + @JsonProperty("frequency_penalty") + private @Nullable Double frequencyPenalty; + + @JsonProperty("logit_bias") + private Map logitBias = Collections.emptyMap(); + + @JsonProperty("system_message") + private @Nullable String systemMessage; + + @JsonProperty("max_input_tokens") + private @Nullable Long maxInputTokens; + + public @This Builder model(OpenAIModel model) { + this.model = model; + return this; + } + + public @This Builder apiKey(String apiKey) { + Objects.requireNonNull(apiKey); + Preconditions.checkArgument(!apiKey.isBlank(), "api key cannot be empty"); + this.apiKey = apiKey; + return this; + } + + public @This Builder temperature(double temperature) { + Preconditions.checkArgument( + temperature >= 0 && temperature <= 2, "temperature must be >= 0 and <= 2"); + this.temperature = temperature; + return this; + } + + public @This Builder topP(double topP) { + Preconditions.checkArgument(topP > 0 && topP <= 1, "top_p must be > 0 and <= 1"); + this.topP = topP; + return this; + } + + public @This Builder stop(List stop) { + Objects.requireNonNull(stop); + this.stop = Collections.unmodifiableList(stop); + return this; + } + + public @This Builder maxOutputTokens(long maxOutputTokens) { + Preconditions.checkArgument( + maxOutputTokens > 0, "max_output_tokens must be greater than zero"); + this.maxOutputTokens = maxOutputTokens; + return this; + } + + public @This Builder presencePenalty(double presencePenalty) { + Preconditions.checkArgument( + presencePenalty >= -2.0 && presencePenalty <= 2.0, + "presence_penalty must be between -2.0 and 2.0"); + this.presencePenalty = presencePenalty; + return this; + } + + public @This Builder frequencyPenalty(double frequencyPenalty) { + Preconditions.checkArgument( + frequencyPenalty >= -2.0 && frequencyPenalty <= 2.0, + "frequency_penalty must be between -2.0 and 2.0"); + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public @This Builder logitBias(Map logitBias) { + Preconditions.checkArgument( + logitBias.values().stream().allMatch(v -> v >= -100 && v <= 100), + "all values for log_bias must be between -100 and 100"); + this.logitBias = logitBias; + return this; + } + + public @This Builder systemMessage(String systemMessage) { + Preconditions.checkArgument(!systemMessage.isBlank(), "system_message cannot be blank"); + this.systemMessage = systemMessage; + return this; + } + + public @This Builder maxInputTokens(long maxInputTokens) { + Preconditions.checkArgument(maxInputTokens > 0, "max_input_tokens must be greater than zero"); + this.maxInputTokens = maxInputTokens; + return this; + } + + public OpenAIConfig build() { + Objects.requireNonNull(model, "model is a required parameter"); + Objects.requireNonNull(apiKey, "api_key is a required parameter"); + if (maxOutputTokens != null) { + Preconditions.checkArgument( + maxOutputTokens <= model.properties().tokenLimit(), + "max_tokens must be <= " + + model.properties().tokenLimit() + + ", the maximum tokens allowed for the selected model '" + + model.properties().name() + + "'"); + } + if (maxInputTokens == null) { + if (maxOutputTokens == null) { + // set the max input size to 50% of the total context size so that there is always some + // room for the output + maxInputTokens = (long) (model.properties().tokenLimit() * 0.50); + } else { + maxInputTokens = model.properties().tokenLimit() - maxOutputTokens; + } + } + + Preconditions.checkArgument( + maxInputTokens + (maxOutputTokens == null ? 0 : maxOutputTokens) + <= model.properties().tokenLimit(), + "max_input_tokens + max_output_tokens must total to be less than or equal to " + + model.properties().tokenLimit() + + ", the total context tokens allowed by this model"); + + return new OpenAIConfig( + model, + apiKey, + temperature, + topP, + stop, + maxOutputTokens, + presencePenalty, + frequencyPenalty, + logitBias, + systemMessage, + maxInputTokens); + } + } +} diff --git a/src/main/java/com/meta/chatbridge/llm/OpenAIModel.java b/src/main/java/com/meta/chatbridge/llm/OpenAIModel.java new file mode 100644 index 0000000..977eb2c --- /dev/null +++ b/src/main/java/com/meta/chatbridge/llm/OpenAIModel.java @@ -0,0 +1,48 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import com.knuddels.jtokkit.api.ModelType; + +public enum OpenAIModel { + GPT4 { + OpenAIModelProperties properties() { + return new OpenAIModelProperties("gpt-4", 8_192, ModelType.GPT_4); + } + }, + + GPT432K { + @Override + OpenAIModelProperties properties() { + return new OpenAIModelProperties("gpt-4-32k", 32_768, ModelType.GPT_4_32K); + } + }, + GPT35TURBO { + @Override + OpenAIModelProperties properties() { + return new OpenAIModelProperties("gpt-3.5-turbo", 4_096, ModelType.GPT_3_5_TURBO); + } + }, + + GPT35TURBO16K { + @Override + OpenAIModelProperties properties() { + return new OpenAIModelProperties("gpt-3.5-turbo-16k", 16_384, ModelType.GPT_3_5_TURBO_16K); + } + }; + + @Override + public String toString() { + return this.properties().name(); + } + + abstract OpenAIModelProperties properties(); + + public record OpenAIModelProperties(String name, long tokenLimit, ModelType jtokkinModel) {} +} diff --git a/src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java b/src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java new file mode 100644 index 0000000..1600fb2 --- /dev/null +++ b/src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java @@ -0,0 +1,191 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.meta.chatbridge.message.Message; +import com.meta.chatbridge.message.Message.Role; +import com.meta.chatbridge.message.MessageStack; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Instant; +import java.util.Optional; +import org.apache.hc.client5.http.fluent.Request; +import org.apache.hc.client5.http.fluent.Response; +import org.apache.hc.core5.http.ContentType; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.common.returnsreceiver.qual.This; +import org.jetbrains.annotations.TestOnly; + +public class OpenAIPlugin implements LLMPlugin { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String ENDPOINT = "https://api.openai.com/v1/chat/completions"; + private final OpenAIConfig config; + private final Encoding tokenEncoding; + private final int tokensPerMessage; + private final int tokensPerName; + private URI endpoint; + + public OpenAIPlugin(OpenAIConfig config) { + this.config = config; + + try { + this.endpoint = new URI(ENDPOINT); + } catch (URISyntaxException e) { + throw new RuntimeException(e); // this should be impossible + } + tokenEncoding = + Encodings.newDefaultEncodingRegistry() + .getEncodingForModel(config.model().properties().jtokkinModel()); + + switch (config.model()) { + case GPT4, GPT432K -> { + tokensPerMessage = 3; + tokensPerName = 1; + } + case GPT35TURBO, GPT35TURBO16K -> { + tokensPerMessage = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n + tokensPerName = -1; // if there's a name, the role is omitted + } + default -> throw new IllegalArgumentException("Unsupported model: " + config.model()); + } + } + + @TestOnly + public @This OpenAIPlugin endpoint(URI endpoint) { + this.endpoint = endpoint; + return this; + } + + private int tokenCount(JsonNode message) { + int tokenCount = tokensPerMessage; + tokenCount += tokenEncoding.countTokens(message.get("content").textValue()); + tokenCount += tokenEncoding.countTokens(message.get("role").textValue()); + @Nullable JsonNode name = message.get("name"); + if (name != null) { + tokenCount += tokenEncoding.countTokens(name.textValue()); + tokenCount += tokensPerName; + } + return tokenCount; + } + + private Optional pruneMessages(ArrayNode messages, @Nullable JsonNode functions) + throws JsonProcessingException { + + int functionTokens = 0; + if (functions != null) { + // This is honestly a guess, it's undocumented + functionTokens = tokenEncoding.countTokens(MAPPER.writeValueAsString(functions)); + } + + ArrayNode output = MAPPER.createArrayNode(); + int totalTokens = functionTokens; + totalTokens += 3; // every reply is primed with <|start|>assistant<|message|> + + JsonNode systemMessage = messages.get(0); + boolean hasSystemMessage = systemMessage.get("role").textValue().equals("system"); + if (hasSystemMessage) { + // if the system message is present it's required + totalTokens += tokenCount(messages.get(0)); + } + for (int i = messages.size() - 1; i >= 0; i--) { + JsonNode m = messages.get(i); + String role = m.get("role").textValue(); + if (role.equals("system")) { + continue; // system has already been counted + } + totalTokens += tokenCount(m); + if (totalTokens > config.maxInputTokens()) { + break; + } + output.insert(0, m); + } + if (hasSystemMessage) { + output.insert(0, systemMessage); + } + + if ((hasSystemMessage && output.size() <= 1) || output.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(output); + } + + @Override + public T handle(MessageStack messageStack) throws IOException { + T fromUser = messageStack.tail(); + + ObjectNode body = MAPPER.createObjectNode(); + body.put("model", config.model().properties().name()) + // .put("function_call", "auto") // Update when we support functions + .put("n", 1) + .put("stream", false) + .put("user", fromUser.senderId().toString()); + config.topP().ifPresent(v -> body.put("top_p", v)); + config.temperature().ifPresent(v -> body.put("temperature", v)); + config.maxOutputTokens().ifPresent(v -> body.put("max_tokens", v)); + config.presencePenalty().ifPresent(v -> body.put("presence_penalty", v)); + config.frequencyPenalty().ifPresent(v -> body.put("frequency_penalty", v)); + if (!config.logitBias().isEmpty()) { + body.set("logit_bias", MAPPER.valueToTree(config.logitBias())); + } + if (!config.stop().isEmpty()) { + body.set("stop", MAPPER.valueToTree(config.stop())); + } + + ArrayNode messages = MAPPER.createArrayNode(); + config + .systemMessage() + .ifPresent( + m -> + messages + .addObject() + .put("role", Role.SYSTEM.toString().toLowerCase()) + .put("content", m)); + for (T message : messageStack.messages()) { + messages + .addObject() + .put("role", message.role().toString().toLowerCase()) + .put("content", message.message()); + } + + Optional prunedMessages = pruneMessages(messages, null); + if (prunedMessages.isEmpty()) { + return messageStack.newMessageFromBot( + Instant.now(), "I'm sorry but that request was too long for me."); + } + body.set("messages", prunedMessages.get()); + + String bodyString; + try { + bodyString = MAPPER.writeValueAsString(body); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); // this should be impossible + } + Response response = + Request.post(endpoint) + .bodyString(bodyString, ContentType.APPLICATION_JSON) + .setHeader("Authorization", "Bearer " + config.apiKey()) + .execute(); + + JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes()); + Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue()); + JsonNode choice = responseBody.get("choices").get(0); + String messageContent = choice.get("message").get("content").textValue(); + return messageStack.newMessageFromBot(timestamp, messageContent); + } +} diff --git a/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java b/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java index 51ccfc9..6a9da33 100644 --- a/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java +++ b/src/main/java/com/meta/chatbridge/message/FBMessageHandler.java @@ -26,6 +26,7 @@ import java.time.Instant; import java.util.*; import java.util.function.Function; +import java.util.stream.Stream; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import org.apache.hc.client5.http.fluent.Request; @@ -73,11 +74,11 @@ public FBMessageHandler(String verifyToken, String pageAccessToken, String appSe this.accessToken = pageAccessToken; } - @TestOnly - @This - FBMessageHandler baseURLFactory(Function baseURLFactory) { - this.baseURLFactory = Objects.requireNonNull(baseURLFactory); - return this; + private static Stream textChunker(String text, String regexSeparator) { + if (text.length() > 2000) { + return Arrays.stream(text.split(regexSeparator)).map(String::strip); + } + return Stream.of(text); } @Override @@ -213,19 +214,37 @@ private List postHandler(Context ctx) throws JsonProcessingException return output; } + @TestOnly + public @This FBMessageHandler baseURLFactory(Function baseURLFactory) { + this.baseURLFactory = Objects.requireNonNull(baseURLFactory); + return this; + } + @Override public void respond(FBMessage message) throws IOException { + List chunkedText = + Stream.of(message.message().strip()) + .flatMap(m -> textChunker(m, "\n\n\n+")) + .flatMap(m -> textChunker(m, "\n\n")) + .flatMap(m -> textChunker(m, "\n")) + .flatMap(m -> textChunker(m, "\\. +")) + .flatMap(m -> textChunker(m, " +")) + .toList(); + for (String text : chunkedText) { + send(text, message.recipientId(), message.senderId()); + } + } + + private void send(String message, Identifier recipient, Identifier sender) throws IOException { URI url; ObjectNode body = MAPPER.createObjectNode(); - body.put("messaging_type", "RESPONSE") - .putObject("recipient") - .put("id", message.recipientId().toString()); - body.putObject("message").put("text", message.message()); + body.put("messaging_type", "RESPONSE").putObject("recipient").put("id", recipient.toString()); + body.putObject("message").put("text", message); String bodyString; try { bodyString = MAPPER.writeValueAsString(body); url = - new URIBuilder(baseURLFactory.apply(message.senderId())) + new URIBuilder(baseURLFactory.apply(sender)) .addParameter("access_token", accessToken) .build(); } catch (JsonProcessingException | URISyntaxException e) { @@ -237,6 +256,13 @@ public void respond(FBMessage message) throws IOException { Request.post(url).bodyString(bodyString, ContentType.APPLICATION_JSON).execute(); HttpResponse responseContent = response.returnResponse(); if (responseContent.getCode() != 200) { + String errorMessage = + "received a " + + responseContent.getCode() + + " error code when attempting to reply. " + + responseContent.getReasonPhrase(); + + LOGGER.atError().addKeyValue("body", bodyString).setMessage(errorMessage).log(); throw new IOException( "received a " + responseContent.getCode() diff --git a/src/main/java/com/meta/chatbridge/message/Message.java b/src/main/java/com/meta/chatbridge/message/Message.java index 7ff82b4..cdebcb3 100644 --- a/src/main/java/com/meta/chatbridge/message/Message.java +++ b/src/main/java/com/meta/chatbridge/message/Message.java @@ -31,10 +31,14 @@ enum Role { SYSTEM } - default Identifier conversationId() { - if (senderId().compareTo(recipientId()) <= 0) { - return Identifier.from(senderId().toString() + '|' + recipientId()); + static Identifier conversationId(Identifier id1, Identifier id2) { + if (id1.compareTo(id2) <= 0) { + return Identifier.from(id1.toString() + '|' + id2); } - return Identifier.from(recipientId().toString() + '|' + senderId()); + return Identifier.from(id2.toString() + '|' + id1); + } + + default Identifier conversationId() { + return conversationId(senderId(), recipientId()); } } diff --git a/src/main/java/com/meta/chatbridge/message/MessageFactory.java b/src/main/java/com/meta/chatbridge/message/MessageFactory.java new file mode 100644 index 0000000..e5c2adb --- /dev/null +++ b/src/main/java/com/meta/chatbridge/message/MessageFactory.java @@ -0,0 +1,51 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.message; + +import com.meta.chatbridge.Identifier; +import com.meta.chatbridge.message.Message.Role; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +@FunctionalInterface +public interface MessageFactory { + Map, MessageFactory> FACTORY_MAP = + Stream.>of( + new FactoryContainer<>( + FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r))) + .collect( + Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory)); + + static MessageFactory instance(Class clazz) { + @SuppressWarnings("unchecked") // static map guarantees this to be true + MessageFactory factory = (MessageFactory) FACTORY_MAP.get(clazz); + Objects.requireNonNull(factory, clazz + " does not have a registered factory"); + return factory; + } + + static MessageFactory instance(T message) { + @SuppressWarnings("unchecked") // class of an object is its class :) + Class clazz = (Class) message.getClass(); + return instance(clazz); + } + + T newMessage( + Instant timestamp, + String message, + Identifier senderId, + Identifier recipientId, + Identifier instanceId, + Role role); + + /** this exists to provide compiler guarantees for type matching in the FACTORY_MAP */ + record FactoryContainer(Class clazz, MessageFactory factory) {} +} diff --git a/src/main/java/com/meta/chatbridge/message/MessageStack.java b/src/main/java/com/meta/chatbridge/message/MessageStack.java new file mode 100644 index 0000000..ee4a801 --- /dev/null +++ b/src/main/java/com/meta/chatbridge/message/MessageStack.java @@ -0,0 +1,92 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.message; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.meta.chatbridge.Identifier; +import com.meta.chatbridge.message.Message.Role; +import java.time.Instant; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MessageStack { + private final List messages; + private final MessageFactory messageFactory; + + private MessageStack(T message) { + Objects.requireNonNull(message); + this.messages = ImmutableList.of(message); + messageFactory = MessageFactory.instance(message); + } + + /** Constructor that exists to support the with method */ + private MessageStack(MessageStack old, T newMessage) { + Objects.requireNonNull(newMessage); + messageFactory = old.messageFactory; + Preconditions.checkArgument( + old.tail().conversationId().equals(newMessage.conversationId()), + "all messages in a stack must have the same conversation id"); + List messages = old.messages; + if (newMessage.timestamp().isBefore(old.tail().timestamp())) { + this.messages = + Stream.concat(messages.stream(), Stream.of(newMessage)) + .sorted(Comparator.comparing(Message::timestamp)) + .collect(Collectors.toUnmodifiableList()); + } else { + this.messages = ImmutableList.builder().addAll(messages).add(newMessage).build(); + } + + Preconditions.checkArgument( + old.userId().equals(userId()) && old.botId().equals(botId()), + "userId and botId not consistent with this message stack"); + } + + public static MessageStack of(T message) { + return new MessageStack<>(message); + } + + public Identifier userId() { + T message = tail(); + return switch (message.role()) { + case ASSISTANT, SYSTEM -> message.recipientId(); + case USER -> message.senderId(); + }; + } + + public Identifier botId() { + T message = tail(); + return switch (message.role()) { + case ASSISTANT, SYSTEM -> message.senderId(); + case USER -> message.recipientId(); + }; + } + + public T newMessageFromBot(Instant timestamp, String message) { + return messageFactory.newMessage( + timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT); + } + + public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) { + return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER); + } + + public MessageStack with(T message) { + return new MessageStack<>(this, message); + } + + public List messages() { + return messages; + } + + public T tail() { + return messages.get(messages.size() - 1); + } +} diff --git a/src/main/java/com/meta/chatbridge/store/ChatStore.java b/src/main/java/com/meta/chatbridge/store/ChatStore.java index 07d09a8..428c70a 100644 --- a/src/main/java/com/meta/chatbridge/store/ChatStore.java +++ b/src/main/java/com/meta/chatbridge/store/ChatStore.java @@ -9,6 +9,7 @@ package com.meta.chatbridge.store; import com.meta.chatbridge.message.Message; +import com.meta.chatbridge.message.MessageStack; /** * This class is in charge of both maintaining a chat history and managing a queue of conversations diff --git a/src/main/java/com/meta/chatbridge/store/MemoryStore.java b/src/main/java/com/meta/chatbridge/store/MemoryStore.java index 36863f4..36385ef 100644 --- a/src/main/java/com/meta/chatbridge/store/MemoryStore.java +++ b/src/main/java/com/meta/chatbridge/store/MemoryStore.java @@ -12,6 +12,7 @@ import com.google.common.cache.CacheBuilder; import com.meta.chatbridge.Identifier; import com.meta.chatbridge.message.Message; +import com.meta.chatbridge.message.MessageStack; import java.time.Duration; public class MemoryStore implements ChatStore { diff --git a/src/main/java/com/meta/chatbridge/store/MessageStack.java b/src/main/java/com/meta/chatbridge/store/MessageStack.java deleted file mode 100644 index d55607a..0000000 --- a/src/main/java/com/meta/chatbridge/store/MessageStack.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.meta.chatbridge.store; - -import com.google.common.collect.ImmutableList; -import com.meta.chatbridge.message.Message; -import java.util.*; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class MessageStack { - - private final List messages; - - private MessageStack(Collection messages) { - Objects.requireNonNull(messages); - this.messages = - messages.stream() - .sorted(Comparator.comparing(Message::timestamp)) - .collect(Collectors.toUnmodifiableList()); - } - - /** Constructor that exists to support the with method */ - private MessageStack(List messages, T newMessage) { - if (!messages.isEmpty() - && newMessage.timestamp().isBefore(messages.get(messages.size() - 1).timestamp())) { - this.messages = - Stream.concat(messages.stream(), Stream.of(newMessage)) - .sorted(Comparator.comparing(Message::timestamp)) - .collect(Collectors.toUnmodifiableList()); - } else { - this.messages = ImmutableList.builder().addAll(messages).add(newMessage).build(); - } - } - - public static MessageStack of(T message) { - return new MessageStack<>(List.of(message)); - } - - public static MessageStack of(Collection messages) { - return new MessageStack<>(messages); - } - - public MessageStack with(T message) { - return new MessageStack<>(messages, message); - } - - public List messages() { - return messages; - } -} diff --git a/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java b/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java index 5548b6e..689d3fe 100644 --- a/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java +++ b/src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java @@ -11,12 +11,12 @@ import com.meta.chatbridge.Identifier; import com.meta.chatbridge.message.FBMessage; import com.meta.chatbridge.message.Message; -import com.meta.chatbridge.store.MessageStack; +import com.meta.chatbridge.message.MessageStack; import java.time.Instant; import java.util.concurrent.*; import org.checkerframework.checker.nullness.qual.Nullable; -public class DummyFBMessageLLMHandler implements LLMHandler { +public class DummyFBMessageLLMHandler implements LLMPlugin { private final String dummyLLMResponse; private final BlockingQueue> receivedMessageStacks = @@ -51,7 +51,10 @@ public String dummyResponse() { public FBMessage handle(MessageStack messageStack) { receivedMessageStacks.add(messageStack); FBMessage inbound = - messageStack.messages().stream().filter(m -> m.role() == Message.Role.USER).findAny().get(); + messageStack.messages().stream() + .filter(m -> m.role() == Message.Role.USER) + .findAny() + .orElseThrow(); return new FBMessage( Instant.now(), Identifier.from("test_message"), diff --git a/src/test/java/com/meta/chatbridge/llm/OpenAIConfigTest.java b/src/test/java/com/meta/chatbridge/llm/OpenAIConfigTest.java new file mode 100644 index 0000000..99db5ec --- /dev/null +++ b/src/test/java/com/meta/chatbridge/llm/OpenAIConfigTest.java @@ -0,0 +1,182 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.*; +import com.google.common.collect.ImmutableList; +import com.meta.chatbridge.Configuration; +import java.util.Collection; +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class OpenAIConfigTest { + + private static final ObjectMapper MAPPER = Configuration.MAPPER; + static final Collection CONFIG_ITEMS = + ImmutableList.of( + new ConfigItem("model", true, TextNode.valueOf("gpt-4"), List.of(TextNode.valueOf("n"))), + new ConfigItem( + "api_key", + true, + TextNode.valueOf("notempty"), + List.of(TextNode.valueOf(""), TextNode.valueOf(" "))), + new ConfigItem( + "temperature", + false, + DoubleNode.valueOf(1), + List.of(DoubleNode.valueOf(-0.1), DoubleNode.valueOf(2.1))), + new ConfigItem( + "top_p", + false, + DoubleNode.valueOf(0.5), + List.of(DoubleNode.valueOf(0), DoubleNode.valueOf(1.1))), + new ConfigItem( + "stop", + false, + MAPPER.createArrayNode().add("10").add("stop"), + List.of( + TextNode.valueOf("I'm a stop"), + DoubleNode.valueOf(10), + MAPPER.createArrayNode().addObject())), + new ConfigItem( + "max_output_tokens", + false, + LongNode.valueOf(100), + List.of( + LongNode.valueOf(0), + LongNode.valueOf(OpenAIModel.GPT4.properties().tokenLimit() + 1))), + new ConfigItem( + "presence_penalty", + false, + DoubleNode.valueOf(0), + List.of(DoubleNode.valueOf(-2.1), DoubleNode.valueOf(2.1))), + new ConfigItem( + "frequency_penalty", + false, + DoubleNode.valueOf(0), + List.of(DoubleNode.valueOf(-2.1), DoubleNode.valueOf(2.1))), + new ConfigItem( + "logit_bias", + false, + MAPPER.createObjectNode().put("1", 0.5), + List.of( + MAPPER.createObjectNode().put("1", -101), + MAPPER.createObjectNode().put("1", 101))), + new ConfigItem( + "system_message", + false, + TextNode.valueOf("you're a helpful assistant"), + List.of(TextNode.valueOf(""), TextNode.valueOf(" "))), + new ConfigItem( + "max_input_tokens", + false, + LongNode.valueOf(4000), + List.of(LongNode.valueOf(-1), LongNode.valueOf(100_000)))); + private ObjectNode minimalConfig; + + static Stream configItems() { + return CONFIG_ITEMS.stream(); + } + + static Stream invalidValues() { + return configItems() + .flatMap(c -> c.invalidValues().stream().map(t -> Arguments.of(c.key(), t))); + } + + static Stream requiredKeys() { + return configItems().filter(ConfigItem::required).map(ConfigItem::key); + } + + @BeforeEach + void setUp() { + minimalConfig = MAPPER.createObjectNode(); + CONFIG_ITEMS.forEach( + t -> { + if (t.required()) { + minimalConfig.set(t.key(), t.validValue()); + } + }); + } + + @Test + void maximalValidConfig() throws JsonProcessingException { + ObjectNode body = MAPPER.createObjectNode(); + CONFIG_ITEMS.forEach(t -> body.set(t.key(), t.validValue())); + OpenAIConfig config = MAPPER.readValue(MAPPER.writeValueAsString(body), OpenAIConfig.class); + assertThat(config.model()).isEqualTo(OpenAIModel.GPT4); + assertThat(config.temperature().isPresent()).isTrue(); + assertThat(config.frequencyPenalty().isPresent()).isTrue(); + assertThat(config.topP().isPresent()).isTrue(); + assertThat(config.maxOutputTokens().isPresent()).isTrue(); + assertThat(config.presencePenalty().isPresent()).isTrue(); + assertThat(config.frequencyPenalty().isPresent()).isTrue(); + assertThat(config.logitBias().isEmpty()).isFalse(); + } + + @Test + void minimalValidConfig() throws JsonProcessingException { + ObjectNode body = MAPPER.createObjectNode(); + CONFIG_ITEMS.forEach( + t -> { + if (t.required()) { + body.set(t.key(), t.validValue()); + } + }); + OpenAIConfig config = MAPPER.readValue(MAPPER.writeValueAsString(body), OpenAIConfig.class); + assertThat(config.model()).isEqualTo(OpenAIModel.GPT4); + assertThat(config.temperature().isEmpty()).isTrue(); + assertThat(config.frequencyPenalty().isEmpty()).isTrue(); + assertThat(config.topP().isEmpty()).isTrue(); + assertThat(config.maxOutputTokens().isEmpty()).isTrue(); + assertThat(config.presencePenalty().isEmpty()).isTrue(); + assertThat(config.frequencyPenalty().isEmpty()).isTrue(); + assertThat(config.logitBias().isEmpty()).isTrue(); + } + + @ParameterizedTest + @MethodSource("configItems") + void nullValues(ConfigItem item) throws JsonProcessingException { + minimalConfig.putNull(item.key()); + String bodyString = MAPPER.writeValueAsString(minimalConfig); + assertThatThrownBy(() -> MAPPER.readValue(bodyString, OpenAIConfig.class)) + .isInstanceOf(Exception.class); + } + + @ParameterizedTest + @MethodSource("invalidValues") + void invalidValues(String key, JsonNode value) throws JsonProcessingException { + minimalConfig.set(key, value); + String bodyString = MAPPER.writeValueAsString(minimalConfig); + assertThatThrownBy(() -> MAPPER.readValue(bodyString, OpenAIConfig.class)) + .isInstanceOf(Exception.class); + } + + @ParameterizedTest + @MethodSource("requiredKeys") + void requiredKeysMissing(String key) throws JsonProcessingException { + minimalConfig.remove(key); + String bodyString = MAPPER.writeValueAsString(minimalConfig); + assertThatThrownBy(() -> MAPPER.readValue(bodyString, OpenAIConfig.class)) + .isInstanceOf(Exception.class); + } + + record ConfigItem( + String key, boolean required, JsonNode validValue, List invalidValues) {} +} diff --git a/src/test/java/com/meta/chatbridge/llm/OpenAIModelTest.java b/src/test/java/com/meta/chatbridge/llm/OpenAIModelTest.java new file mode 100644 index 0000000..09203ed --- /dev/null +++ b/src/test/java/com/meta/chatbridge/llm/OpenAIModelTest.java @@ -0,0 +1,35 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +class OpenAIModelTest { + + /** This is super important because the JsonDeserializer relies on it */ + @Test + void namesAreUnique() { + Set uniqueElements = + Arrays.stream(OpenAIModel.values()).map(OpenAIModel::toString).collect(Collectors.toSet()); + assertThat(OpenAIModel.values()).extracting(Enum::toString).hasSize(uniqueElements.size()); + } + + /** This is super important because the JsonDeserializer relies on it */ + @Test + void toStringMatchesName() { + assertThat(OpenAIModel.values()) + .allSatisfy(v -> assertThat(v.properties().name()).isEqualTo(v.toString())); + } +} diff --git a/src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java b/src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java new file mode 100644 index 0000000..f6174d7 --- /dev/null +++ b/src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java @@ -0,0 +1,269 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.chatbridge.llm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.SoftAssertions.assertSoftly; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.meta.chatbridge.Configuration; +import com.meta.chatbridge.Identifier; +import com.meta.chatbridge.Pipeline; +import com.meta.chatbridge.PipelinesRunner; +import com.meta.chatbridge.message.*; +import com.meta.chatbridge.message.Message.Role; +import com.meta.chatbridge.store.ChatStore; +import com.meta.chatbridge.store.MemoryStore; +import io.javalin.Javalin; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hc.client5.http.fluent.Request; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.net.URIBuilder; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; + +public class OpenAIPluginTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + public static final JsonNode SAMPLE_RESPONSE = MAPPER.createObjectNode(); + private static final String PATH = "/"; + private static final String TEST_MESSAGE = "this is a test message"; + private static final MessageStack STACK = + MessageStack.of( + MessageFactory.instance(FBMessage.class) + .newMessage( + Instant.now(), + "test message", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Role.USER)); + + static { + ((ObjectNode) SAMPLE_RESPONSE) + .put("created", Instant.now().getEpochSecond()) + .put("object", "chat.completion") + .put("id", UUID.randomUUID().toString()) + .putArray("choices") + .addObject() + .put("index", 0) + .put("finish_reason", "stop") + .putObject("message") + .put("role", "assistant") + .put("content", TEST_MESSAGE); + } + + private BlockingQueue openAIRequests; + private Javalin app; + private URI endpoint; + private ObjectNode minimalConfig; + + static Stream modelOptions() { + Set non_model_options = Set.of("model", "api_key", "max_input_tokens"); + return OpenAIConfigTest.CONFIG_ITEMS.stream().filter(c -> !non_model_options.contains(c.key())); + } + + @BeforeEach + void setUp() throws UnknownHostException, URISyntaxException { + openAIRequests = new LinkedBlockingDeque<>(); + app = Javalin.create(); + app.before( + PATH, + ctx -> + openAIRequests.add( + new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap()))); + app.post(PATH, ctx -> ctx.result(MAPPER.writeValueAsString(SAMPLE_RESPONSE))); + app.start(0); + endpoint = + URIBuilder.localhost().setScheme("http").appendPath(PATH).setPort(app.port()).build(); + } + + @ParameterizedTest + @EnumSource(OpenAIModel.class) + void sampleValid(OpenAIModel model) throws IOException, InterruptedException { + String apiKey = UUID.randomUUID().toString(); + OpenAIConfig config = OpenAIConfig.builder(model, apiKey).build(); + OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint); + FBMessage message = plugin.handle(STACK); + assertThat(message.message()).isEqualTo(TEST_MESSAGE); + assertThat(message.role()).isSameAs(Role.ASSISTANT); + assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); + @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + assertThat(or.headerMap().get("Authorization")).isNotNull().isEqualTo("Bearer " + apiKey); + assertThat(MAPPER.readTree(or.body()).get("model").textValue()).isEqualTo(model.toString()); + } + + @BeforeEach + void setUpMinConfig() { + minimalConfig = MAPPER.createObjectNode(); + OpenAIConfigTest.CONFIG_ITEMS.forEach( + t -> { + if (t.required()) { + minimalConfig.set(t.key(), t.validValue()); + } + }); + } + + @ParameterizedTest + @MethodSource("modelOptions") + void validConfigValues(OpenAIConfigTest.ConfigItem configItem) + throws IOException, InterruptedException { + minimalConfig.set(configItem.key(), configItem.validValue()); + OpenAIConfig config = Configuration.MAPPER.convertValue(minimalConfig, OpenAIConfig.class); + OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint); + FBMessage message = plugin.handle(STACK); + assertThat(message.message()).isEqualTo(TEST_MESSAGE); + assertThat(message.role()).isSameAs(Role.ASSISTANT); + assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); + @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + assertThat(or.headerMap().get("Authorization")) + .isNotNull() + .isEqualTo("Bearer " + config.apiKey()); + JsonNode body = Configuration.MAPPER.readTree(or.body()); + assertThat(body.get("model").textValue()).isEqualTo(config.model().toString()); + if (configItem.key().equals("system_message")) { + assertThat(body.get("messages")) + .satisfiesOnlyOnce( + m -> { + assertThat(m.get("role").textValue()) + .isEqualTo(Role.SYSTEM.toString().toLowerCase()); + assertThat(m.get("content").textValue()) + .isEqualTo(configItem.validValue().textValue()); + }); + } else { + if (configItem.key().equals("max_output_tokens")) { + assertThat(body.get("max_tokens")).isEqualTo(minimalConfig.get(configItem.key())); + } else { + assertThat(body.get(configItem.key())).isEqualTo(minimalConfig.get(configItem.key())); + } + } + } + + @Test + void contextTooBig() throws IOException { + OpenAIConfig config = + OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build(); + OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint); + MessageStack stack = + STACK.with( + STACK.newMessageFromUser( + Instant.now(), + Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()), + Identifier.random())); + FBMessage response = plugin.handle(stack); + assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me."); + assertThat(openAIRequests).hasSize(0); + } + + @Test + void orderedCorrectly() throws IOException, InterruptedException { + OpenAIConfig config = + OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build(); + OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint); + MessageStack stack = + MessageStack.of( + MessageFactory.instance(FBMessage.class) + .newMessage( + Instant.now(), + "1", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Role.SYSTEM)); + stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2))); + stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3))); + stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4))); + plugin.handle(stack); + @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + JsonNode body = MAPPER.readTree(or.body()); + + for (int i = 0; i < stack.messages().size(); i++) { + FBMessage stackMessage = stack.messages().get(i); + JsonNode sentMessage = body.get("messages").get(i); + assertSoftly( + s -> + s.assertThat(stackMessage.message()) + .isEqualTo(sentMessage.get("content").textValue())); + ; + } + } + + @Test + void inPipeline() throws IOException, URISyntaxException, InterruptedException { + ChatStore store = new MemoryStore<>(); + String appSecret = "app secret"; + String accessToken = "access token"; + String verifyToken = "verify token"; + + BlockingQueue metaRequests = new LinkedBlockingDeque<>(); + String metaPath = "/meta"; + URI messageReceiver = + URIBuilder.localhost().appendPath(metaPath).setScheme("http").setPort(app.port()).build(); + app.post( + metaPath, + ctx -> + metaRequests.put( + new OutboundRequest(ctx.body(), ctx.headerMap(), ctx.queryParamMap()))); + FBMessageHandler handler = + new FBMessageHandler(verifyToken, accessToken, appSecret) + .baseURLFactory(ignored -> messageReceiver); + + String apiKey = "api key"; + OpenAIConfig config = OpenAIConfig.builder(OpenAIModel.GPT4, apiKey).build(); + OpenAIPlugin plugin = new OpenAIPlugin(config).endpoint(endpoint); + + String webhookPath = "/webhook"; + Pipeline pipeline = new Pipeline<>(store, handler, plugin, webhookPath); + PipelinesRunner runner = PipelinesRunner.newInstance().pipeline(pipeline).port(0); + runner.start(); + + // TODO: create test harness + Request request = + FBMessageHandlerTest.createMessageRequest(FBMessageHandlerTest.SAMPLE_MESSAGE, runner); + HttpResponse response = request.execute().returnResponse(); + assertThat(response.getCode()).isEqualTo(200); + @Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS); + assertThat(or).isNotNull(); + assertThat(or.headerMap().get("Authorization")) + .isNotNull() + .isEqualTo("Bearer " + config.apiKey()); + JsonNode body = Configuration.MAPPER.readTree(or.body()); + assertThat(body.get("model").textValue()).isEqualTo(config.model().toString()); + + or = metaRequests.poll(500, TimeUnit.MILLISECONDS); + // plugin output got back to meta + assertThat(or).isNotNull().satisfies(r -> assertThat(r.body()).contains(TEST_MESSAGE)); + } + + private record OutboundRequest( + String body, Map headerMap, Map> queryParamMap) {} +} diff --git a/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java b/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java index 1ebb5c6..d49a53f 100644 --- a/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java +++ b/src/test/java/com/meta/chatbridge/message/FBMessageHandlerTest.java @@ -22,7 +22,6 @@ import com.meta.chatbridge.llm.DummyFBMessageLLMHandler; import com.meta.chatbridge.message.Message.Role; import com.meta.chatbridge.store.MemoryStore; -import com.meta.chatbridge.store.MessageStack; import io.javalin.Javalin; import io.javalin.http.HandlerType; import java.io.IOException; @@ -36,6 +35,7 @@ import java.util.Map; import java.util.concurrent.*; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.hc.client5.http.fluent.Request; import org.apache.hc.client5.http.fluent.Response; @@ -49,7 +49,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -class FBMessageHandlerTest { +public class FBMessageHandlerTest { @FunctionalInterface private interface ThrowableFunction { @@ -59,7 +59,7 @@ private interface ThrowableFunction { private static final ObjectMapper MAPPER = new ObjectMapper(); /** Example message collected directly from the messenger webhook */ - private static final String SAMPLE_MESSAGE = + public static final String SAMPLE_MESSAGE = "{\"object\":\"page\",\"entry\":[{\"id\":\"106195825075770\",\"time\":1692813219204,\"messaging\":[{\"sender\":{\"id\":\"6357858494326947\"},\"recipient\":{\"id\":\"106195825075770\"},\"timestamp\":1692813218705,\"message\":{\"mid\":\"m_kT_mWOSYh_eK3kF8chtyCWfcD9-gomvu4mhaMFQl-gt4D3LjORi6k3BXD6_x9a-FOUt-D2LFuywJN6HfrpAnDg\",\"text\":\"asdfa\"}}]}]}"; private static final String SAMPLE_MESSAGE_HMAC = @@ -78,7 +78,7 @@ static void beforeAll() throws JsonProcessingException { private Javalin app; private BlockingQueue requests; - private HttpResponse getRequest(String path, int port, Map params) + private static HttpResponse getRequest(String path, int port, Map params) throws IOException, URISyntaxException { URIBuilder uriBuilder = URIBuilder.loopbackAddress().setScheme("http").setPort(port).appendPath(path); @@ -142,7 +142,7 @@ private static Request createMessageRequest( throws IOException, URISyntaxException { @SuppressWarnings("unchecked") // for the scope of this test this is guaranteed Pipeline pipeline = - (Pipeline) runner.pipelines().stream().findAny().get(); + (Pipeline) runner.pipelines().stream().findAny().orElseThrow(); String path = pipeline.path(); // for the scope of this test this is guaranteed @@ -160,7 +160,7 @@ private static Request createMessageRequest( return request; } - private static Request createMessageRequest(String body, PipelinesRunner runner) + public static Request createMessageRequest(String body, PipelinesRunner runner) throws IOException, URISyntaxException { return createMessageRequest(body, runner, true); } @@ -373,4 +373,23 @@ void invalidMessage( assertThat(body.get("message").get("text").textValue()).isEqualTo(llmHandler.dummyResponse()); } } + + @Test + void chunkingHappens() throws IOException { + app.start(0); + Identifier pageId = Identifier.from(106195825075770L); + String token = "243af3c6-9994-4869-ae13-ad61a38323f5"; // this is fake don't worry + String secret = "f74a638462f975e9eadfcbb84e4aa06b"; // it's been rolled don't worry + FBMessageHandler messageHandler = + new FBMessageHandler("0", token, secret).baseURLFactory(testURLFactoryFactory(pageId)); + + String bigText = + Stream.generate(() -> "0123456789.").limit(300).collect(Collectors.joining(" ")); + FBMessage bigMessage = + new FBMessage( + Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER); + messageHandler.respond(bigMessage); + assertThat(requests.size()).isEqualTo(300); + assertThat(requests).allSatisfy(m -> assertThat(m.body()).contains("0123456789")); + } } diff --git a/src/test/java/com/meta/chatbridge/store/MessageStackTest.java b/src/test/java/com/meta/chatbridge/store/MessageStackTest.java index 0ce1f94..429f823 100644 --- a/src/test/java/com/meta/chatbridge/store/MessageStackTest.java +++ b/src/test/java/com/meta/chatbridge/store/MessageStackTest.java @@ -8,58 +8,40 @@ package com.meta.chatbridge.store; -import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.*; -import com.google.common.collect.Lists; import com.meta.chatbridge.Identifier; +import com.meta.chatbridge.message.FBMessage; import com.meta.chatbridge.message.Message; +import com.meta.chatbridge.message.MessageFactory; +import com.meta.chatbridge.message.MessageStack; import java.time.Instant; -import java.util.List; import org.junit.jupiter.api.Test; class MessageStackTest { - record TestMessage(Instant timestamp) implements Message { - - @Override - public Identifier instanceId() { - return Identifier.from("0"); - } - - @Override - public Identifier senderId() { - return Identifier.from(0); - } - - @Override - public Identifier recipientId() { - return Identifier.from(0); - } - - @Override - public String message() { - return ""; - } - - @Override - public Role role() { - return Role.SYSTEM; - } - } + private static final MessageFactory FACTORY = MessageFactory.instance(FBMessage.class); @Test void orderPreservation() { Instant start = Instant.now(); - TestMessage message1 = new TestMessage(start); - TestMessage message2 = new TestMessage(start.plusSeconds(1)); - MessageStack ms = MessageStack.of(Lists.newArrayList(message1, message2)); + FBMessage message1 = + FACTORY.newMessage( + start, + "sample message", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Message.Role.USER); + + MessageStack ms = MessageStack.of(message1); + FBMessage message2 = ms.newMessageFromBot(start.plusSeconds(1), "other sample message"); + ms = ms.with(message2); assertThat(ms.messages()).hasSize(2); assertThat(ms.messages().get(0)).isSameAs(message1); assertThat(ms.messages().get(1)).isSameAs(message2); - ms = MessageStack.of(List.of()); - assertThat(ms.messages()).hasSize(0); - ms = ms.with(message1); + ms = MessageStack.of(message1); assertThat(ms.messages()).hasSize(1); ms = ms.with(message2); assertThat(ms.messages()).hasSize(2); @@ -70,18 +52,105 @@ void orderPreservation() { @Test void orderCorrection() { Instant start = Instant.now(); - TestMessage message1 = new TestMessage(start); - TestMessage message2 = new TestMessage(start.plusSeconds(1)); - MessageStack ms = MessageStack.of(Lists.newArrayList(message2, message1)); + FBMessage message2 = + FACTORY.newMessage( + start, + "sample message", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Message.Role.USER); + MessageStack ms = MessageStack.of(message2); + + FBMessage message1 = ms.newMessageFromBot(start.minusSeconds(1), "other sample message"); + + ms = ms.with(message1); assertThat(ms.messages().get(0)).isSameAs(message1); assertThat(ms.messages().get(1)).isSameAs(message2); - ms = MessageStack.of(List.of()); - ms = ms.with(message2); + ms = MessageStack.of(message2); assertThat(ms.messages()).hasSize(1); ms = ms.with(message1); assertThat(ms.messages()).hasSize(2); assertThat(ms.messages().get(0)).isSameAs(message1); assertThat(ms.messages().get(1)).isSameAs(message2); } + + @Test + void botAndUserId() { + Instant start = Instant.now(); + FBMessage message1 = + FACTORY.newMessage( + start, + "sample message", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Message.Role.USER); + + MessageStack ms = MessageStack.of(message1); + FBMessage message2 = + FACTORY.newMessage( + start, + "sample message", + message1.recipientId(), + message1.senderId(), + Identifier.random(), + Message.Role.ASSISTANT); + + final MessageStack finalMs = ms; + assertThatCode(() -> finalMs.with(message2)).doesNotThrowAnyException(); + assertThatCode(() -> finalMs.with(finalMs.newMessageFromBot(start, ""))) + .doesNotThrowAnyException(); + assertThatCode(() -> finalMs.with(finalMs.newMessageFromUser(start, "", Identifier.random()))) + .doesNotThrowAnyException(); + ms = ms.with(message2); + assertThat(ms.userId()).isEqualTo(message1.senderId()); + assertThat(ms.botId()).isEqualTo(message1.recipientId()); + FBMessage mDifferentSenderId = + FACTORY.newMessage( + start, + "", + Identifier.random(), + message1.recipientId(), + Identifier.random(), + Message.Role.USER); + + MessageStack finalMs1 = ms; + assertThatThrownBy(() -> finalMs1.with(mDifferentSenderId)) + .isInstanceOf(IllegalArgumentException.class); + + FBMessage mDifferentRecipientId = + FACTORY.newMessage( + start, + "", + message1.senderId(), + Identifier.random(), + Identifier.random(), + Message.Role.USER); + assertThatThrownBy(() -> finalMs1.with(mDifferentRecipientId)) + .isInstanceOf(IllegalArgumentException.class); + + FBMessage illegalSenderId = + FACTORY.newMessage( + start, + "", + message1.recipientId(), + message1.senderId(), + Identifier.random(), + Message.Role.USER); + assertThatThrownBy(() -> finalMs1.with(illegalSenderId)) + .isInstanceOf(IllegalArgumentException.class); + + FBMessage illegalRecipientId = + FACTORY.newMessage( + start, + "", + message1.senderId(), + message1.recipientId(), + Identifier.random(), + Message.Role.ASSISTANT); + assertThatThrownBy(() -> finalMs1.with(illegalRecipientId)) + .isInstanceOf(IllegalArgumentException.class); + } }