Skip to content

Commit

Permalink
♻️ Ammends to opt-out (#7)
Browse files Browse the repository at this point in the history
This commit changes a few leftover issues with filtering messages from opted out users.
  • Loading branch information
chrisliebaer authored Feb 17, 2024
1 parent a6534b4 commit 6f35666
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/context_extraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ pub struct InvocationContextSettings {

impl InvocationContextSettings {
pub async fn extract_context_from_message(&self, ctx: &Context, message: &Message) -> Result<Vec<ContextMessageVariant>> {
// TODO: track which limits were exceeded
let mut limit_tracker = LimitTracker::new();
let mut messages = Vec::<ContextMessageVariant>::new();

Expand Down Expand Up @@ -157,6 +156,8 @@ impl InvocationContextSettings {
});
messages.dedup_by_key(|m| m.id());

//

Ok(messages)
}
}
Expand Down
46 changes: 44 additions & 2 deletions src/handler/completion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::HashMap;
use std::collections::{
HashMap,
HashSet,
};

use async_openai::types::{
ChatCompletionRequestMessage,
Expand All @@ -18,6 +21,7 @@ use poise::serenity_prelude::{
CreateMessage,
Message,
};
use sea_orm::DatabaseConnection;
use tracing::{
info,
trace,
Expand Down Expand Up @@ -256,9 +260,13 @@ async fn generate_openai_response<'a>(
.join("\n");

// TODO: implement message cache to avoid fetching messages multiple times
let chat_history = context_settings.extract_context_from_message(ctx, message).await?;
// TODO: pass message cache as argument
let mut chat_history = context_settings.extract_context_from_message(ctx, message).await?;
dump_extracted_messages(&chat_history);

// remove all messages for users that opted out
remove_opted_out_users(&app.db, &mut chat_history).await?;

// unpack chat history into messages, we longer need inclusion reason
let chat_history = chat_history.iter().map(|m| m.into()).collect::<Vec<&Message>>();

Expand Down Expand Up @@ -318,6 +326,40 @@ async fn generate_openai_response<'a>(
Ok(())
}

async fn remove_opted_out_users(db: &DatabaseConnection, messages: &mut Vec<ContextMessageVariant>) -> Result<()> {
// extract all user ids from messages
let authors = messages
.iter()
// convert into message and get user id
.map(|m| {
let msg: &Message = m.into();
&msg.author
})
.collect::<HashSet<_>>();

// fetch database objects to check for opt-out status
let mut opt_out_users = HashSet::new();
for author in authors {
let user = user_from_db_or_create(db, author).await?;

if user.opt_out_since.is_some() {
opt_out_users.insert(user.discord_user_id);
continue;
}
}

messages.retain(|m| {
let msg: &Message = m.into();
let retain = !opt_out_users.contains(&msg.author.id.get());

trace!("Removing message {} from user {} due to opt-out", msg.id, msg.author.name);

retain
});

Ok(())
}

fn dump_request_messages(messages: &Vec<ChatCompletionRequestMessage>) {
let mut lines = Vec::new();

Expand Down
4 changes: 4 additions & 0 deletions src/handler/opt_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use sea_orm::{
};

use crate::{
message_cache::MessageCache,
user_from_db_or_create,
Context,
};
Expand Down Expand Up @@ -155,6 +156,9 @@ pub async fn opt_out_dialogue(ctx: Context<'_>) -> Result<()> {
return Ok(());
};

let cache = MessageCache::new(&app.db);
cache.delete_from_user(ctx.author().id).await?;

db_user
.update(&app.db)
.await
Expand Down
16 changes: 12 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod context_extraction;
mod gcra;
mod handler;
mod invocation_builder;
mod message_cache;
mod rate_limit_config;

use std::{
Expand Down Expand Up @@ -50,6 +51,7 @@ use sea_orm::{
ActiveValue::Set,
ColumnTrait,
ConnectOptions,
ConnectionTrait,
Database,
DatabaseConnection,
EntityTrait,
Expand All @@ -58,7 +60,6 @@ use sea_orm::{
use tera::Tera;
use tokio::sync::Mutex;
use tracing::{
debug,
error,
info,
info_span,
Expand All @@ -75,6 +76,7 @@ use crate::{
completion::handle_completion,
opt_out,
},
message_cache::MessageCache,
rate_limit_config::{
PathRateLimits,
RateLimitConfig,
Expand Down Expand Up @@ -356,8 +358,14 @@ async fn discord_listener<'a>(ctx: &'a poise::serenity_prelude::Context, ev: &'a
FullEvent::MessageUpdate {
new: Some(new), ..
} => {
// TODO: invalidate moderation cache for message
debug!("message {} updated, invalidating cache", new.id);
let message_cache = MessageCache::new(&app.db);
message_cache.invalidate(&new.id).await?;
},
FullEvent::MessageDelete {
deleted_message_id, ..
} => {
let message_cache = MessageCache::new(&app.db);
message_cache.invalidate(deleted_message_id).await?;
},
_ => {},
}
Expand All @@ -378,7 +386,7 @@ pub async fn help(
Ok(())
}

pub async fn user_from_db_or_create(db: &DatabaseConnection, user: &User) -> Result<user::Model> {
pub async fn user_from_db_or_create<C: ConnectionTrait>(db: &C, user: &User) -> Result<user::Model> {
let id = user.id.get();
let name = &user.name;

Expand Down
135 changes: 135 additions & 0 deletions src/message_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use entity::message_cache;
use log::debug;
use miette::{
IntoDiagnostic,
Result,
WrapErr,
};
use poise::serenity_prelude::{
ChannelId,
Message,
MessageId,
UserId,
};
use sea_orm::{
ActiveModelTrait,
ActiveValue::Set,
ColumnTrait,
ConnectionTrait,
EntityTrait,
QueryFilter,
};

use crate::{
user_from_db_or_create,
Context,
};

/// Database backed message cache. Used to minimize the amount of requests to the Discord API. Once a message has been
/// fetched, it is stored in the cache for a certain amount of time. On message updates or deletions, the cache needs to
/// be invalidated.
pub struct MessageCache<'a, C> {
db: &'a C,
}

impl<'a, C: ConnectionTrait> MessageCache<'a, C> {
/// Creates a new handle to the message cache.
pub fn new(db: &'a C) -> Self {
Self {
db,
}
}

pub async fn _add(&self, message: &Message) -> Result<message_cache::Model> {
// ensure author is also present in database
let _ = user_from_db_or_create(self.db, &message.author).await?;

let existing = {
message_cache::Entity::find()
.filter(message_cache::Column::DiscordMessageId.eq(message.id.get()))
.one(self.db)
.await
.into_diagnostic()
.wrap_err("failed to fetch message cache entry")?
};

if let Some(existing) = existing {
debug!("message {} already in cache", message.id);
return Ok(existing);
}

let entry = message_cache::ActiveModel {
discord_message_id: Set(message.id.get()),
discord_user_id: Set(message.author.id.get()),
content: Set(message.content.clone()),
..Default::default()
};

let model = entry
.insert(self.db)
.await
.into_diagnostic()
.wrap_err("failed to insert message cache entry")?;

Ok(model)
}

pub async fn delete_from_user(&self, user_id: UserId) -> Result<()> {
let entry = message_cache::ActiveModel {
discord_user_id: Set(user_id.get()),
..Default::default()
};

entity::prelude::MessageCache::delete(entry)
.exec(self.db)
.await
.into_diagnostic()
.wrap_err("failed to delete message cache entry")?;

Ok(())
}

/// Fetches a message from the cache. If the message is not in the cache, it will be loaded from the Discord API.
pub async fn _fetch(
&self,
channel_id: ChannelId,
message_id: MessageId,
ctx: &Context<'_>,
) -> Result<Option<message_cache::Model>> {
let entry = entity::prelude::MessageCache::find()
.filter(message_cache::Column::DiscordMessageId.eq(message_id.get()))
.one(self.db)
.await
.into_diagnostic()
.wrap_err("failed to fetch message cache entry")?;

// if not in cache, we fetch from discord and add to cache
let entry = match entry {
Some(entry) => entry,
None => {
debug!("message {} not in cache, fetching from discord", message_id);
let discord_message = ctx
.http()
.get_message(channel_id, message_id)
.await
.into_diagnostic()
.wrap_err("failed to fetch message from discord")?;
self._add(&discord_message).await?
},
};

Ok(Some(entry))
}

/// Invalidates a message in the cache. This is used when a message is updated or deleted.
pub async fn invalidate(&self, message_id: &MessageId) -> Result<()> {
entity::prelude::MessageCache::delete_many()
.filter(message_cache::Column::DiscordMessageId.eq(message_id.get()))
.exec(self.db)
.await
.into_diagnostic()
.wrap_err("failed to delete message cache entry")?;

Ok(())
}
}

0 comments on commit 6f35666

Please sign in to comment.