diff --git a/src/context_extraction.rs b/src/context_extraction.rs index 3efaabc..ca23350 100644 --- a/src/context_extraction.rs +++ b/src/context_extraction.rs @@ -38,7 +38,6 @@ pub struct InvocationContextSettings { impl InvocationContextSettings { pub async fn extract_context_from_message(&self, ctx: &Context, message: &Message) -> Result> { - // TODO: track which limits were exceeded let mut limit_tracker = LimitTracker::new(); let mut messages = Vec::::new(); @@ -157,6 +156,8 @@ impl InvocationContextSettings { }); messages.dedup_by_key(|m| m.id()); + // + Ok(messages) } } diff --git a/src/handler/completion.rs b/src/handler/completion.rs index bcdabfa..4b8868c 100644 --- a/src/handler/completion.rs +++ b/src/handler/completion.rs @@ -1,4 +1,7 @@ -use std::collections::HashMap; +use std::collections::{ + HashMap, + HashSet, +}; use async_openai::types::{ ChatCompletionRequestMessage, @@ -18,6 +21,7 @@ use poise::serenity_prelude::{ CreateMessage, Message, }; +use sea_orm::DatabaseConnection; use tracing::{ info, trace, @@ -256,9 +260,13 @@ async fn generate_openai_response<'a>( .join("\n"); // TODO: implement message cache to avoid fetching messages multiple times - let chat_history = context_settings.extract_context_from_message(ctx, message).await?; + // TODO: pass message cache as argument + let mut chat_history = context_settings.extract_context_from_message(ctx, message).await?; dump_extracted_messages(&chat_history); + // remove all messages for users that opted out + remove_opted_out_users(&app.db, &mut chat_history).await?; + // unpack chat history into messages, we longer need inclusion reason let chat_history = chat_history.iter().map(|m| m.into()).collect::>(); @@ -318,6 +326,40 @@ async fn generate_openai_response<'a>( Ok(()) } +async fn remove_opted_out_users(db: &DatabaseConnection, messages: &mut Vec) -> Result<()> { + // extract all user ids from messages + let authors = messages + .iter() + // convert into message and get user id + .map(|m| { + let msg: &Message = m.into(); + &msg.author + }) + .collect::>(); + + // fetch database objects to check for opt-out status + let mut opt_out_users = HashSet::new(); + for author in authors { + let user = user_from_db_or_create(db, author).await?; + + if user.opt_out_since.is_some() { + opt_out_users.insert(user.discord_user_id); + continue; + } + } + + messages.retain(|m| { + let msg: &Message = m.into(); + let retain = !opt_out_users.contains(&msg.author.id.get()); + + trace!("Removing message {} from user {} due to opt-out", msg.id, msg.author.name); + + retain + }); + + Ok(()) +} + fn dump_request_messages(messages: &Vec) { let mut lines = Vec::new(); diff --git a/src/handler/opt_out.rs b/src/handler/opt_out.rs index 561e661..9d815b8 100644 --- a/src/handler/opt_out.rs +++ b/src/handler/opt_out.rs @@ -26,6 +26,7 @@ use sea_orm::{ }; use crate::{ + message_cache::MessageCache, user_from_db_or_create, Context, }; @@ -155,6 +156,9 @@ pub async fn opt_out_dialogue(ctx: Context<'_>) -> Result<()> { return Ok(()); }; + let cache = MessageCache::new(&app.db); + cache.delete_from_user(ctx.author().id).await?; + db_user .update(&app.db) .await diff --git a/src/main.rs b/src/main.rs index b890ac6..313d88c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod context_extraction; mod gcra; mod handler; mod invocation_builder; +mod message_cache; mod rate_limit_config; use std::{ @@ -50,6 +51,7 @@ use sea_orm::{ ActiveValue::Set, ColumnTrait, ConnectOptions, + ConnectionTrait, Database, DatabaseConnection, EntityTrait, @@ -58,7 +60,6 @@ use sea_orm::{ use tera::Tera; use tokio::sync::Mutex; use tracing::{ - debug, error, info, info_span, @@ -75,6 +76,7 @@ use crate::{ completion::handle_completion, opt_out, }, + message_cache::MessageCache, rate_limit_config::{ PathRateLimits, RateLimitConfig, @@ -356,8 +358,14 @@ async fn discord_listener<'a>(ctx: &'a poise::serenity_prelude::Context, ev: &'a FullEvent::MessageUpdate { new: Some(new), .. } => { - // TODO: invalidate moderation cache for message - debug!("message {} updated, invalidating cache", new.id); + let message_cache = MessageCache::new(&app.db); + message_cache.invalidate(&new.id).await?; + }, + FullEvent::MessageDelete { + deleted_message_id, .. + } => { + let message_cache = MessageCache::new(&app.db); + message_cache.invalidate(deleted_message_id).await?; }, _ => {}, } @@ -378,7 +386,7 @@ pub async fn help( Ok(()) } -pub async fn user_from_db_or_create(db: &DatabaseConnection, user: &User) -> Result { +pub async fn user_from_db_or_create(db: &C, user: &User) -> Result { let id = user.id.get(); let name = &user.name; diff --git a/src/message_cache.rs b/src/message_cache.rs new file mode 100644 index 0000000..e85452c --- /dev/null +++ b/src/message_cache.rs @@ -0,0 +1,135 @@ +use entity::message_cache; +use log::debug; +use miette::{ + IntoDiagnostic, + Result, + WrapErr, +}; +use poise::serenity_prelude::{ + ChannelId, + Message, + MessageId, + UserId, +}; +use sea_orm::{ + ActiveModelTrait, + ActiveValue::Set, + ColumnTrait, + ConnectionTrait, + EntityTrait, + QueryFilter, +}; + +use crate::{ + user_from_db_or_create, + Context, +}; + +/// Database backed message cache. Used to minimize the amount of requests to the Discord API. Once a message has been +/// fetched, it is stored in the cache for a certain amount of time. On message updates or deletions, the cache needs to +/// be invalidated. +pub struct MessageCache<'a, C> { + db: &'a C, +} + +impl<'a, C: ConnectionTrait> MessageCache<'a, C> { + /// Creates a new handle to the message cache. + pub fn new(db: &'a C) -> Self { + Self { + db, + } + } + + pub async fn _add(&self, message: &Message) -> Result { + // ensure author is also present in database + let _ = user_from_db_or_create(self.db, &message.author).await?; + + let existing = { + message_cache::Entity::find() + .filter(message_cache::Column::DiscordMessageId.eq(message.id.get())) + .one(self.db) + .await + .into_diagnostic() + .wrap_err("failed to fetch message cache entry")? + }; + + if let Some(existing) = existing { + debug!("message {} already in cache", message.id); + return Ok(existing); + } + + let entry = message_cache::ActiveModel { + discord_message_id: Set(message.id.get()), + discord_user_id: Set(message.author.id.get()), + content: Set(message.content.clone()), + ..Default::default() + }; + + let model = entry + .insert(self.db) + .await + .into_diagnostic() + .wrap_err("failed to insert message cache entry")?; + + Ok(model) + } + + pub async fn delete_from_user(&self, user_id: UserId) -> Result<()> { + let entry = message_cache::ActiveModel { + discord_user_id: Set(user_id.get()), + ..Default::default() + }; + + entity::prelude::MessageCache::delete(entry) + .exec(self.db) + .await + .into_diagnostic() + .wrap_err("failed to delete message cache entry")?; + + Ok(()) + } + + /// Fetches a message from the cache. If the message is not in the cache, it will be loaded from the Discord API. + pub async fn _fetch( + &self, + channel_id: ChannelId, + message_id: MessageId, + ctx: &Context<'_>, + ) -> Result> { + let entry = entity::prelude::MessageCache::find() + .filter(message_cache::Column::DiscordMessageId.eq(message_id.get())) + .one(self.db) + .await + .into_diagnostic() + .wrap_err("failed to fetch message cache entry")?; + + // if not in cache, we fetch from discord and add to cache + let entry = match entry { + Some(entry) => entry, + None => { + debug!("message {} not in cache, fetching from discord", message_id); + let discord_message = ctx + .http() + .get_message(channel_id, message_id) + .await + .into_diagnostic() + .wrap_err("failed to fetch message from discord")?; + self._add(&discord_message).await? + }, + }; + + Ok(Some(entry)) + } + + /// Invalidates a message in the cache. This is used when a message is updated or deleted. + pub async fn invalidate(&self, message_id: &MessageId) -> Result<()> { + entity::prelude::MessageCache::delete_many() + .filter(message_cache::Column::DiscordMessageId.eq(message_id.get())) + .exec(self.db) + .await + .into_diagnostic() + .wrap_err("failed to delete message cache entry")?; + + Ok(()) + } +}