Skip to content

Commit

Permalink
Add channel whitelist functionality (#4)
Browse files Browse the repository at this point in the history
Implemented a new feature that allows the bot to respond only in
whitelisted channels. The whitelist setting can be set using the
`WHITELIST_CHANNELS` environment variable.
  • Loading branch information
chrisliebaer authored Jan 15, 2024
1 parent 4b83298 commit 12c94aa
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ The project requires the following environment variables:
- `DISCORD_TOKEN`: Your Discord bot token.
- `TEMPLATE_DIR`: The directory where your Tera templates are located. Defaults to `templates`.
- `RATE_LIMIT_CONFIG`: The path to your rate limit configuration file. Defaults to `rate_limits.toml`.
- `DATABASE_URL`: The URL to your database. For example `mysql://user:password@localhost/database`.
- `WHITELIST_CHANNEL`: A comma separated list of channel IDs that the bot is allowed to respond in. If not set, the bot will respond in all channels.


## License
Expand Down
10 changes: 10 additions & 0 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ impl From<&poise::serenity_prelude::User> for UserContext {
}

pub async fn handle_completion(ctx: &poise::serenity_prelude::Context, app: &AppState, new_message: &Message) -> Result<()> {
// if whitelist is empty, assume user did not configure one
if !app.whitelist.is_empty() && !app.whitelist.contains(&new_message.channel_id) {
new_message
.reply(ctx, "This channel is not whitelisted.")
.await
.into_diagnostic()
.wrap_err("failed to send whitelist message")?;
return Ok(());
}

if !check_rate_limit(new_message, app).await? {
// prevent user from spamming us with timeout
let error_report_future = tokio::time::timeout(std::time::Duration::from_secs(10), async {
Expand Down
28 changes: 27 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod gcra;
mod invocation_builder;
mod rate_limit_config;

use std::time::Duration;
use std::{
str::FromStr,
time::Duration,
};

use async_openai::{
config::OpenAIConfig,
Expand All @@ -24,6 +27,7 @@ use migration::{
};
use poise::{
serenity_prelude::{
ChannelId,
ClientBuilder,
CreateAllowedMentions,
FullEvent,
Expand Down Expand Up @@ -82,6 +86,26 @@ struct EnvConfig {

#[envconfig(from = "RATE_LIMIT_CONFIG", default = "rate_limits.toml")]
rate_limit_config: String,

#[envconfig(from = "WHITELIST_CHANNEL", default = "")]
whitelist_channel: ChannelWhiteList,
}

struct ChannelWhiteList(Vec<ChannelId>);
impl FromStr for ChannelWhiteList {
type Err = Report;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Ok(ChannelWhiteList(Vec::new()));
}

s.split(',')
.map(|s| s.parse().into_diagnostic().wrap_err("failed to parse channel id"))
.collect::<Result<Vec<_>, _>>()
.map(ChannelWhiteList)
.wrap_err("failed to parse channel whitelist")
}
}

struct AppState {
Expand All @@ -90,6 +114,7 @@ struct AppState {
db: DatabaseConnection,
path_rate_limits: Mutex<PathRateLimits>,
context_settings: InvocationContextSettings,
whitelist: Vec<ChannelId>,
}

#[tokio::main(flavor = "current_thread")]
Expand Down Expand Up @@ -200,6 +225,7 @@ async fn main() -> Result<()> {
reply_chain_window: Some(5),
reply_chain_max_token_count: Some(1000),
},
whitelist: env_config.whitelist_channel.0,
})
})
})
Expand Down

0 comments on commit 12c94aa

Please sign in to comment.