aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Perry <avaglir@gmail.com>2019-03-30 05:23:54 -0400
committerNathan Perry <avaglir@gmail.com>2019-03-30 05:23:54 -0400
commite7e2daa560af199bd005a8526ebcf7ff4441e3ea (patch)
treeb54be3956efa77ed729e0940e82b258798062a58
parenta6f5334f984531cf2304b6333b138944517d3703 (diff)
add rarememe command
-rw-r--r--src/commands/meme/invoke.rs42
-rw-r--r--src/commands/meme/mod.rs32
-rw-r--r--src/commands/mod.rs6
-rw-r--r--src/db/mod.rs60
4 files changed, 101 insertions, 39 deletions
diff --git a/src/commands/meme/invoke.rs b/src/commands/meme/invoke.rs
index 8deecb5..0f992c2 100644
--- a/src/commands/meme/invoke.rs
+++ b/src/commands/meme/invoke.rs
@@ -1,4 +1,7 @@
-use diesel::result::Error as DieselError;
+use diesel::{
+ NotFound,
+ result::Error as DieselError,
+};
use failure::Error;
use serenity::{
framework::standard::Args,
@@ -13,12 +16,10 @@ use crate::{
send,
},
db::{
+ self,
connection,
find_meme,
InvocationRecord,
- rand_audio_meme as db_rand_audio_meme,
- rand_meme as db_rand_meme,
- rand_silent_meme as db_rand_silent_meme,
},
Result,
};
@@ -79,9 +80,9 @@ fn rand_meme(ctx: &Context, message: &Message, audio_playback: AudioPlayback) ->
let should_audio = ctx.users_listening()?;
let mem = match audio_playback {
- AudioPlayback::Required => db_rand_audio_meme(&conn),
- AudioPlayback::Optional => db_rand_meme(&conn, should_audio),
- AudioPlayback::Prohibited => db_rand_silent_meme(&conn),
+ AudioPlayback::Required => db::rand_audio_meme(&conn),
+ AudioPlayback::Optional => db::rand_meme(&conn, should_audio),
+ AudioPlayback::Prohibited => db::rand_silent_meme(&conn),
};
match mem {
@@ -103,3 +104,30 @@ fn rand_meme(ctx: &Context, message: &Message, audio_playback: AudioPlayback) ->
},
}
}
+
+pub fn rare_meme(ctx: &mut Context, msg: &Message, _args: Args) -> Result<()> {
+ let should_audio = ctx.users_listening()?;
+
+ let conn = connection()?;
+ let meme = db::rare_meme(&conn, should_audio);
+
+ match meme {
+ Ok(meme) => {
+ InvocationRecord::create(&conn, msg.author.id.0, msg.id.0, meme.id, true)?;
+ send_meme(ctx, &meme, &conn, msg)
+ },
+ Err(e) => {
+ match e.downcast_ref::<DieselError>() {
+ Some(NotFound) => {
+ info!("rare meme not found");
+ return send(msg.channel_id, "i don't know any :(", msg.tts)
+ },
+ _ => {},
+ }
+
+ send(msg.channel_id, "THE MEME MARKET IS IN FREEFALL", msg.tts)?;
+
+ Err(e)
+ },
+ }
+}
diff --git a/src/commands/meme/mod.rs b/src/commands/meme/mod.rs
index d29a025..ce1cb48 100644
--- a/src/commands/meme/mod.rs
+++ b/src/commands/meme/mod.rs
@@ -1,46 +1,18 @@
-use std::{
- io::Read,
- process::{
- Command,
- Stdio,
- },
-};
-
-use diesel::{
- NotFound,
- PgConnection,
- result::Error as DieselError,
-};
-use failure::Error;
+use diesel::PgConnection;
use rand::{Rng, thread_rng};
use serenity::{
builder::CreateMessage,
- framework::standard::Args,
http::AttachmentType,
model::channel::Message,
prelude::*,
};
-use url::Url;
use crate::{
audio::{
- CtxExt,
- parse_times,
PlayArgs,
PlayQueue,
- ytdl_url,
- },
- commands::send,
- db::{
- Audio,
- connection,
- delete_meme,
- find_meme,
- Image,
- InvocationRecord,
- Meme,
- NewMeme,
},
+ db::Meme,
Result,
};
diff --git a/src/commands/mod.rs b/src/commands/mod.rs
index 6596a32..74e6e3f 100644
--- a/src/commands/mod.rs
+++ b/src/commands/mod.rs
@@ -128,6 +128,12 @@ fn register_db(f: StandardFramework) -> StandardFramework {
.desc("history of recent messages")
.cmd(history)
)
+ .command("rarememe", |c| c
+ .known_as("rare_meme")
+ .guild_only(true)
+ .desc("deliver an underutilized meme")
+ .cmd(rare_meme)
+ )
}
#[cfg(not(feature = "diesel"))]
diff --git a/src/db/mod.rs b/src/db/mod.rs
index cd71e9e..c1d5ab0 100644
--- a/src/db/mod.rs
+++ b/src/db/mod.rs
@@ -13,6 +13,7 @@ use diesel::{
prelude::*,
r2d2::{ConnectionManager, ManageConnection},
};
+use postgres::Connection as RawPgConn;
use r2d2_postgres::{
PostgresConnectionManager as RawPgConnMgr,
TlsMode,
@@ -32,17 +33,22 @@ lazy_static! {
static ref RAW_CONN_MGR: RawPgConnMgr = RawPgConnMgr::new(DB_URL.clone(), TlsMode::None).unwrap();
}
+#[inline]
pub fn connection() -> Result<PgConnection> {
CONN_MGR.connect().map_err(Error::from)
}
+#[inline]
+fn raw_connection() -> Result<RawPgConn> {
+ RAW_CONN_MGR.connect().map_err(Error::from)
+}
+
pub fn find_meme<T: AsRef<str>>(conn: &PgConnection, search: T) -> Result<Meme> {
use diesel::dsl::sql;
use diesel::sql_types::Text;
let search = search.as_ref();
- // TODO: check for injection
let mut meme = memes::table
.filter(memes::title.eq(search))
.limit(1)
@@ -111,6 +117,56 @@ pub fn delete_meme<T: AsRef<str>>(conn: &PgConnection, search: T, deleted_by: u6
})
}
+pub fn rare_meme(conn: &PgConnection, audio: bool) -> Result<Meme> {
+ use rand::prelude::*;
+ use failure::err_msg;
+
+ let raw_conn = raw_connection()?;
+
+ let rows = raw_conn.query(r#"
+ SELECT agg.meme_id, (agg.time_diff / agg.ct) AS play_prop, agg.ct FROM (
+ SELECT meme_count.meme_id AS meme_id, meme_count.ct AS ct, EXTRACT(EPOCH FROM (now() - metadata.created)) AS time_diff FROM (
+ SELECT meme_id, COUNT(*) AS ct FROM invocation_records GROUP BY meme_id
+ ) AS meme_count
+ INNER JOIN memes ON memes.id = meme_count.meme_id
+ INNER JOIN metadata ON metadata.id = memes.metadata_id
+ WHERE ((memes.audio_id IS NOT NULL) = $1) OR $2
+ ) AS agg
+ ORDER BY play_prop DESC
+ LIMIT 100;
+ "#, &[&audio, &!audio])?;
+
+ let elems = rows.iter()
+ .map(|row| (row.get::<_, i32>(0), row.get::<_, f64>(1), row.get::<_, i64>(2) as usize))
+ .collect::<Vec<_>>();
+
+ if elems.len() == 0 {
+ return Err(err_msg("no rare memes found"));
+ }
+
+ let total_probability_mass: f64 = elems.iter().map(|(_, prob, _)| prob).sum();
+
+ if total_probability_mass == 0. {
+ return Err(err_msg("rare meme probability mass was 0"))
+ }
+
+ let mut rng = thread_rng();
+ let target_prob = rng.gen_range(0., total_probability_mass);
+
+ let mut cur_prob_acc = 0.;
+ let mut meme_id = elems.last().unwrap().0;
+
+ for &(m_id, prob, _) in elems.iter() {
+ cur_prob_acc += prob;
+ if cur_prob_acc > target_prob {
+ meme_id = m_id;
+ break;
+ }
+ }
+
+ Meme::find(conn, meme_id)
+}
+
pub fn rand_meme(conn: &PgConnection, audio: bool) -> Result<Meme> {
use rand::{thread_rng, seq::SliceRandom};
use failure::err_msg;
@@ -273,7 +329,7 @@ pub fn stats(conn: &PgConnection) -> Result<Stats> {
.first(conn)
.map_err(Error::from)?;
- let raw_conn = RAW_CONN_MGR.connect().map_err(Error::from)?;
+ let raw_conn = raw_connection()?;
let rows = raw_conn.query(r#"
SELECT DATE(time) as dt, COUNT(*) FROM invocation_records