Skip to content

Commit

Permalink
Merge pull request #9 from facebookincubator/hunter/name_changes
Browse files Browse the repository at this point in the history
Rename Prominent Classes
  • Loading branch information
hunterjackson authored Sep 25, 2023
2 parents 09f62cb + 975bd5f commit 199e8a2
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 85 deletions.
10 changes: 5 additions & 5 deletions src/main/java/com/meta/chatbridge/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.meta.chatbridge.llm.LLMPlugin;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageHandler;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import com.meta.chatbridge.store.ChatStore;
import io.javalin.Javalin;
import io.javalin.http.Context;
Expand Down Expand Up @@ -44,8 +44,8 @@ void handle(Context ctx) {
List<T> messages = handler.processRequest(ctx);
// TODO: once we have a non-volatile store, on startup send stored but not replied to messages
for (T m : messages) {
MessageStack<T> stack = store.add(m);
executorService.submit(() -> execute(stack));
ThreadState<T> thread = store.add(m);
executorService.submit(() -> execute(thread));
}
}

Expand All @@ -61,10 +61,10 @@ public MessageHandler<T> messageHandler() {
return this.handler;
}

private void execute(MessageStack<T> stack) {
private void execute(ThreadState<T> thread) {
T llmResponse;
try {
llmResponse = llmPlugin.handle(stack);
llmResponse = llmPlugin.handle(thread);
} catch (IOException e) {
LOGGER.error("failed to communicate with LLM", e);
return;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/meta/chatbridge/llm/LLMPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
package com.meta.chatbridge.llm;

import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.io.IOException;

public interface LLMPlugin<T extends Message> {

T handle(MessageStack<T> messageStack) throws IOException;
T handle(ThreadState<T> threadState) throws IOException;
}
12 changes: 6 additions & 6 deletions src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import com.knuddels.jtokkit.api.Encoding;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.Message.Role;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -126,8 +126,8 @@ private Optional<ArrayNode> pruneMessages(ArrayNode messages, @Nullable JsonNode
}

@Override
public T handle(MessageStack<T> messageStack) throws IOException {
T fromUser = messageStack.tail();
public T handle(ThreadState<T> threadState) throws IOException {
T fromUser = threadState.tail();

ObjectNode body = MAPPER.createObjectNode();
body.put("model", config.model().properties().name())
Expand Down Expand Up @@ -156,7 +156,7 @@ public T handle(MessageStack<T> messageStack) throws IOException {
.addObject()
.put("role", Role.SYSTEM.toString().toLowerCase())
.put("content", m));
for (T message : messageStack.messages()) {
for (T message : threadState.messages()) {
messages
.addObject()
.put("role", message.role().toString().toLowerCase())
Expand All @@ -165,7 +165,7 @@ public T handle(MessageStack<T> messageStack) throws IOException {

Optional<ArrayNode> prunedMessages = pruneMessages(messages, null);
if (prunedMessages.isEmpty()) {
return messageStack.newMessageFromBot(
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
}
body.set("messages", prunedMessages.get());
Expand All @@ -186,6 +186,6 @@ public T handle(MessageStack<T> messageStack) throws IOException {
Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
JsonNode choice = responseBody.get("choices").get(0);
String messageContent = choice.get("message").get("content").textValue();
return messageStack.newMessageFromBot(timestamp, messageContent);
return threadState.newMessageFromBot(timestamp, messageContent);
}
}
6 changes: 3 additions & 3 deletions src/main/java/com/meta/chatbridge/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ enum Role {
SYSTEM
}

static Identifier conversationId(Identifier id1, Identifier id2) {
static Identifier threadId(Identifier id1, Identifier id2) {
if (id1.compareTo(id2) <= 0) {
return Identifier.from(id1.toString() + '|' + id2);
}
return Identifier.from(id2.toString() + '|' + id1);
}

default Identifier conversationId() {
return conversationId(senderId(), recipientId());
default Identifier threadId() {
return threadId(senderId(), recipientId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class MessageStack<T extends Message> {
public class ThreadState<T extends Message> {
private final List<T> messages;
private final MessageFactory<T> messageFactory;

private MessageStack(T message) {
private ThreadState(T message) {
Objects.requireNonNull(message);
this.messages = ImmutableList.of(message);
messageFactory = MessageFactory.instance(message);
}

/** Constructor that exists to support the with method */
private MessageStack(MessageStack<T> old, T newMessage) {
private ThreadState(ThreadState<T> old, T newMessage) {
Objects.requireNonNull(newMessage);
messageFactory = old.messageFactory;
Preconditions.checkArgument(
old.tail().conversationId().equals(newMessage.conversationId()),
"all messages in a stack must have the same conversation id");
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
this.messages =
Expand All @@ -46,11 +46,11 @@ private MessageStack(MessageStack<T> old, T newMessage) {

Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this message stack");
"userId and botId not consistent with this thread state");
}

public static <T extends Message> MessageStack<T> of(T message) {
return new MessageStack<>(message);
public static <T extends Message> ThreadState<T> of(T message) {
return new ThreadState<>(message);
}

public Identifier userId() {
Expand Down Expand Up @@ -78,8 +78,8 @@ public T newMessageFromUser(Instant timestamp, String message, Identifier instan
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
}

public MessageStack<T> with(T message) {
return new MessageStack<>(this, message);
public ThreadState<T> with(T message) {
return new ThreadState<>(this, message);
}

public List<T> messages() {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/meta/chatbridge/store/ChatStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
package com.meta.chatbridge.store;

import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;

/**
* This class is in charge of both maintaining a chat history and managing a queue of conversations
Expand All @@ -22,5 +22,5 @@
*/
public interface ChatStore<T extends Message> {

MessageStack<T> add(T message);
ThreadState<T> add(T message);
}
12 changes: 6 additions & 6 deletions src/main/java/com/meta/chatbridge/store/MemoryStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,32 @@
import com.google.common.cache.CacheBuilder;
import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.time.Duration;

public class MemoryStore<T extends Message> implements ChatStore<T> {
private final Cache<Identifier, MessageStack<T>> store;
private final Cache<Identifier, ThreadState<T>> store;

MemoryStore(MemoryStoreConfig config) {
this.store =
CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofHours(config.storageDurationHours()))
.maximumWeight((long) (config.storageCapacityMb() * Math.pow(2, 20))) // megabytes
.<Identifier, MessageStack<T>>weigher(
.<Identifier, ThreadState<T>>weigher(
(k, v) ->
v.messages().stream().map(m -> m.message().length()).reduce(0, Integer::sum))
.build();
}

@Override
public MessageStack<T> add(T message) {
public ThreadState<T> add(T message) {
return this.store
.asMap()
.compute(
message.conversationId(),
message.threadId(),
(k, v) -> {
if (v == null) {
return MessageStack.of(message);
return ThreadState.of(message);
}
return v.with(message);
});
Expand Down
24 changes: 12 additions & 12 deletions src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,47 @@
import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.time.Instant;
import java.util.concurrent.*;
import org.checkerframework.checker.nullness.qual.Nullable;

public class DummyFBMessageLLMHandler implements LLMPlugin<FBMessage> {

private final String dummyLLMResponse;
private final BlockingQueue<MessageStack<FBMessage>> receivedMessageStacks =
private final BlockingQueue<ThreadState<FBMessage>> receivedThreadStates =
new LinkedBlockingDeque<>();

public DummyFBMessageLLMHandler(String dummyLLMResponse) {
this.dummyLLMResponse = dummyLLMResponse;
}

public MessageStack<FBMessage> take(int waitMs) throws InterruptedException {
@Nullable MessageStack<FBMessage> value =
receivedMessageStacks.poll(waitMs, TimeUnit.MILLISECONDS);
public ThreadState<FBMessage> take(int waitMs) throws InterruptedException {
@Nullable ThreadState<FBMessage> value =
receivedThreadStates.poll(waitMs, TimeUnit.MILLISECONDS);
if (value == null) {
throw new RuntimeException("unable to remove item form queue in under " + waitMs + "ms");
}
return value;
}

public MessageStack<FBMessage> take() throws InterruptedException {
return receivedMessageStacks.take();
public ThreadState<FBMessage> take() throws InterruptedException {
return receivedThreadStates.take();
}

public @Nullable MessageStack<FBMessage> poll() {
return receivedMessageStacks.poll();
public @Nullable ThreadState<FBMessage> poll() {
return receivedThreadStates.poll();
}

public String dummyResponse() {
return dummyLLMResponse;
}

@Override
public FBMessage handle(MessageStack<FBMessage> messageStack) {
receivedMessageStacks.add(messageStack);
public FBMessage handle(ThreadState<FBMessage> threadState) {
receivedThreadStates.add(threadState);
FBMessage inbound =
messageStack.messages().stream()
threadState.messages().stream()
.filter(m -> m.role() == Message.Role.USER)
.findAny()
.orElseThrow();
Expand Down
38 changes: 19 additions & 19 deletions src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public class OpenAIPluginTest {
public static final JsonNode SAMPLE_RESPONSE = MAPPER.createObjectNode();
private static final String PATH = "/";
private static final String TEST_MESSAGE = "this is a test message";
private static final MessageStack<FBMessage> STACK =
MessageStack.of(
private static final ThreadState<FBMessage> THREAD =
ThreadState.of(
MessageFactory.instance(FBMessage.class)
.newMessage(
Instant.now(),
Expand Down Expand Up @@ -110,10 +110,10 @@ void sampleValid(OpenAIModel model) throws IOException, InterruptedException {
String apiKey = UUID.randomUUID().toString();
OpenAIConfig config = OpenAIConfig.builder(model, apiKey).build();
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
FBMessage message = plugin.handle(STACK);
FBMessage message = plugin.handle(THREAD);
assertThat(message.message()).isEqualTo(TEST_MESSAGE);
assertThat(message.role()).isSameAs(Role.ASSISTANT);
assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
assertThatCode(() -> THREAD.with(message)).doesNotThrowAnyException();
@Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
assertThat(or.headerMap().get("Authorization")).isNotNull().isEqualTo("Bearer " + apiKey);
Expand All @@ -139,10 +139,10 @@ void validConfigValues(OpenAIConfigTest.ConfigItem configItem)
OpenAIConfig config =
ConfigurationUtils.jsonMapper().convertValue(minimalConfig, OpenAIConfig.class);
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
FBMessage message = plugin.handle(STACK);
FBMessage message = plugin.handle(THREAD);
assertThat(message.message()).isEqualTo(TEST_MESSAGE);
assertThat(message.role()).isSameAs(Role.ASSISTANT);
assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
assertThatCode(() -> THREAD.with(message)).doesNotThrowAnyException();
@Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
assertThat(or.headerMap().get("Authorization"))
Expand Down Expand Up @@ -173,13 +173,13 @@ void contextTooBig() throws IOException {
OpenAIConfig config =
OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build();
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
MessageStack<FBMessage> stack =
STACK.with(
STACK.newMessageFromUser(
ThreadState<FBMessage> thread =
THREAD.with(
THREAD.newMessageFromUser(
Instant.now(),
Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()),
Identifier.random()));
FBMessage response = plugin.handle(stack);
FBMessage response = plugin.handle(thread);
assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me.");
assertThat(openAIRequests).hasSize(0);
}
Expand All @@ -189,8 +189,8 @@ void orderedCorrectly() throws IOException, InterruptedException {
OpenAIConfig config =
OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build();
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
MessageStack<FBMessage> stack =
MessageStack.of(
ThreadState<FBMessage> thread =
ThreadState.of(
MessageFactory.instance(FBMessage.class)
.newMessage(
Instant.now(),
Expand All @@ -199,20 +199,20 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Role.SYSTEM));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
plugin.handle(stack);
thread = thread.with(thread.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
plugin.handle(thread);
@Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
JsonNode body = MAPPER.readTree(or.body());

for (int i = 0; i < stack.messages().size(); i++) {
FBMessage stackMessage = stack.messages().get(i);
for (int i = 0; i < thread.messages().size(); i++) {
FBMessage threadMessage = thread.messages().get(i);
JsonNode sentMessage = body.get("messages").get(i);
assertSoftly(
s ->
s.assertThat(stackMessage.message())
s.assertThat(threadMessage.message())
.isEqualTo(sentMessage.get("content").textValue()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ void invalidMessage(
.isEqualTo(0); // make sure the message wasn't processed and stored
assertThat(requests).hasSize(0);
} else {
MessageStack<FBMessage> stack = llmHandler.take(500);
ThreadState<FBMessage> thread = llmHandler.take(500);
JsonNode messageObject = PARSED_SAMPLE_MESSAGE.get("entry").get(0).get("messaging").get(0);
String messageText = messageObject.get("message").get("text").textValue();
String mid = messageObject.get("message").get("mid").textValue();
Identifier recipientId =
Identifier.from(messageObject.get("recipient").get("id").textValue());
Identifier senderId = Identifier.from(messageObject.get("sender").get("id").textValue());
Instant timestamp = Instant.ofEpochMilli(messageObject.get("timestamp").longValue());
assertThat(stack.messages())
assertThat(thread.messages())
.hasSize(1)
.allSatisfy(m -> assertThat(m.message()).isEqualTo(messageText))
.allSatisfy(m -> assertThat(m.instanceId().toString()).isEqualTo(mid))
Expand Down
Loading

0 comments on commit 199e8a2

Please sign in to comment.