aboutsummaryrefslogtreecommitdiff
path: root/src/db/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/db/mod.rs')
-rw-r--r--src/db/mod.rs60
1 files changed, 58 insertions, 2 deletions
diff --git a/src/db/mod.rs b/src/db/mod.rs
index cd71e9e..c1d5ab0 100644
--- a/src/db/mod.rs
+++ b/src/db/mod.rs
@@ -13,6 +13,7 @@ use diesel::{
prelude::*,
r2d2::{ConnectionManager, ManageConnection},
};
+use postgres::Connection as RawPgConn;
use r2d2_postgres::{
PostgresConnectionManager as RawPgConnMgr,
TlsMode,
@@ -32,17 +33,22 @@ lazy_static! {
static ref RAW_CONN_MGR: RawPgConnMgr = RawPgConnMgr::new(DB_URL.clone(), TlsMode::None).unwrap();
}
+#[inline]
pub fn connection() -> Result<PgConnection> {
CONN_MGR.connect().map_err(Error::from)
}
+#[inline]
+fn raw_connection() -> Result<RawPgConn> {
+ RAW_CONN_MGR.connect().map_err(Error::from)
+}
+
pub fn find_meme<T: AsRef<str>>(conn: &PgConnection, search: T) -> Result<Meme> {
use diesel::dsl::sql;
use diesel::sql_types::Text;
let search = search.as_ref();
- // TODO: check for injection
let mut meme = memes::table
.filter(memes::title.eq(search))
.limit(1)
@@ -111,6 +117,56 @@ pub fn delete_meme<T: AsRef<str>>(conn: &PgConnection, search: T, deleted_by: u6
})
}
+pub fn rare_meme(conn: &PgConnection, audio: bool) -> Result<Meme> {
+ use rand::prelude::*;
+ use failure::err_msg;
+
+ let raw_conn = raw_connection()?;
+
+ let rows = raw_conn.query(r#"
+ SELECT agg.meme_id, (agg.time_diff / agg.ct) AS play_prop, agg.ct FROM (
+ SELECT meme_count.meme_id AS meme_id, meme_count.ct AS ct, EXTRACT(EPOCH FROM (now() - metadata.created)) AS time_diff FROM (
+ SELECT meme_id, COUNT(*) AS ct FROM invocation_records GROUP BY meme_id
+ ) AS meme_count
+ INNER JOIN memes ON memes.id = meme_count.meme_id
+ INNER JOIN metadata ON metadata.id = memes.metadata_id
+ WHERE ((memes.audio_id IS NOT NULL) = $1) OR $2
+ ) AS agg
+ ORDER BY play_prop DESC
+ LIMIT 100;
+ "#, &[&audio, &!audio])?;
+
+ let elems = rows.iter()
+ .map(|row| (row.get::<_, i32>(0), row.get::<_, f64>(1), row.get::<_, i64>(2) as usize))
+ .collect::<Vec<_>>();
+
+ if elems.len() == 0 {
+ return Err(err_msg("no rare memes found"));
+ }
+
+ let total_probability_mass: f64 = elems.iter().map(|(_, prob, _)| prob).sum();
+
+ if total_probability_mass == 0. {
+ return Err(err_msg("rare meme probability mass was 0"))
+ }
+
+ let mut rng = thread_rng();
+ let target_prob = rng.gen_range(0., total_probability_mass);
+
+ let mut cur_prob_acc = 0.;
+ let mut meme_id = elems.last().unwrap().0;
+
+ for &(m_id, prob, _) in elems.iter() {
+ cur_prob_acc += prob;
+ if cur_prob_acc > target_prob {
+ meme_id = m_id;
+ break;
+ }
+ }
+
+ Meme::find(conn, meme_id)
+}
+
pub fn rand_meme(conn: &PgConnection, audio: bool) -> Result<Meme> {
use rand::{thread_rng, seq::SliceRandom};
use failure::err_msg;
@@ -273,7 +329,7 @@ pub fn stats(conn: &PgConnection) -> Result<Stats> {
.first(conn)
.map_err(Error::from)?;
- let raw_conn = RAW_CONN_MGR.connect().map_err(Error::from)?;
+ let raw_conn = raw_connection()?;
let rows = raw_conn.query(r#"
SELECT DATE(time) as dt, COUNT(*) FROM invocation_records