diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/commands/db.rs | 5 | ||||
| -rw-r--r-- | src/commands/meme.rs | 120 | ||||
| -rw-r--r-- | src/commands/mod.rs | 62 | ||||
| -rw-r--r-- | src/commands/playback/mod.rs | 51 | ||||
| -rw-r--r-- | src/commands/playback/types.rs | 29 | ||||
| -rw-r--r-- | src/db/mod.rs | 69 | ||||
| -rw-r--r-- | src/db/models.rs | 54 | ||||
| -rw-r--r-- | src/main.rs | 7 |
8 files changed, 304 insertions, 93 deletions
diff --git a/src/commands/db.rs b/src/commands/db.rs deleted file mode 100644 index 4e8293a..0000000 --- a/src/commands/db.rs +++ /dev/null @@ -1,5 +0,0 @@ -use super::*; - -command!(meme(_ctx, msg) { - send(msg.channel_id, "I am not yet capable of memeing", msg.tts)?; -}); diff --git a/src/commands/meme.rs b/src/commands/meme.rs new file mode 100644 index 0000000..bb7315c --- /dev/null +++ b/src/commands/meme.rs @@ -0,0 +1,120 @@ +use rand::{thread_rng, distributions::{Weighted, WeightedChoice, Distribution}}; +use serenity::http::AttachmentType; +use serenity::builder::CreateMessage; +use diesel::PgConnection; + +use super::*; +use super::playback::CtxExt; + +use ::db::*; + +#[derive(Clone, Copy, Debug)] +enum MemeType { + Text, + Image, + Audio, +} + +static mut MEME_WEIGHTS: [Weighted<MemeType>; 3] = [ + Weighted { weight: 1, item: MemeType::Text }, + Weighted { weight: 1, item: MemeType::Image }, + Weighted { weight: 1, item: MemeType::Audio }, +]; + +static mut TTS_WEIGHTS: [Weighted<bool>; 2] = [ + Weighted { weight: 4, item: false }, + Weighted { weight: 1, item: true } +]; + +command!(meme(ctx, msg, args) { + let ch = msg.channel_id; + + if args.len() == 0 { + let conn = connection()?; + + let should_audio = ctx.currently_playing() && ctx.users_listening()?; + let dist: WeightedChoice<'static, MemeType> = if should_audio { + WeightedChoice::new(unsafe { &mut MEME_WEIGHTS }) + } else { + WeightedChoice::new(unsafe { &mut MEME_WEIGHTS[..2] }) + }; + + match dist.sample(&mut thread_rng()) { + MemeType::Text => { + let mut text_meme = rand_text(&conn)?; + + let mut ctr = 0; + while !should_audio && text_meme.audio_id.is_some() { + text_meme = rand_text(&conn)?; + + ctr += 1; + if ctr > 10 { + warn!("looped 10 times trying to find a non-audio text meme"); + return Ok(()); + } + } + + send_text(ctx, &text_meme, &conn, msg)?; + }, + MemeType::Image => { + let image_meme = rand_image(&conn)?; + let image = image_meme.associated_data(&conn)?; + + send_image(&image_meme, &image, ch)?; + }, + MemeType::Audio => { + let audio = rand_audio(&conn)?.associated_data(&conn)?; + send_audio(ctx, msg, &audio)?; + } + } + } +}); + +fn send_text(ctx: &Context, t: &TextMeme, conn: &PgConnection, msg: &Message) -> Result<()> { + let (image, audio) = t.associated_data(conn)?; + + let dist = WeightedChoice::new(unsafe { &mut TTS_WEIGHTS }); + + let create_msg = |m: CreateMessage| m + .tts(dist.sample(&mut thread_rng())) + .content(&t.content); + + match image { + Some(image) => msg.channel_id.send_files(vec!(AttachmentType::Bytes((&image.data, &t.title))), create_msg)?, + None => msg.channel_id.send_message(create_msg)?, + }; + + if let Some(audio) = audio { + send_audio(ctx, msg, &audio)?; + } + + Ok(()) +} + +fn send_image(image_meme: &ImageMeme, image: &Image, ch: ChannelId) -> Result<()> { + ch.send_files(vec!(AttachmentType::Bytes((&image.data, &image_meme.title))), |m| m.content(""))?; + Ok(()) +} + +// note: slight edge-case race condition here: there could have been something queued since we +// checked whether anything was playing. not a significant negative impact and unlikely, so i'm +// not worrying about it +fn send_audio(ctx: &Context, msg: &Message, audio: &Audio) -> Result<()> { + let queue_lock = ctx.data.lock().get::<PlayQueue>().cloned().unwrap(); + let mut play_queue = queue_lock.write().unwrap(); + + play_queue.queue.push_front(PlayArgs{ + initiator: msg.author.name.clone(), + data: ::either::Right(audio.data.clone()), + sender_channel: msg.channel_id, + }); + + Ok(()) +} + + +pub fn db_fallback(ctx: &mut Context, msg: &Message, s: &str) -> Result<()> { + + + Ok(()) +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 3a9cb66..d13082f 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -3,7 +3,6 @@ use serenity::framework::StandardFramework; use serenity::model::channel::Message; use serenity::model::id::ChannelId; use serenity::prelude::*; -use serenity::voice::{LockedAudio, ytdl}; use std::thread; use std::time::Duration; @@ -13,30 +12,6 @@ mod sound; pub use self::sound::*; pub use self::playback::*; -cfg_if! { - if #[cfg(feature = "diesel")] { - mod db; - pub use self::db::*; - - fn register_db(f: StandardFramework) -> StandardFramework { - f - .command("meme", |c| c - .guild_only(true) - .help_available(false) - .cmd(meme)) - } - } else { - fn register_db(f: StandardFramework) -> StandardFramework { - f - } - } -} - -fn send(channel: ChannelId, text: &str, tts: bool) -> Result<()> { - channel.send_message(|m| m.content(text).tts(tts))?; - Ok(()) -} - pub fn register_commands(f: StandardFramework) -> StandardFramework { let f: StandardFramework = register_db(f); f @@ -80,9 +55,13 @@ pub fn register_commands(f: StandardFramework) -> StandardFramework { .cmd(volume)) .unrecognised_command(|ctx, msg, unrec| { let url = match msg.content.split_whitespace().skip(1).next() { - Some(x) => x, + Some(x) if x.starts_with("http") => x, + Some(x) => { + let _ = db_fallback(ctx, msg, x); + return; + }, None => { - info!("received unrecognized command: {}", unrec); + info!("bad command formatting: '{}'", unrec); let _ = send(msg.channel_id, "format your commands right. fuck you.", msg.tts); return; } @@ -92,4 +71,33 @@ pub fn register_commands(f: StandardFramework) -> StandardFramework { }) } +cfg_if! { + if #[cfg(feature = "diesel")] { + mod meme; + pub use self::meme::*; + + fn register_db(f: StandardFramework) -> StandardFramework { + f + .command("meme", |c| c + .guild_only(true) + .help_available(false) + .cmd(meme)) + } + } else { + fn register_db(f: StandardFramework) -> StandardFramework { + f + } + + fn db_fallback(_: &mut Context, msg: &Message, s: &str) -> Result<()> { + info!("received unrecognized command: {}", s); + let _ = send(msg.channel_id, "format your commands right. fuck you.", msg.tts)?; + Ok(()) + } + } +} + +fn send(channel: ChannelId, text: &str, tts: bool) -> Result<()> { + channel.send_message(|m| m.content(text).tts(tts))?; + Ok(()) +} diff --git a/src/commands/playback/mod.rs b/src/commands/playback/mod.rs index 1d4ee96..a13ad36 100644 --- a/src/commands/playback/mod.rs +++ b/src/commands/playback/mod.rs @@ -1,9 +1,37 @@ -use super::*; +use either::{Left, Right}; +use serenity::voice::{LockedAudio, ytdl}; +use super::*; pub use self::types::*; mod types; +pub trait CtxExt { + fn currently_playing(&self) -> bool; + fn users_listening(&self) -> Result<bool>; +} + +impl CtxExt for Context { + fn currently_playing(&self) -> bool { + let queue_lock = self.data.lock().get::<PlayQueue>().cloned().unwrap(); + let play_queue = queue_lock.read().unwrap(); + play_queue.playing.is_none() + } + + fn users_listening(&self) -> Result<bool> { + let channel_id = ChannelId(must_env_lookup::<u64>("VOICE_CHANNEL")); + let channel = channel_id.get()?; + let res = channel.guild() + .and_then(|ch| ch.read().guild()) + .map(|g| (&g.read().voice_states) + .into_iter() + .any(|(_, state)| state.channel_id == Some(channel_id))) + .unwrap_or(false); + + Ok(res) + } +} + pub fn _play(ctx: &Context, msg: &Message, url: &str) -> Result<()> { debug!("playing '{}'", url); if !url.starts_with("http") { @@ -16,16 +44,12 @@ pub fn _play(ctx: &Context, msg: &Message, url: &str) -> Result<()> { return Ok(()); } - trace!("acquiring queue lock"); - let queue_lock = ctx.data.lock().get::<PlayQueue>().cloned().unwrap(); let mut play_queue = queue_lock.write().unwrap(); - trace!("queue lock acquired"); - play_queue.queue.push_back(PlayArgs{ initiator: msg.author.name.clone(), - url: url.to_owned(), + data: Left(url.to_owned()), sender_channel: msg.channel_id, }); @@ -169,7 +193,13 @@ command!(list(ctx, msg) { Some(ref info) => { let audio = info.audio.lock(); let status = if audio.playing { "playing" } else { "paused:" }; - send(msg.channel_id, &format!("Currently {} `{}` ({})", status, info.init_args.url, info.init_args.initiator), msg.tts)?; + + let playing_info = match info.init_args.data { + Left(ref url) => format!(" `{}`", url), + Right(_) => "memeing".to_owned(), + }; + + send(msg.channel_id, &format!("Currently {} {} ({})", status, playing_info, info.init_args.initiator), msg.tts)?; }, None => { debug!("`list` called with no items in queue"); @@ -179,6 +209,11 @@ command!(list(ctx, msg) { } play_queue.queue.iter().for_each(|info| { - channel.say(&format!("`{}` ({})", info.url, info.initiator)).unwrap(); + let playing_info = match info.data { + Left(ref url) => format!("`{}`", url), + Right(_) => "meme".to_owned(), + }; + + channel.say(&format!("{} ({})", playing_info, info.initiator)).unwrap(); }); }); diff --git a/src/commands/playback/types.rs b/src/commands/playback/types.rs index 41592ec..b9e1778 100644 --- a/src/commands/playback/types.rs +++ b/src/commands/playback/types.rs @@ -2,6 +2,9 @@ use serenity::client::bridge::voice::ClientVoiceManager; use typemap::Key; use std::sync::{Arc, RwLock}; use std::collections::VecDeque; + +use either::{Either, Left, Right}; + use super::*; pub struct VoiceManager; @@ -19,7 +22,7 @@ impl VoiceManager { #[derive(Clone, Debug)] pub struct PlayArgs { - pub url: String, + pub data: Either<String, Vec<u8>>, pub initiator: String, pub sender_channel: ChannelId, } @@ -90,7 +93,7 @@ impl PlayQueue { let mut manager = voice_manager.lock(); manager.leave(*TARGET_GUILD_ID); - debug!("disconnected due to inactivity"); + debug!("disconnected because playback finished"); } continue; } @@ -98,18 +101,22 @@ impl PlayQueue { let mut queue = queue_lck.write().unwrap(); let item = queue.queue.pop_front().unwrap(); - trace!("checking ytdl for: {}", item.url); - - let src = match ytdl(&item.url) { - Ok(src) => src, - Err(e) => { - error!("bad link: {}; {:?}", &item.url, e); - let _ = send(item.sender_channel, &format!("what the fuck"), false); - continue; + let src = match item.data { + Left(ref url) => { + match ytdl(url) { + Ok(src) => src, + Err(e) => { + error!("bad link: {}; {:?}", url, e); + let _ = send(item.sender_channel, "what the fuck", false); + continue; + } + } + }, + Right(ref vec) => { + ::serenity::voice::opus(true, ::std::io::Cursor::new(vec.clone())) } }; - trace!("got ytdl item for {}", item.url); let mut manager = voice_manager.lock(); let handler = manager.join(*TARGET_GUILD_ID, must_env_lookup::<u64>("VOICE_CHANNEL")); diff --git a/src/db/mod.rs b/src/db/mod.rs index 8ce0011..e1c3a55 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,61 +1,78 @@ use std::env;
use diesel::prelude::*;
+use diesel::r2d2::{ConnectionManager, ManageConnection};
use super::{Result, Error};
pub use self::models::*;
+use self::schema::*;
mod schema;
mod models;
-fn connection() -> Result<PgConnection> {
- let database_url = env::var("DATABASE_URL")?;
- PgConnection::establish(&database_url).map_err(Error::from)
+lazy_static! {
+ static ref DB_URL: String = env::var("DATABASE_URL").expect("no database url in environment").into();
+ static ref CONN_MGR: ConnectionManager<PgConnection> = ConnectionManager::new(DB_URL.clone());
}
-pub fn find_text(search: String) -> Result<TextMeme> {
- use self::schema::text_memes::dsl::*;
+pub fn connection() -> Result<PgConnection> {
+ CONN_MGR.connect().map_err(Error::from)
+}
+
+pub trait AssociatedData {
+ type Associated;
+
+ fn associated_data(&self, conn: &PgConnection) -> Result<Self::Associated>;
+}
+pub fn find_text(conn: &PgConnection, search: String) -> Result<TextMeme> {
let format_search = format!("%{}%", search);
- let conn = connection()?;
- text_memes
- .filter(title.ilike(&format_search).or(content.ilike(&format_search)))
+ text_memes::table
+ .filter(text_memes::title.ilike(&format_search).or(text_memes::content.ilike(&format_search)))
.limit(1)
- .first::<TextMeme>(&conn)
+ .first::<TextMeme>(conn)
.map_err(Error::from)
}
-pub fn find_audio(search: String) -> Result<AudioMeme> {
- use self::schema::audio_memes::dsl::*;
-
+pub fn find_audio(conn: &PgConnection, search: String) -> Result<AudioMeme> {
let format_search = format!("%{}%", search);
- let conn = connection()?;
- audio_memes
- .filter(title.ilike(format_search))
+ audio_memes::table
+ .filter(audio_memes::title.ilike(format_search))
.limit(1)
- .first::<AudioMeme>(&conn)
+ .first::<AudioMeme>(conn)
.map_err(Error::from)
}
-pub fn rand_audio() -> Result<AudioMeme> {
- use self::schema::audio_memes::dsl::*;
+pub fn find_image(conn: &PgConnection, search: String) -> Result<ImageMeme> {
+ let format_search = format!("%{}%", search);
- let conn = connection()?;
- audio_memes
+ image_memes::table
+ .filter(image_memes::title.ilike(format_search))
+ .limit(1)
+ .first::<ImageMeme>(conn)
+ .map_err(Error::from)
+}
+
+pub fn rand_text(conn: &PgConnection) -> Result<TextMeme> {
+ text_memes::table
.order(random.desc())
- .first::<AudioMeme>(&conn)
+ .first::<TextMeme>(conn)
.map_err(Error::from)
}
-pub fn rand_text() -> Result<TextMeme> {
- use self::schema::text_memes::dsl::*;
+pub fn rand_image(conn: &PgConnection) -> Result<ImageMeme> {
+ image_memes::table
+ .order(random.desc())
+ .first::<ImageMeme>(conn)
+ .map_err(Error::from)
+}
- let conn = connection()?;
- text_memes
+pub fn rand_audio(conn: &PgConnection) -> Result<AudioMeme> {
+ audio_memes::table
.order(random.desc())
- .first::<TextMeme>(&conn)
+ .first::<AudioMeme>(conn)
.map_err(Error::from)
}
diff --git a/src/db/models.rs b/src/db/models.rs index 85cba2a..c07a12a 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -1,5 +1,9 @@ -use super::schema::*;
use chrono::naive::NaiveDateTime;
+use diesel::prelude::*;
+
+use super::schema::*;
+use super::AssociatedData;
+use ::{Result, Error};
#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="text_memes"]
@@ -12,6 +16,18 @@ pub struct TextMeme { pub metadata_id: i32,
}
+impl AssociatedData for TextMeme {
+ type Associated = (Option<Image>, Option<Audio>);
+
+ fn associated_data(&self, conn: &PgConnection) -> Result<Self::Associated> {
+ let image = self.image_id.map(|x: i32| images::table.find(x).first(conn)).transpose()?;
+ let audio = self.audio_id.map(|x: i32| audio::table.find(x).first(conn)).transpose()?;
+
+ Ok((image, audio))
+ }
+}
+
+
#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="image_memes"]
pub struct ImageMeme {
@@ -21,6 +37,15 @@ pub struct ImageMeme { pub metadata_id: i32,
}
+impl AssociatedData for ImageMeme {
+ type Associated = Image;
+
+ fn associated_data(&self, conn: &PgConnection) -> Result<Self::Associated> {
+ images::table.find(self.image_id).first(conn).map_err(Error::from)
+ }
+}
+
+
#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="audio_memes"]
pub struct AudioMeme {
@@ -30,9 +55,16 @@ pub struct AudioMeme { pub metadata_id: i32,
}
-#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug, Associations)]
-#[belongs_to(AudioMeme)]
-#[belongs_to(TextMeme)]
+impl AssociatedData for AudioMeme {
+ type Associated = Audio;
+
+ fn associated_data(&self, conn: &PgConnection) -> Result<Self::Associated> {
+ audio::table.find(self.audio_id).first(conn).map_err(Error::from)
+ }
+}
+
+
+#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="audio"]
pub struct Audio {
pub id: i32,
@@ -40,9 +72,7 @@ pub struct Audio { pub metadata_id: i32,
}
-#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug, Associations)]
-#[belongs_to(ImageMeme)]
-#[belongs_to(TextMeme)]
+#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="images"]
pub struct Image {
pub id: i32,
@@ -50,12 +80,7 @@ pub struct Image { pub metadata_id: i32,
}
-#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug, Associations)]
-#[belongs_to(Audio)]
-#[belongs_to(Image)]
-#[belongs_to(TextMeme)]
-#[belongs_to(ImageMeme)]
-#[belongs_to(TextMeme)]
+#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="metadata"]
pub struct Metadata {
pub id: i32,
@@ -63,8 +88,7 @@ pub struct Metadata { pub created_by: i64,
}
-#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug, Associations)]
-#[belongs_to(Metadata)]
+#[derive(Insertable, Queryable, Identifiable, PartialEq, AsChangeset, Debug)]
#[table_name="audit_records"]
pub struct AuditRecord {
pub id: i32,
diff --git a/src/main.rs b/src/main.rs index fba574e..551449b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![feature(transpose_result)] + #[macro_use] extern crate cfg_if; extern crate chrono; extern crate ctrlc; @@ -10,6 +12,8 @@ extern crate fern; #[macro_use] extern crate serenity; extern crate typemap; extern crate url; +extern crate rand; +extern crate either; use commands::register_commands; use dotenv::dotenv; @@ -39,7 +43,8 @@ mod errors { Serenity(::serenity::Error); MissingVar(::std::env::VarError); DieselConn(::diesel::ConnectionError) #[cfg(feature = "diesel")]; - Diesel(::diesel::result::Error) #[cfg(feature = "diesel")]; + Diesel(::diesel::result::Error) #[cfg(feature = "diesel")]; + R2D2(::diesel::r2d2::Error) #[cfg(feature = "diesel")]; } } } |
