From 0fda34c21e55e87de6d929ec2b19fa1312f1d9e6 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sat, 23 Nov 2024 02:01:42 -0800 Subject: [PATCH] feat: proper permissions for push notifications --- Cargo.lock | 32 +- .../database/src/models/messages/model.rs | 26 +- .../database/src/util/bulk_permissions.rs | 331 ++++++++++++++++++ crates/core/database/src/util/mod.rs | 1 + crates/daemons/pushd/src/main.rs | 4 +- 5 files changed, 357 insertions(+), 37 deletions(-) create mode 100644 crates/core/database/src/util/bulk_permissions.rs diff --git a/Cargo.lock b/Cargo.lock index fe742d8de..529c77ee5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -466,7 +466,7 @@ dependencies = [ "regex", "reqwest 0.11.10", "revolt_okapi", - "revolt_rocket_okapi 0.10.0", + "revolt_rocket_okapi", "rocket", "rust-argon2", "schemars", @@ -5646,7 +5646,7 @@ dependencies = [ "revolt_a2", "revolt_okapi", "revolt_optional_struct", - "revolt_rocket_okapi 0.10.0", + "revolt_rocket_okapi", "rocket", "schemars", "serde", @@ -5691,7 +5691,7 @@ dependencies = [ "revolt-permissions", "revolt-presence", "revolt-result", - "revolt_rocket_okapi 0.10.0", + "revolt_rocket_okapi", "rocket", "rocket_authifier", "rocket_cors", @@ -5844,7 +5844,7 @@ version = "0.7.18" dependencies = [ "axum", "revolt_okapi", - "revolt_rocket_okapi 0.10.0", + "revolt_rocket_okapi", "rocket", "schemars", "serde", @@ -5904,22 +5904,6 @@ dependencies = [ "syn 0.11.11", ] -[[package]] -name = "revolt_rocket_okapi" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "275e1e9bd3343f75225cafa64f4bfb939c8b21c5f861141180fc0e24769ff6cf" -dependencies = [ - "either", - "log", - "revolt_okapi", - "revolt_rocket_okapi_codegen", - "rocket", - "schemars", - "serde", - "serde_json", -] - [[package]] name = "revolt_rocket_okapi" version = "0.10.0" @@ -6061,14 +6045,14 @@ dependencies = [ [[package]] name = "rocket_authifier" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89a12311f60e9288833fc3ce6029bce5d5c61870ceef74d4a50668a8b520ad" +checksum = "810753b79106c44a4e76247fc7576b660663133a9e8f4b0afeb303589ec51d59" dependencies = [ "authifier", "iso8601-timestamp 0.1.10", "revolt_okapi", - "revolt_rocket_okapi 0.9.1", + "revolt_rocket_okapi", "rocket", "rocket_empty", "schemars", @@ -6115,7 +6099,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97a55000e1ef5f4a9b20ae3d9de2a0bd22620c78ebd1aa568776ae12276125a6" dependencies = [ "revolt_okapi", - "revolt_rocket_okapi 0.10.0", + "revolt_rocket_okapi", "rocket", ] diff --git a/crates/core/database/src/models/messages/model.rs b/crates/core/database/src/models/messages/model.rs index c4f42fb5a..61b6499ff 100644 --- a/crates/core/database/src/models/messages/model.rs +++ b/crates/core/database/src/models/messages/model.rs @@ -15,7 +15,7 @@ use validator::Validate; use crate::{ events::client::EventV1, tasks::{self, ack::AckEvent}, - util::idempotency::IdempotencyKey, + util::{bulk_permissions::BulkDatabasePermissionQuery, idempotency::IdempotencyKey}, Channel, Database, Emoji, File, User, AMQP, }; @@ -358,17 +358,19 @@ impl Message { mentions.retain(|m| valid_mentions.contains(m)); // quick pass, validate mentions are in the server - // Need to build a struct for bulk querying user permissions for a channel, - // as this would involve fetching all the requisite information (server permissions, channel permissions, etc) for every user. - // if !mentions.is_empty() { - // // if there are still mentions, drill down to a channel-level - // for member in valid_members.iter() { - // DatabasePermissionQuery::new(db, member.into()) - // .channel(&channel) - // .member(&member) - // . - // } - // } + if !mentions.is_empty() { + // if there are still mentions, drill down to a channel-level + let member_channel_view_perms = + BulkDatabasePermissionQuery::from_server_id(db, server) + .await + .channel(&channel) + .members(&valid_members) + .members_can_see_channel() + .await; + + mentions + .retain(|m| *member_channel_view_perms.get(m).unwrap_or(&false)); + } } else { revolt_config::capture_error(&valid_members.unwrap_err()); return Err(create_error!(InternalError)); diff --git a/crates/core/database/src/util/bulk_permissions.rs b/crates/core/database/src/util/bulk_permissions.rs new file mode 100644 index 000000000..eb9bca475 --- /dev/null +++ b/crates/core/database/src/util/bulk_permissions.rs @@ -0,0 +1,331 @@ +use std::{collections::HashMap, hash::RandomState}; + +use revolt_permissions::{ + ChannelPermission, ChannelType, Override, OverrideField, PermissionValue, ALLOW_IN_TIMEOUT, + DEFAULT_PERMISSION_DIRECT_MESSAGE, +}; + +use crate::{Channel, Database, Member, Server, User}; + +#[derive(Clone)] +pub struct BulkDatabasePermissionQuery<'a> { + #[allow(dead_code)] + database: &'a Database, + + server: Server, + channel: Option, + users: Option>, + members: Option>, + + // In case the users or members are fetched as part of the permissions checking operation + pub(crate) cached_users: Option>, + pub(crate) cached_members: Option>, + + cached_member_perms: Option>, +} + +impl<'z, 'x> BulkDatabasePermissionQuery<'x> { + pub async fn members_can_see_channel(&'z mut self) -> HashMap + where + 'z: 'x, + { + let member_perms = if self.cached_member_perms.is_some() { + // This isn't done as an if let to prevent borrow checker errors with the mut self call when the perms aren't cached. + let perms = self.cached_member_perms.as_ref().unwrap(); + perms + .iter() + .map(|(m, p)| { + ( + m.clone(), + p.has_channel_permission(ChannelPermission::ViewChannel), + ) + }) + .collect() + } else { + calculate_members_permissions(self) + .await + .iter() + .map(|(m, p)| { + ( + m.clone(), + p.has_channel_permission(ChannelPermission::ViewChannel), + ) + }) + .collect() + }; + member_perms + } +} + +impl<'z> BulkDatabasePermissionQuery<'z> { + pub fn new(database: &Database, server: Server) -> BulkDatabasePermissionQuery<'_> { + BulkDatabasePermissionQuery { + database, + server, + channel: None, + users: None, + members: None, + cached_members: None, + cached_users: None, + cached_member_perms: None, + } + } + + pub async fn from_server_id<'a>( + db: &'a Database, + server: &str, + ) -> BulkDatabasePermissionQuery<'a> { + BulkDatabasePermissionQuery { + database: db, + server: db.fetch_server(server).await.unwrap(), + channel: None, + users: None, + members: None, + cached_members: None, + cached_users: None, + cached_member_perms: None, + } + } + + pub fn channel(self, channel: &'z Channel) -> BulkDatabasePermissionQuery { + BulkDatabasePermissionQuery { + channel: Some(channel.clone()), + ..self + } + } + + pub fn members(self, members: &'z [Member]) -> BulkDatabasePermissionQuery { + BulkDatabasePermissionQuery { + members: Some(members.to_owned()), + ..self + } + } + + pub fn users(self, users: &'z [User]) -> BulkDatabasePermissionQuery { + BulkDatabasePermissionQuery { + users: Some(users.to_owned()), + ..self + } + } + + /// Get the default channel permissions + /// Group channel defaults should be mapped to an allow-only override + #[allow(dead_code)] + async fn get_default_channel_permissions(&mut self) -> Override { + if let Some(channel) = &self.channel { + match channel { + Channel::Group { permissions, .. } => Override { + allow: permissions.unwrap_or(*DEFAULT_PERMISSION_DIRECT_MESSAGE as i64) as u64, + deny: 0, + }, + Channel::TextChannel { + default_permissions, + .. + } + | Channel::VoiceChannel { + default_permissions, + .. + } => default_permissions.unwrap_or_default().into(), + _ => Default::default(), + } + } else { + Default::default() + } + } + + #[allow(dead_code)] + fn get_channel_type(&mut self) -> ChannelType { + if let Some(channel) = &self.channel { + match channel { + Channel::DirectMessage { .. } => ChannelType::DirectMessage, + Channel::Group { .. } => ChannelType::Group, + Channel::SavedMessages { .. } => ChannelType::SavedMessages, + Channel::TextChannel { .. } | Channel::VoiceChannel { .. } => { + ChannelType::ServerChannel + } + } + } else { + ChannelType::Unknown + } + } + + /// Get the ordered role overrides (from lowest to highest) for this member in this channel + #[allow(dead_code)] + async fn get_channel_role_overrides(&mut self) -> &HashMap { + if let Some(channel) = &self.channel { + match channel { + Channel::TextChannel { + role_permissions, .. + } + | Channel::VoiceChannel { + role_permissions, .. + } => role_permissions, + _ => panic!("Not supported for non-server channels"), + } + } else { + panic!("No channel added to query") + } + } +} + +/// Calculate members permissions in a server channel. +async fn calculate_members_permissions<'a>( + query: &'a mut BulkDatabasePermissionQuery<'a>, +) -> HashMap { + let mut resp = HashMap::new(); + + let (_, channel_role_permissions) = match query + .channel + .as_ref() + .expect("A channel must be assigned to calculate channel permissions") + .clone() + { + Channel::TextChannel { + id, + role_permissions, + .. + } + | Channel::VoiceChannel { + id, + role_permissions, + .. + } => (id, role_permissions), + _ => panic!("Calculation of member permissions must be done on a server channel"), + }; + + if query.users.is_none() { + let ids: Vec = query + .members + .as_ref() + .expect("No users or members added to the query") + .iter() + .map(|m| m.id.user.clone()) + .collect(); + + query.cached_users = Some( + query + .database + .fetch_users(&ids[..]) + .await + .expect("Failed to get data from the db"), + ); + + query.users = Some(query.cached_users.as_ref().unwrap().to_vec()) + } + + let users = query.users.as_ref().unwrap(); + + if query.members.is_none() { + let ids: Vec = query + .users + .as_ref() + .expect("No users or members added to the query") + .iter() + .map(|m| m.id.clone()) + .collect(); + + query.cached_members = Some( + query + .database + .fetch_members(&query.server.id, &ids[..]) + .await + .expect("Failed to get data from the db"), + ); + query.members = Some(query.cached_members.as_ref().unwrap().to_vec()) + } + + let members: HashMap<&String, &Member, RandomState> = HashMap::from_iter( + query + .members + .as_ref() + .unwrap() + .iter() + .map(|m| (&m.id.user, m)), + ); + + for user in users { + let member = members.get(&user.id); + + // User isn't a part of the server + if member.is_none() { + resp.insert(user.id.clone(), 0_u64.into()); + continue; + } + + let member = *member.unwrap(); + + if user.privileged { + resp.insert( + user.id.clone(), + PermissionValue::from(ChannelPermission::GrantAllSafe), + ); + continue; + } + + if user.id == query.server.owner { + resp.insert( + user.id.clone(), + PermissionValue::from(ChannelPermission::GrantAllSafe), + ); + continue; + } + + // Get the user's server permissions + let mut permission = calculate_server_permissions(&query.server, user, member); + + // Get the applicable role overrides + let mut roles = channel_role_permissions + .iter() + .filter(|(id, _)| member.roles.contains(id)) + .filter_map(|(id, permission)| { + query.server.roles.get(id).map(|role| { + let v: Override = (*permission).into(); + (role.rank, v) + }) + }) + .collect::>(); + + roles.sort_by(|a, b| b.0.cmp(&a.0)); + let overrides = roles.into_iter().map(|(_, v)| v); + + for role_override in overrides { + permission.apply(role_override) + } + + resp.insert(user.id.clone(), permission); + } + + resp +} + +/// Calculates a member's server permissions +fn calculate_server_permissions(server: &Server, user: &User, member: &Member) -> PermissionValue { + if user.privileged || server.owner == user.id { + return ChannelPermission::GrantAllSafe.into(); + } + + let mut permissions: PermissionValue = server.default_permissions.into(); + + let mut roles = server + .roles + .iter() + .filter(|(id, _)| member.roles.contains(id)) + .map(|(_, role)| { + let v: Override = role.permissions.into(); + (role.rank, v) + }) + .collect::>(); + + roles.sort_by(|a, b| b.0.cmp(&a.0)); + let role_overrides: Vec = roles.into_iter().map(|(_, v)| v).collect(); + + for role in role_overrides { + permissions.apply(role); + } + + if member.in_timeout() { + permissions.restrict(*ALLOW_IN_TIMEOUT); + } + + permissions +} diff --git a/crates/core/database/src/util/mod.rs b/crates/core/database/src/util/mod.rs index 26cf436e6..1baf7d8ba 100644 --- a/crates/core/database/src/util/mod.rs +++ b/crates/core/database/src/util/mod.rs @@ -1,4 +1,5 @@ pub mod bridge; +pub mod bulk_permissions; pub mod idempotency; pub mod permissions; pub mod reference; diff --git a/crates/daemons/pushd/src/main.rs b/crates/daemons/pushd/src/main.rs index 3c9da6dc1..58ac14d10 100644 --- a/crates/daemons/pushd/src/main.rs +++ b/crates/daemons/pushd/src/main.rs @@ -203,7 +203,9 @@ where routing_key, )) .await - .unwrap(); + .expect( + "This probably means the revolt.notifications exchange does not exist in rabbitmq!", + ); let args = BasicConsumeArguments::new(queue_name, "") .manual_ack(false)