Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename Prominent Classes #9

Merged
merged 2 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/main/java/com/meta/chatbridge/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.meta.chatbridge.llm.LLMPlugin;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageHandler;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import com.meta.chatbridge.store.ChatStore;
import io.javalin.Javalin;
import io.javalin.http.Context;
Expand Down Expand Up @@ -44,8 +44,8 @@ void handle(Context ctx) {
List<T> messages = handler.processRequest(ctx);
// TODO: once we have a non-volatile store, on startup send stored but not replied to messages
for (T m : messages) {
MessageStack<T> stack = store.add(m);
executorService.submit(() -> execute(stack));
ThreadState<T> thread = store.add(m);
executorService.submit(() -> execute(thread));
}
}

Expand All @@ -61,10 +61,10 @@ public MessageHandler<T> messageHandler() {
return this.handler;
}

private void execute(MessageStack<T> stack) {
private void execute(ThreadState<T> thread) {
T llmResponse;
try {
llmResponse = llmPlugin.handle(stack);
llmResponse = llmPlugin.handle(thread);
} catch (IOException e) {
LOGGER.error("failed to communicate with LLM", e);
return;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/meta/chatbridge/llm/LLMPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
package com.meta.chatbridge.llm;

import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.io.IOException;

public interface LLMPlugin<T extends Message> {

T handle(MessageStack<T> messageStack) throws IOException;
T handle(ThreadState<T> threadState) throws IOException;
}
12 changes: 6 additions & 6 deletions src/main/java/com/meta/chatbridge/llm/OpenAIPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import com.knuddels.jtokkit.api.Encoding;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.Message.Role;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -126,8 +126,8 @@ private Optional<ArrayNode> pruneMessages(ArrayNode messages, @Nullable JsonNode
}

@Override
public T handle(MessageStack<T> messageStack) throws IOException {
T fromUser = messageStack.tail();
public T handle(ThreadState<T> threadState) throws IOException {
T fromUser = threadState.tail();

ObjectNode body = MAPPER.createObjectNode();
body.put("model", config.model().properties().name())
Expand Down Expand Up @@ -156,7 +156,7 @@ public T handle(MessageStack<T> messageStack) throws IOException {
.addObject()
.put("role", Role.SYSTEM.toString().toLowerCase())
.put("content", m));
for (T message : messageStack.messages()) {
for (T message : threadState.messages()) {
messages
.addObject()
.put("role", message.role().toString().toLowerCase())
Expand All @@ -165,7 +165,7 @@ public T handle(MessageStack<T> messageStack) throws IOException {

Optional<ArrayNode> prunedMessages = pruneMessages(messages, null);
if (prunedMessages.isEmpty()) {
return messageStack.newMessageFromBot(
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
}
body.set("messages", prunedMessages.get());
Expand All @@ -186,6 +186,6 @@ public T handle(MessageStack<T> messageStack) throws IOException {
Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
JsonNode choice = responseBody.get("choices").get(0);
String messageContent = choice.get("message").get("content").textValue();
return messageStack.newMessageFromBot(timestamp, messageContent);
return threadState.newMessageFromBot(timestamp, messageContent);
}
}
6 changes: 3 additions & 3 deletions src/main/java/com/meta/chatbridge/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ enum Role {
SYSTEM
}

static Identifier conversationId(Identifier id1, Identifier id2) {
static Identifier threadId(Identifier id1, Identifier id2) {
if (id1.compareTo(id2) <= 0) {
return Identifier.from(id1.toString() + '|' + id2);
}
return Identifier.from(id2.toString() + '|' + id1);
}

default Identifier conversationId() {
return conversationId(senderId(), recipientId());
default Identifier threadId() {
return threadId(senderId(), recipientId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class MessageStack<T extends Message> {
public class ThreadState<T extends Message> {
private final List<T> messages;
private final MessageFactory<T> messageFactory;

private MessageStack(T message) {
private ThreadState(T message) {
Objects.requireNonNull(message);
this.messages = ImmutableList.of(message);
messageFactory = MessageFactory.instance(message);
}

/** Constructor that exists to support the with method */
private MessageStack(MessageStack<T> old, T newMessage) {
private ThreadState(ThreadState<T> old, T newMessage) {
Objects.requireNonNull(newMessage);
messageFactory = old.messageFactory;
Preconditions.checkArgument(
old.tail().conversationId().equals(newMessage.conversationId()),
"all messages in a stack must have the same conversation id");
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
this.messages =
Expand All @@ -46,11 +46,11 @@ private MessageStack(MessageStack<T> old, T newMessage) {

Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this message stack");
"userId and botId not consistent with this thread state");
}

public static <T extends Message> MessageStack<T> of(T message) {
return new MessageStack<>(message);
public static <T extends Message> ThreadState<T> of(T message) {
return new ThreadState<>(message);
}

public Identifier userId() {
Expand Down Expand Up @@ -78,8 +78,8 @@ public T newMessageFromUser(Instant timestamp, String message, Identifier instan
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
}

public MessageStack<T> with(T message) {
return new MessageStack<>(this, message);
public ThreadState<T> with(T message) {
return new ThreadState<>(this, message);
}

public List<T> messages() {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/meta/chatbridge/store/ChatStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
package com.meta.chatbridge.store;

import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;

/**
* This class is in charge of both maintaining a chat history and managing a queue of conversations
Expand All @@ -22,5 +22,5 @@
*/
public interface ChatStore<T extends Message> {

MessageStack<T> add(T message);
ThreadState<T> add(T message);
}
12 changes: 6 additions & 6 deletions src/main/java/com/meta/chatbridge/store/MemoryStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,32 @@
import com.google.common.cache.CacheBuilder;
import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.time.Duration;

public class MemoryStore<T extends Message> implements ChatStore<T> {
private final Cache<Identifier, MessageStack<T>> store;
private final Cache<Identifier, ThreadState<T>> store;

MemoryStore(MemoryStoreConfig config) {
this.store =
CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofHours(config.storageDurationHours()))
.maximumWeight((long) (config.storageCapacityMb() * Math.pow(2, 20))) // megabytes
.<Identifier, MessageStack<T>>weigher(
.<Identifier, ThreadState<T>>weigher(
(k, v) ->
v.messages().stream().map(m -> m.message().length()).reduce(0, Integer::sum))
.build();
}

@Override
public MessageStack<T> add(T message) {
public ThreadState<T> add(T message) {
return this.store
.asMap()
.compute(
message.conversationId(),
message.threadId(),
(k, v) -> {
if (v == null) {
return MessageStack.of(message);
return ThreadState.of(message);
}
return v.with(message);
});
Expand Down
24 changes: 12 additions & 12 deletions src/test/java/com/meta/chatbridge/llm/DummyFBMessageLLMHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,47 @@
import com.meta.chatbridge.Identifier;
import com.meta.chatbridge.message.FBMessage;
import com.meta.chatbridge.message.Message;
import com.meta.chatbridge.message.MessageStack;
import com.meta.chatbridge.message.ThreadState;
import java.time.Instant;
import java.util.concurrent.*;
import org.checkerframework.checker.nullness.qual.Nullable;

public class DummyFBMessageLLMHandler implements LLMPlugin<FBMessage> {

private final String dummyLLMResponse;
private final BlockingQueue<MessageStack<FBMessage>> receivedMessageStacks =
private final BlockingQueue<ThreadState<FBMessage>> receivedThreadStates =
new LinkedBlockingDeque<>();

public DummyFBMessageLLMHandler(String dummyLLMResponse) {
this.dummyLLMResponse = dummyLLMResponse;
}

public MessageStack<FBMessage> take(int waitMs) throws InterruptedException {
@Nullable MessageStack<FBMessage> value =
receivedMessageStacks.poll(waitMs, TimeUnit.MILLISECONDS);
public ThreadState<FBMessage> take(int waitMs) throws InterruptedException {
@Nullable ThreadState<FBMessage> value =
receivedThreadStates.poll(waitMs, TimeUnit.MILLISECONDS);
if (value == null) {
throw new RuntimeException("unable to remove item form queue in under " + waitMs + "ms");
}
return value;
}

public MessageStack<FBMessage> take() throws InterruptedException {
return receivedMessageStacks.take();
public ThreadState<FBMessage> take() throws InterruptedException {
return receivedThreadStates.take();
}

public @Nullable MessageStack<FBMessage> poll() {
return receivedMessageStacks.poll();
public @Nullable ThreadState<FBMessage> poll() {
return receivedThreadStates.poll();
}

public String dummyResponse() {
return dummyLLMResponse;
}

@Override
public FBMessage handle(MessageStack<FBMessage> messageStack) {
receivedMessageStacks.add(messageStack);
public FBMessage handle(ThreadState<FBMessage> threadState) {
receivedThreadStates.add(threadState);
FBMessage inbound =
messageStack.messages().stream()
threadState.messages().stream()
.filter(m -> m.role() == Message.Role.USER)
.findAny()
.orElseThrow();
Expand Down
38 changes: 19 additions & 19 deletions src/test/java/com/meta/chatbridge/llm/OpenAIPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public class OpenAIPluginTest {
public static final JsonNode SAMPLE_RESPONSE = MAPPER.createObjectNode();
private static final String PATH = "/";
private static final String TEST_MESSAGE = "this is a test message";
private static final MessageStack<FBMessage> STACK =
MessageStack.of(
private static final ThreadState<FBMessage> THREAD =
ThreadState.of(
MessageFactory.instance(FBMessage.class)
.newMessage(
Instant.now(),
Expand Down Expand Up @@ -110,10 +110,10 @@ void sampleValid(OpenAIModel model) throws IOException, InterruptedException {
String apiKey = UUID.randomUUID().toString();
OpenAIConfig config = OpenAIConfig.builder(model, apiKey).build();
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
FBMessage message = plugin.handle(STACK);
FBMessage message = plugin.handle(THREAD);
assertThat(message.message()).isEqualTo(TEST_MESSAGE);
assertThat(message.role()).isSameAs(Role.ASSISTANT);
assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
assertThatCode(() -> THREAD.with(message)).doesNotThrowAnyException();
@Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
assertThat(or.headerMap().get("Authorization")).isNotNull().isEqualTo("Bearer " + apiKey);
Expand All @@ -139,10 +139,10 @@ void validConfigValues(OpenAIConfigTest.ConfigItem configItem)
OpenAIConfig config =
ConfigurationUtils.jsonMapper().convertValue(minimalConfig, OpenAIConfig.class);
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
FBMessage message = plugin.handle(STACK);
FBMessage message = plugin.handle(THREAD);
assertThat(message.message()).isEqualTo(TEST_MESSAGE);
assertThat(message.role()).isSameAs(Role.ASSISTANT);
assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
assertThatCode(() -> THREAD.with(message)).doesNotThrowAnyException();
@Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
assertThat(or.headerMap().get("Authorization"))
Expand Down Expand Up @@ -173,13 +173,13 @@ void contextTooBig() throws IOException {
OpenAIConfig config =
OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build();
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
MessageStack<FBMessage> stack =
STACK.with(
STACK.newMessageFromUser(
ThreadState<FBMessage> thread =
THREAD.with(
THREAD.newMessageFromUser(
Instant.now(),
Stream.generate(() -> "0123456789").limit(100).collect(Collectors.joining()),
Identifier.random()));
FBMessage response = plugin.handle(stack);
FBMessage response = plugin.handle(thread);
assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me.");
assertThat(openAIRequests).hasSize(0);
}
Expand All @@ -189,8 +189,8 @@ void orderedCorrectly() throws IOException, InterruptedException {
OpenAIConfig config =
OpenAIConfig.builder(OpenAIModel.GPT35TURBO, "lkjasdlkjasdf").maxInputTokens(100).build();
OpenAIPlugin<FBMessage> plugin = new OpenAIPlugin<FBMessage>(config).endpoint(endpoint);
MessageStack<FBMessage> stack =
MessageStack.of(
ThreadState<FBMessage> thread =
ThreadState.of(
MessageFactory.instance(FBMessage.class)
.newMessage(
Instant.now(),
Expand All @@ -199,20 +199,20 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Role.SYSTEM));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
plugin.handle(stack);
thread = thread.with(thread.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
plugin.handle(thread);
@Nullable OutboundRequest or = openAIRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
JsonNode body = MAPPER.readTree(or.body());

for (int i = 0; i < stack.messages().size(); i++) {
FBMessage stackMessage = stack.messages().get(i);
for (int i = 0; i < thread.messages().size(); i++) {
FBMessage threadMessage = thread.messages().get(i);
JsonNode sentMessage = body.get("messages").get(i);
assertSoftly(
s ->
s.assertThat(stackMessage.message())
s.assertThat(threadMessage.message())
.isEqualTo(sentMessage.get("content").textValue()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ void invalidMessage(
.isEqualTo(0); // make sure the message wasn't processed and stored
assertThat(requests).hasSize(0);
} else {
MessageStack<FBMessage> stack = llmHandler.take(500);
ThreadState<FBMessage> thread = llmHandler.take(500);
JsonNode messageObject = PARSED_SAMPLE_MESSAGE.get("entry").get(0).get("messaging").get(0);
String messageText = messageObject.get("message").get("text").textValue();
String mid = messageObject.get("message").get("mid").textValue();
Identifier recipientId =
Identifier.from(messageObject.get("recipient").get("id").textValue());
Identifier senderId = Identifier.from(messageObject.get("sender").get("id").textValue());
Instant timestamp = Instant.ofEpochMilli(messageObject.get("timestamp").longValue());
assertThat(stack.messages())
assertThat(thread.messages())
.hasSize(1)
.allSatisfy(m -> assertThat(m.message()).isEqualTo(messageText))
.allSatisfy(m -> assertThat(m.instanceId().toString()).isEqualTo(mid))
Expand Down
Loading