use crate::{ PoiseContext, PoiseData, commands, config::CONFIG, err_msg, util, util::OAUTH_URL, }; use chrono::Datelike; use dashmap::DashMap; use fnv::{ FnvHashMap, FnvHashSet, }; use grate::tracing; use lazy_static::lazy_static; use poise::{ BoxFuture, FrameworkError, PrefixContext, }; use serenity::{ all::{ Guild, GuildId, ReactionType, }, builder::CreateMessage, model::{ event::ResumedEvent, gateway::Ready, id::{ ChannelId, MessageId, }, }, prelude::*, }; use songbird::{ Call, Event, EventContext, SerenityInit, TrackEvent, }; use std::{ collections::HashSet, fs::File, future::Future, path::PathBuf, pin::Pin, str::FromStr, sync::Arc, time::Duration, }; use tokio::{ sync::Mutex, time::MissedTickBehavior, }; pub struct HttpKey; impl TypeMapKey for HttpKey { type Value = reqwest::Client; } pub struct VolumeKey; impl TypeMapKey for VolumeKey { type Value = DashMap; } pub struct PlaybackKey; impl TypeMapKey for PlaybackKey { type Value = Arc>; } #[cfg(debug_assertions)] const BOTNAME: &str = "thulani (dev)"; #[cfg(not(debug_assertions))] const BOTNAME: &str = "thulani"; struct Handler; async fn ready_guild(ctx: &Context, guild_id: GuildId) { let sb = songbird::get(ctx).await.unwrap(); if let Err(e) = guild_id.edit_nickname(ctx, Some(BOTNAME)).await { tracing::error!(error = %e, %guild_id, "changing nickname"); } let c = sb.get_or_insert(guild_id); let mut call = c.lock().await; call.remove_all_global_events(); call.add_global_event(Event::Track(TrackEvent::End), SongbirdHandler(c.clone())); } async fn perm_check(ctx: &Context, guild: &Guild) -> anyhow::Result<()> { // "Requested permissions" through the discord OAuth flow don't actually apply in a way that's // reflected in the guild member permission object. There is a dedicated role type created to // protect bot perms when you do this, which I'd like to read and verify, but I suspect it's // behind BotIntegration, which you need `GUILD_MANAGE` to even read. I don't want to give this // permission only for this reason. let me = ctx.cache.current_user().id; let member = guild.member(&ctx, me).await?; let perms = member.permissions(ctx)?; let lacking_perms = util::REQUIRED_PERMS.difference(perms); if !lacking_perms.is_empty() { tracing::error!( guild_id = %guild.id, guild_name = %guild.name, %lacking_perms, "insufficient permissions for guild" ); } let lacking_desired_perms = util::DESIRED_PERMS.difference(*util::REQUIRED_PERMS).difference(perms); if !lacking_desired_perms.is_empty() { tracing::warn!( guild_id = %guild.id, guild_name = %guild.name, %lacking_desired_perms, "bot is lacking requested permissions" ); } let excess_perms = perms.difference(*util::DESIRED_PERMS); if !excess_perms.is_empty() { tracing::debug!( guild_id = %guild.id, guild_name = %guild.name, %excess_perms, "bot has permissions in excess of requirements" ); } Ok(()) } #[serenity::async_trait] impl EventHandler for Handler { async fn ready(&self, ctx: Context, r: Ready) { tracing::info!( join_url = %OAUTH_URL.as_str(), visible_guilds = r.guilds.len(), my_user_id = %ctx.cache.current_user().id, "connected to discord" ); } async fn guild_create(&self, ctx: Context, guild: Guild, _is_new: Option) { tracing::info!(disc_event = "guild_create", guild_id = %guild.id, guild_name = %guild.name); if let Err(e) = perm_check(&ctx, &guild).await { tracing::error!(error = %e, "checking permissions"); } ready_guild(&ctx, guild.id).await; } async fn resume(&self, _ctx: Context, _resume: ResumedEvent) { tracing::info!("reconnected to discord"); } async fn message_delete( &self, _ctx: Context, _: ChannelId, deleted_message_id: MessageId, _: Option, ) { MESSAGE_WATCH.lock().await.remove(&deleted_message_id); } } struct SongbirdHandler(Arc>); #[serenity::async_trait] impl songbird::events::EventHandler for SongbirdHandler { async fn act(&self, _ctx: &EventContext<'_>) -> Option { let mut call = self.0.lock().await; if call.queue().is_empty() { let _ = call.leave().await; } None } } lazy_static! { static ref MESSAGE_WATCH: Mutex> = Mutex::new(FnvHashMap::default()); static ref PREFIXES: Vec<&'static str> = vec!["!thulani ", "!thulan ", "!thulando madando ", "!thulando "]; static ref RESTRICTED_PREFIXES: Vec<&'static str> = vec!["!todd ", "!toddbert ", "!toddlani "]; static ref ALL_PREFIXES: Vec<&'static str> = { let mut all_prefixes: Vec<&'static str> = vec![]; all_prefixes.extend(PREFIXES.iter()); all_prefixes.extend(RESTRICTED_PREFIXES.iter()); all_prefixes }; static ref RESTRICT_IDS: FnvHashSet = { let default_path = PathBuf::from_str("restrict.json").unwrap(); let restrict_path = CONFIG.restrict.as_ref().unwrap_or(&default_path); let restrict_ids = File::open(restrict_path) .map_err(anyhow::Error::from) .and_then(|f| serde_json::from_reader::<_, Vec>(f).map_err(anyhow::Error::from)); if let Err(ref e) = restrict_ids { tracing::warn!(error = %e, "opening restrict file"); } let result = restrict_ids.unwrap_or_default().into_iter().collect::>(); tracing::info!(restricted_ids = ?result); result }; } fn on_err(err: FrameworkError) -> BoxFuture<()> { Box::pin(async move { let Some(msg) = err_msg(&err) else { tracing::warn!("error handler missing poise context"); return; }; let ctx = err.serenity_context(); let text = match err { FrameworkError::ArgumentParse { .. } | FrameworkError::SubcommandRequired { .. } => "format your commands right. fuck you.".to_string(), FrameworkError::CooldownHit { .. } => "slow the fuck down bitch".to_string(), FrameworkError::NotAnOwner { .. } => "who do you think you are?".to_string(), FrameworkError::GuildOnly { .. } => "what in the sam hill are you smoking".to_string(), FrameworkError::DmOnly { .. } => "take that back or i'm revoking your kitten status".to_string(), FrameworkError::UnknownCommand { ctx, msg, prefix, msg_content, trigger, invocation_data, framework, .. } => { let command = poise::Command { name: "meme".to_owned(), ..Default::default() }; fn noop( _ctx: PrefixContext<'_, U, E>, ) -> BoxFuture<'_, Result<(), FrameworkError<'_, U, E>>> { Box::pin(async { Ok(()) }) } let ctx = PrefixContext { serenity_context: ctx, prefix, msg, command: &command, trigger, invocation_data, parent_commands: &[], data: &(), invoked_command_name: "", action: noop, args: msg_content, framework, __non_exhaustive: (), }; let content = msg_content.trim(); if content.is_empty() { if let Err(e) = util::reply(PoiseContext::Prefix(ctx), "what?").await { tracing::error!(error = %e, "responding to empty message"); }; return; } lazy_static::lazy_static! { static ref HTTP_REGEX: regex::Regex = regex::Regex::new(r#"^https?://"#).unwrap(); } if HTTP_REGEX.is_match(content) { match util::pop_string(msg_content) .map_err(anyhow::Error::from) .and_then(|(_rest, s)| s.parse().map_err(anyhow::Error::from)) { Ok(u) => { if let Err(e) = commands::link_unrecognized(PoiseContext::Prefix(ctx), u).await { tracing::error!(error = %e, "processing audio"); "BANIC".to_string() } else { return; } }, Err(e) => { tracing::error!(error = %e, "processing unrecognized message"); "BANIC".to_string() }, } } else if let Err(e) = commands::meme::invoke::_meme( PoiseContext::Prefix(ctx), msg_content, Default::default(), ) .await { tracing::error!(error = %e, "producing meme for unrecognized"); "BANIC".to_string() } else { return; } }, _ => "BANIC".to_string(), }; tracing::error!("error encountered: {err:#?}"); if let Err(e) = msg.react(ctx, ReactionType::Unicode("❌".to_owned())).await { tracing::error!(error = %e, "reacting to failed message"); } let cm = CreateMessage::default().content(text).tts(msg.tts); if let Err(e) = msg.channel_id.send_message(ctx, cm).await { tracing::error!(error = %e, "sending error to chat"); } }) } async fn framework() -> poise::Framework { let additional_prefixes = ALL_PREFIXES.iter().skip(1).map(|x| poise::Prefix::Literal(x.to_owned())).collect(); poise::Framework::builder() .options(poise::FrameworkOptions { pre_command: before_handle, post_command: after_handle, on_error: on_err, command_check: Some(check), prefix_options: poise::PrefixFrameworkOptions { prefix: ALL_PREFIXES.first().map(|&x| x.to_owned()), additional_prefixes, case_insensitive_commands: true, mention_as_prefix: false, ignore_bots: true, ..Default::default() }, commands: commands::commands(), owners: HashSet::from_iter([CONFIG.discord.owner()]), initialize_owners: false, skip_checks_for_owners: false, ..Default::default() }) .setup(|_ctx, _ready, _framework| Box::pin(async move { Ok(()) })) .build() } fn check(ctx: PoiseContext) -> BoxFuture> { Box::pin(async move { let span = tracing::debug_span!( "check", name = %ctx.command().name, author = %ctx.author().name, author_id = %ctx.author().id, ) .entered(); if ctx.author().id == CONFIG.discord.owner() { tracing::info!("author is owner"); return Ok(true); } let restricted_prefix = RESTRICTED_PREFIXES.iter().any(|&prefix| ctx.prefix() == prefix); if !restricted_prefix { tracing::debug!("command isn't restricted"); return Ok(true); } const PERMITTED_WEEKDAY: chrono::Weekday = chrono::Weekday::Tue; let user_is_restricted = RESTRICT_IDS.contains(&ctx.author().id.get()); let restrictions_flipped = chrono::Local::now().weekday() == PERMITTED_WEEKDAY; if user_is_restricted == restrictions_flipped { tracing::debug!("authorized for restricted command"); return Ok(true); } let reason = if !restrictions_flipped { "restricted prefix".to_owned() } else { format!("it is {PERMITTED_WEEKDAY:?}") }; tracing::info!( %reason, "reject restricted command", ); drop(span); util::reply(ctx, "no").await?; Ok(false) }) } fn before_handle<'fut>(ctx: PoiseContext<'fut>) -> Pin + Send + 'fut>> { tracing::debug!( name = %ctx.command().name, author = %ctx.author().name, author_id = %ctx.author().id, "got command", ); Box::pin(async {}) } fn after_handle(ctx: PoiseContext) -> BoxFuture<()> { Box::pin(async move { tracing::trace!(name = %ctx.command().name, "command completed successfully"); }) } pub async fn run() -> anyhow::Result<()> { #[cfg(all(windows, feature = "windows_autostart_postgres"))] let started_pg = tokio::task::spawn_blocking(util::windows::ensure_postgres_started).await??; let token = &CONFIG.discord.auth.token; let sb_config = songbird::Config::default(); let playback_data = Arc::new(DashMap::new()); let mut client = Client::builder(token, GatewayIntents::non_privileged() | GatewayIntents::MESSAGE_CONTENT) .event_handler(Handler) .register_songbird_from_config(sb_config) .type_map_insert::(reqwest::Client::new()) .type_map_insert::(DashMap::new()) .type_map_insert::(playback_data.clone()) .framework(framework().await) .await?; let client_data = client.data.clone(); let shard_manager = client.shard_manager.clone(); let run_handle = tokio::spawn(async move { tracing::info!("connecting to discord"); client.start().await.expect("running discord client"); }); let mut ticker = tokio::time::interval(Duration::from_secs(10)); ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); tokio::spawn(async move { loop { ticker.tick().await; tracing::trace!("running songbird info gc"); let Some(songbird) = ({ let data = client_data.read().await; data.get::().cloned() }) else { tracing::warn!("gc songbird data: no songbird in state"); continue; }; let mut active_uuids = HashSet::new(); for (guild_id, call) in songbird.into_iter() { let tracks = { let call = call.lock().await; call.queue().current_queue() }; tracing::trace!(%guild_id, queued_tracks = tracks.len()); active_uuids.extend(tracks.into_iter().map(|track| track.uuid())); } let mut n_removed = 0; playback_data.retain(|k, _v| { let result = active_uuids.contains(k); if !result { n_removed += 1; } result }); if n_removed > 0 { tracing::debug!( queued_tracks = active_uuids.len(), n_gced = n_removed, "songbird info gc done" ); } else { tracing::trace!( queued_tracks = active_uuids.len(), n_gced = n_removed, "songbird info gc done" ); } } }); tokio::signal::ctrl_c().await?; tracing::warn!("got ^C, gracefully halting discord"); shard_manager.shutdown_all().await; run_handle.await?; tracing::info!("discord shutdown"); #[cfg(all(windows, feature = "windows_autostart_postgres"))] if started_pg { tracing::info!("we started postgres, stopping it before shutdown"); tokio::task::spawn_blocking(util::windows::shutdown_postgres).await??; } Ok(()) }