From 1fda857d25c3d33e593951eef3ce713fa69a7025 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Thu, 5 Apr 2018 20:53:37 -0400 Subject: start to integrate db support with commands --- src/commands/db.rs | 5 -- src/commands/meme.rs | 120 +++++++++++++++++++++++++++++++++++++++++ src/commands/mod.rs | 62 +++++++++++---------- src/commands/playback/mod.rs | 51 +++++++++++++++--- src/commands/playback/types.rs | 29 ++++++---- src/db/mod.rs | 69 +++++++++++++++--------- src/db/models.rs | 54 +++++++++++++------ src/main.rs | 7 ++- 8 files changed, 304 insertions(+), 93 deletions(-) delete mode 100644 src/commands/db.rs create mode 100644 src/commands/meme.rs (limited to 'src') diff --git a/src/commands/db.rs b/src/commands/db.rs deleted file mode 100644 index 4e8293a..0000000 --- a/src/commands/db.rs +++ /dev/null @@ -1,5 +0,0 @@ -use super::*; - -command!(meme(_ctx, msg) { - send(msg.channel_id, "I am not yet capable of memeing", msg.tts)?; -}); diff --git a/src/commands/meme.rs b/src/commands/meme.rs new file mode 100644 index 0000000..bb7315c --- /dev/null +++ b/src/commands/meme.rs @@ -0,0 +1,120 @@ +use rand::{thread_rng, distributions::{Weighted, WeightedChoice, Distribution}}; +use serenity::http::AttachmentType; +use serenity::builder::CreateMessage; +use diesel::PgConnection; + +use super::*; +use super::playback::CtxExt; + +use ::db::*; + +#[derive(Clone, Copy, Debug)] +enum MemeType { + Text, + Image, + Audio, +} + +static mut MEME_WEIGHTS: [Weighted; 3] = [ + Weighted { weight: 1, item: MemeType::Text }, + Weighted { weight: 1, item: MemeType::Image }, + Weighted { weight: 1, item: MemeType::Audio }, +]; + +static mut TTS_WEIGHTS: [Weighted; 2] = [ + Weighted { weight: 4, item: false }, + Weighted { weight: 1, item: true } +]; + +command!(meme(ctx, msg, args) { + let ch = msg.channel_id; + + if args.len() == 0 { + let conn = connection()?; + + let should_audio = ctx.currently_playing() && ctx.users_listening()?; + let dist: WeightedChoice<'static, MemeType> = if should_audio { + WeightedChoice::new(unsafe { &mut MEME_WEIGHTS }) + } else { + WeightedChoice::new(unsafe { &mut MEME_WEIGHTS[..2] }) + }; + + match dist.sample(&mut thread_rng()) { + MemeType::Text => { + let mut text_meme = rand_text(&conn)?; + + let mut ctr = 0; + while !should_audio && text_meme.audio_id.is_some() { + text_meme = rand_text(&conn)?; + + ctr += 1; + if ctr > 10 { + warn!("looped 10 times trying to find a non-audio text meme"); + return Ok(()); + } + } + + send_text(ctx, &text_meme, &conn, msg)?; + }, + MemeType::Image => { + let image_meme = rand_image(&conn)?; + let image = image_meme.associated_data(&conn)?; + + send_image(&image_meme, &image, ch)?; + }, + MemeType::Audio => { + let audio = rand_audio(&conn)?.associated_data(&conn)?; + send_audio(ctx, msg, &audio)?; + } + } + } +}); + +fn send_text(ctx: &Context, t: &TextMeme, conn: &PgConnection, msg: &Message) -> Result<()> { + let (image, audio) = t.associated_data(conn)?; + + let dist = WeightedChoice::new(unsafe { &mut TTS_WEIGHTS }); + + let create_msg = |m: CreateMessage| m + .tts(dist.sample(&mut thread_rng())) + .content(&t.content); + + match image { + Some(image) => msg.channel_id.send_files(vec!(AttachmentType::Bytes((&image.data, &t.title))), create_msg)?, + None => msg.channel_id.send_message(create_msg)?, + }; + + if let Some(audio) = audio { + send_audio(ctx, msg, &audio)?; + } + + Ok(()) +} + +fn send_image(image_meme: &ImageMeme, image: &Image, ch: ChannelId) -> Result<()> { + ch.send_files(vec!(AttachmentType::Bytes((&image.data, &image_meme.title))), |m| m.content(""))?; + Ok(()) +} + +// note: slight edge-case race condition here: there could have been something queued since we +// checked whether anything was playing. not a significant negative impact and unlikely, so i'm +// not worrying about it +fn send_audio(ctx: &Context, msg: &Message, audio: &Audio) -> Result<()> { + let queue_lock = ctx.data.lock().get::().cloned().unwrap(); + let mut play_queue = queue_lock.write().unwrap(); + + play_queue.queue.push_front(PlayArgs{ + initiator: msg.author.name.clone(), + data: ::either::Right(audio.data.clone()), + sender_channel: msg.channel_id, + }); + + Ok(()) +} + + +pub fn db_fallback(ctx: &mut Context, msg: &Message, s: &str) -> Result<()> { + + + Ok(()) +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 3a9cb66..d13082f 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -3,7 +3,6 @@ use serenity::framework::StandardFramework; use serenity::model::channel::Message; use serenity::model::id::ChannelId; use serenity::prelude::*; -use serenity::voice::{LockedAudio, ytdl}; use std::thread; use std::time::Duration; @@ -13,30 +12,6 @@ mod sound; pub use self::sound::*; pub use self::playback::*; -cfg_if! { - if #[cfg(feature = "diesel")] { - mod db; - pub use self::db::*; - - fn register_db(f: StandardFramework) -> StandardFramework { - f - .command("meme", |c| c - .guild_only(true) - .help_available(false) - .cmd(meme)) - } - } else { - fn register_db(f: StandardFramework) -> StandardFramework { - f - } - } -} - -fn send(channel: ChannelId, text: &str, tts: bool) -> Result<()> { - channel.send_message(|m| m.content(text).tts(tts))?; - Ok(()) -} - pub fn register_commands(f: StandardFramework) -> StandardFramework { let f: StandardFramework = register_db(f); f @@ -80,9 +55,13 @@ pub fn register_commands(f: StandardFramework) -> StandardFramework { .cmd(volume)) .unrecognised_command(|ctx, msg, unrec| { let url = match msg.content.split_whitespace().skip(1).next() { - Some(x) => x, + Some(x) if x.starts_with("http") => x, + Some(x) => { + let _ = db_fallback(ctx, msg, x); + return; + }, None => { - info!("received unrecognized command: {}", unrec); + info!("bad command formatting: '{}'", unrec); let _ = send(msg.channel_id, "format your commands right. fuck you.", msg.tts); return; } @@ -92,4 +71,33 @@ pub fn register_commands(f: StandardFramework) -> StandardFramework { }) } +cfg_if! { + if #[cfg(feature = "diesel")] { + mod meme; + pub use self::meme::*; + + fn register_db(f: StandardFramework) -> StandardFramework { + f + .command("meme", |c| c + .guild_only(true) + .help_available(false) + .cmd(meme)) + } + } else { + fn register_db(f: StandardFramework) -> StandardFramework { + f + } + + fn db_fallback(_: &mut Context, msg: &Message, s: &str) -> Result<()> { + info!("received unrecognized command: {}", s); + let _ = send(msg.channel_id, "format your commands right. fuck you.", msg.tts)?; + Ok(()) + } + } +} + +fn send(channel: ChannelId, text: &str, tts: bool) -> Result<()> { + channel.send_message(|m| m.content(text).tts(tts))?; + Ok(()) +} diff --git a/src/commands/playback/mod.rs b/src/commands/playback/mod.rs index 1d4ee96..a13ad36 100644 --- a/src/commands/playback/mod.rs +++ b/src/commands/playback/mod.rs @@ -1,9 +1,37 @@ -use super::*; +use either::{Left, Right}; +use serenity::voice::{LockedAudio, ytdl}; +use super::*; pub use self::types::*; mod types; +pub trait CtxExt { + fn currently_playing(&self) -> bool; + fn users_listening(&self) -> Result; +} + +impl CtxExt for Context { + fn currently_playing(&self) -> bool { + let queue_lock = self.data.lock().get::().cloned().unwrap(); + let play_queue = queue_lock.read().unwrap(); + play_queue.playing.is_none() + } + + fn users_listening(&self) -> Result { + let channel_id = ChannelId(must_env_lookup::("VOICE_CHANNEL")); + let channel = channel_id.get()?; + let res = channel.guild() + .and_then(|ch| ch.read().guild()) + .map(|g| (&g.read().voice_states) + .into_iter() + .any(|(_, state)| state.channel_id == Some(channel_id))) + .unwrap_or(false); + + Ok(res) + } +} + pub fn _play(ctx: &Context, msg: &Message, url: &str) -> Result<()> { debug!("playing '{}'", url); if !url.starts_with("http") { @@ -16,16 +44,12 @@ pub fn _play(ctx: &Context, msg: &Message, url: &str) -> Result<()> { return Ok(()); } - trace!("acquiring queue lock"); - let queue_lock = ctx.data.lock().get::().cloned().unwrap(); let mut play_queue = queue_lock.write().unwrap(); - trace!("queue lock acquired"); - play_queue.queue.push_back(PlayArgs{ initiator: msg.author.name.clone(), - url: url.to_owned(), + data: Left(url.to_owned()), sender_channel: msg.channel_id, }); @@ -169,7 +193,13 @@ command!(list(ctx, msg) { Some(ref info) => { let audio = info.audio.lock(); let status = if audio.playing { "playing" } else { "paused:" }; - send(msg.channel_id, &format!("Currently {} `{}` ({})", status, info.init_args.url, info.init_args.initiator), msg.tts)?; + + let playing_info = match info.init_args.data { + Left(ref url) => format!(" `{}`", url), + Right(_) => "memeing".to_owned(), + }; + + send(msg.channel_id, &format!("Currently {} {} ({})", status, playing_info, info.init_args.initiator), msg.tts)?; }, None => { debug!("`list` called with no items in queue"); @@ -179,6 +209,11 @@ command!(list(ctx, msg) { } play_queue.queue.iter().for_each(|info| { - channel.say(&format!("`{}` ({})", info.url, info.initiator)).unwrap(); + let playing_info = match info.data { + Left(ref url) => format!("`{}`", url), + Right(_) => "meme".to_owned(), + }; + + channel.say(&format!("{} ({})", playing_info, info.initiator)).unwrap(); }); }); diff --git a/src/commands/playback/types.rs b/src/commands/playback/types.rs index 41592ec..b9e1778 100644 --- a/src/commands/playback/types.rs +++ b/src/commands/playback/types.rs @@ -2,6 +2,9 @@ use serenity::client::bridge::voice::ClientVoiceManager; use typemap::Key; use std::sync::{Arc, RwLock}; use std::collections::VecDeque; + +use either::{Either, Left, Right}; + use super::*; pub struct VoiceManager; @@ -19,7 +22,7 @@ impl VoiceManager { #[derive(Clone, Debug)] pub struct PlayArgs { - pub url: String, + pub data: Either>, pub initiator: String, pub sender_channel: ChannelId, } @@ -90,7 +93,7 @@ impl PlayQueue { let mut manager = voice_manager.lock(); manager.leave(*TARGET_GUILD_ID); - debug!("disconnected due to inactivity"); + debug!("disconnected because playback finished"); } continue; } @@ -98,18 +101,22 @@ impl PlayQueue { let mut queue = queue_lck.write().unwrap(); let item = queue.queue.pop_front().unwrap(); - trace!("checking ytdl for: {}", item.url); - - let src = match ytdl(&item.url) { - Ok(src) => src, - Err(e) => { - error!("bad link: {}; {:?}", &item.url, e); - let _ = send(item.sender_channel, &format!("what the fuck"), false); - continue; + let src = match item.data { + Left(ref url) => { + match ytdl(url) { + Ok(src) => src, + Err(e) => { + error!("bad link: {}; {:?}", url, e); + let _ = send(item.sender_channel, "what the fuck", false); + continue; + } + } + }, + Right(ref vec) => { + ::serenity::voice::opus(true, ::std::io::Cursor::new(vec.clone())) } }; - trace!("got ytdl item for {}", item.url); let mut manager = voice_manager.lock(); let handler = manager.join(*TARGET_GUILD_ID, must_env_lookup::("VOICE_CHANNEL")); diff --git a/src/db/mod.rs b/src/db/mod.rs index 8ce0011..e1c3a55 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,61 +1,78 @@ use std::env; use diesel::prelude::*; +use diesel::r2d2::{ConnectionManager, ManageConnection}; use super::{Result, Error}; pub use self::models::*; +use self::schema::*; mod schema; mod models; -fn connection() -> Result { - let database_url = env::var("DATABASE_URL")?; - PgConnection::establish(&database_url).map_err(Error::from) +lazy_static! { + static ref DB_URL: String = env::var("DATABASE_URL").expect("no database url in environment").into(); + static ref CONN_MGR: ConnectionManager = ConnectionManager::new(DB_URL.clone()); } -pub fn find_text(search: String) -> Result { - use self::schema::text_memes::dsl::*; +pub fn connection() -> Result { + CONN_MGR.connect().map_err(Error::from) +} + +pub trait AssociatedData { + type Associated; + + fn associated_data(&self, conn: &PgConnection) -> Result; +} +pub fn find_text(conn: &PgConnection, search: String) -> Result { let format_search = format!("%{}%", search); - let conn = connection()?; - text_memes - .filter(title.ilike(&format_search).or(content.ilike(&format_search))) + text_memes::table + .filter(text_memes::title.ilike(&format_search).or(text_memes::content.ilike(&format_search))) .limit(1) - .first::(&conn) + .first::(conn) .map_err(Error::from) } -pub fn find_audio(search: String) -> Result { - use self::schema::audio_memes::dsl::*; - +pub fn find_audio(conn: &PgConnection, search: String) -> Result { let format_search = format!("%{}%", search); - let conn = connection()?; - audio_memes - .filter(title.ilike(format_search)) + audio_memes::table + .filter(audio_memes::title.ilike(format_search)) .limit(1) - .first::(&conn) + .first::(conn) .map_err(Error::from) } -pub fn rand_audio() -> Result { - use self::schema::audio_memes::dsl::*; +pub fn find_image(conn: &PgConnection, search: String) -> Result { + let format_search = format!("%{}%", search); - let conn = connection()?; - audio_memes + image_memes::table + .filter(image_memes::title.ilike(format_search)) + .limit(1) + .first::(conn) + .map_err(Error::from) +} + +pub fn rand_text(conn: &PgConnection) -> Result { + text_memes::table .order(random.desc()) - .first::(&conn) + .first::(conn) .map_err(Error::from) } -pub fn rand_text() -> Result { - use self::schema::text_memes::dsl::*; +pub fn rand_image(conn: &PgConnection) -> Result { + image_memes::table + .order(random.desc()) + .first::(conn) + .map_err(Error::from) +} - let conn = connection()?; - text_memes +pub fn rand_audio(conn: &PgConnection) -> Result { + audio_memes::table .order(random.desc()) - .first::(&conn) + .first::(conn) .map_err(Error::from) } diff --git a/src/db/models.rs b/src/db/models.rs index 85cba2a..c07a12a 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -1,5 +1,9 @@ -use super::schema::*; use chrono::naive::NaiveDateTime; +use diesel::prelude::*; + +use super::schema::*; +use super::AssociatedData; +use ::{Result, Error}; #[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)] #[table_name="text_memes"] @@ -12,6 +16,18 @@ pub struct TextMeme { pub metadata_id: i32, } +impl AssociatedData for TextMeme { + type Associated = (Option, Option