diff options
Diffstat (limited to 'src/db/mod.rs')
| -rw-r--r-- | src/db/mod.rs | 242 |
1 files changed, 145 insertions, 97 deletions
diff --git a/src/db/mod.rs b/src/db/mod.rs index 04f2239..5aa0541 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,24 +1,33 @@ use std::{ convert::AsRef, env, - str::FromStr, }; +use anyhow::anyhow; use chrono::{ Date, DateTime, Utc, }; +use deadpool_postgres::{ + Pool as RawPgConnMgr, + PoolConfig, +}; use diesel::{ prelude::*, NotFound, }; -use diesel_async::pooled_connection::AsyncDieselConnectionManager; -use tokio_postgres::Client as RawPgConn; -use anyhow::anyhow; -use diesel_async::pooled_connection::deadpool::Pool; -use deadpool_postgres::{Pool as RawPgConnMgr, PoolConfig}; -use lazy_static::lazy_static; +use diesel_async::{ + pooled_connection::{ + deadpool::Pool, + AsyncDieselConnectionManager, + }, + scoped_futures::ScopedFutureExt, + AsyncConnection, + AsyncPgConnection, + RunQueryDsl, +}; +use serenity::FutureExt; use crate::{ Error, @@ -31,40 +40,38 @@ use self::schema::*; mod models; mod schema; -const MIGRATIONS: diesel_async_migrations::EmbeddedMigrations = diesel_async_migrations::embed_migrations!(); +const MIGRATIONS: diesel_async_migrations::EmbeddedMigrations = + diesel_async_migrations::embed_migrations!(); static MIGRATE: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::new(); -lazy_static! { +lazy_static::lazy_static! { static ref DB_URL: String = env::var("DATABASE_URL").expect("no database url in environment"); - static ref DB_CONFIG: Config = Config::from_str(&DB_URL).expect("parsing db url as config"); - static ref POOL: diesel_async::pooled_connection::deadpool::Pool<AsyncDieselConnectionManager<AsyncPgConnection>> = { + static ref POOL: Pool<AsyncPgConnection> = { let cfg = AsyncDieselConnectionManager::new(DB_URL.clone()); + + let pool = Pool::builder(cfg).build().unwrap(); pool }; static ref RAW_CONN_MGR: RawPgConnMgr = { - deadpool_postgres::Config::builder(tokio_postgres::NoTls).expect("failed to init config") + deadpool_postgres::Config::new() + .builder(tokio_postgres::NoTls).expect("failed to init config") .config(PoolConfig::new(8)) .build().expect("failed to build pool") }; } #[inline] -pub async fn connection() -> Result<AsyncPgConnection> { - let pool: &Pool<AsyncDieselConnectionManager<_>> = POOL; - - pool.get() +pub async fn connection() +-> Result<diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>> { + POOL.get() .then(|mut conn| async move { - if let Ok(conn) = conn { - MIGRATE.get_or_init(|| async move { - log::info!("running migrations"); - MIGRATIONS.run_pending_migrations(&mut conn).await.expect("failed running migrations"); - log::info!("migrations complete"); - }).await; + if let Ok(ref mut conn) = conn { + do_migrate(conn).await; } conn @@ -73,14 +80,24 @@ pub async fn connection() -> Result<AsyncPgConnection> { .map_err(Error::from) } +async fn do_migrate(conn: &mut AsyncPgConnection) { + MIGRATE + .get_or_init(|| async move { + log::info!("running migrations"); + MIGRATIONS.run_pending_migrations(conn).await.expect("failed running migrations"); + log::info!("migrations complete"); + }) + .await; +} + #[inline] -async fn raw_connection() -> Result<tokio_postgres::Client> { +async fn raw_connection() -> Result<deadpool_postgres::Object> { // HACK if !MIGRATE.initialized() { - connection().await?; + let _ = connection().await?; } - RAW_CONN_MGR.connect().map_err(Error::from) + RAW_CONN_MGR.get().await.map_err(Error::from) } pub async fn find_meme<T: AsRef<str>>(conn: &mut AsyncPgConnection, search: T) -> Result<Meme> { @@ -94,7 +111,8 @@ pub async fn find_meme<T: AsRef<str>>(conn: &mut AsyncPgConnection, search: T) - meme = memes::table .filter(memes::title.ilike(&format_search).or(memes::content.ilike(&format_search))) .limit(1) - .first::<Meme>(conn).await; + .first::<Meme>(conn) + .await; } meme.map_err(Error::from) @@ -109,9 +127,10 @@ pub async fn query_meme<T: AsRef<str>>( let search = format!("%{}%", search.as_ref()); - let rows = raw_conn.query( - &format!( - r#" + let rows = raw_conn + .query( + &format!( + r#" SELECT memes.id, title, content, image_id, audio_id, metadata_id, created, created_by FROM memes INNER JOIN metadata ON memes.metadata_id = metadata.id @@ -120,14 +139,15 @@ pub async fn query_meme<T: AsRef<str>>( ORDER BY metadata.created {} LIMIT 100 "#, - if age_desc { - "DESC" - } else { - "ASC" - }, - ), - &[&search, &(user_id.unwrap_or(0) as i64), &user_id.is_none()], - ).await?; + if age_desc { + "DESC" + } else { + "ASC" + }, + ), + &[&search, &(user_id.unwrap_or(0) as i64), &user_id.is_none()], + ) + .await?; let result = rows .iter() @@ -159,49 +179,61 @@ pub async fn delete_meme<T: AsRef<str>>( search: T, deleted_by: u64, ) -> Result<()> { + let search = search.as_ref(); + conn.transaction::<(), Error, _>(|tx| { - let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::<Meme>(tx).await?; + async move { + let deleted = memes::table.filter(memes::title.eq(search)).first::<Meme>(tx).await?; - diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx).await?; + diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx).await?; - if let Some(image_id) = deleted.image_id { - let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx).await?; + if let Some(image_id) = deleted.image_id { + let count = + memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx).await?; - if count == 0 { - diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx).await?; + if count == 0 { + diesel::delete(images::table) + .filter(images::id.eq(image_id)) + .execute(tx) + .await?; + } } - } - if let Some(audio_id) = deleted.audio_id { - let count = memes::table - .select(::diesel::dsl::count_star()) - .filter(memes::audio_id.eq(audio_id)) - .execute(tx).await?; + if let Some(audio_id) = deleted.audio_id { + let count = memes::table + .select(::diesel::dsl::count_star()) + .filter(memes::audio_id.eq(audio_id)) + .execute(tx) + .await?; - if count == 0 { - diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx).await?; + if count == 0 { + diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx).await?; + } } - } - let tombstone = NewTombstone { - deleted_by: deleted_by as i64, - metadata_id: deleted.metadata_id, - meme_id: deleted.id, - }; + let tombstone = NewTombstone { + deleted_by: deleted_by as i64, + metadata_id: deleted.metadata_id, + meme_id: deleted.id, + }; - let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx).await?; + let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx).await?; - Ok(()) + Ok(()) + } + .scope_boxed() }) + .await } pub async fn rare_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result<Meme> { use rand::prelude::*; - let mut raw_conn = raw_connection().await?; + let raw_conn = raw_connection().await?; - let rows = raw_conn.query( - r#" + let rows = raw_conn + .query( + r#" WITH meme_count AS ( SELECT @@ -232,8 +264,9 @@ pub async fn rare_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result<Meme FROM least_used LIMIT 100; "#, - &[&!audio, &audio], - ).await?; + &[&!audio, &audio], + ) + .await?; let elems = rows .iter() @@ -375,7 +408,8 @@ pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> { Date::from_utc(nd, Utc {}) } - let total_count: i64 = memes::table.select(count_star()).first(conn).map_err(Error::from)?; + let total_count: i64 = + memes::table.select(count_star()).first(conn).await.map_err(Error::from)?; let image_count: i64 = memes::table .select(count(memes::image_id)) @@ -418,23 +452,26 @@ pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> { .await .map_err(Error::from)?; - let mut raw_conn = raw_connection()?; + let mut raw_conn = raw_connection().await?; - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT DATE(time) as dt, COUNT(*) FROM invocation_records GROUP BY dt ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_active_day = to_utc_date(row.get(0)); let most_active_day_count: i64 = row.get(1); - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT DATE(time) as dt, COUNT(*) FROM invocation_records INNER JOIN memes ON invocation_records.meme_id = memes.id WHERE memes.audio_id IS NOT NULL @@ -442,42 +479,48 @@ pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> { ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_active_audio_day = to_utc_date(row.get(0)); let most_active_audio_day_count: i64 = row.get(1); - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT user_id, COUNT(*) FROM invocation_records WHERE random IS TRUE GROUP BY user_id ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_random_invoker: i64 = row.get(0); let most_random_invoker_count: i64 = row.get(1); - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT user_id, COUNT(*) FROM invocation_records WHERE random IS FALSE GROUP BY user_id ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_specific_invoker: i64 = row.get(0); let most_specific_invoker_count: i64 = row.get(1); - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT memes.title, COUNT(*) FROM invocation_records INNER JOIN memes ON meme_id = memes.id WHERE random IS FALSE @@ -485,14 +528,16 @@ pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> { ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_requested_meme = row.get(0); let most_requested_meme_count: i64 = row.get(1); - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT memes.title, COUNT(*) FROM invocation_records INNER JOIN memes ON meme_id = memes.id WHERE random IS TRUE @@ -500,22 +545,25 @@ pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> { ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_random_meme = row.get(0); let most_random_meme_count: i64 = row.get(1); - let row = raw_conn.query_one( - r#" + let row = raw_conn + .query_one( + r#" SELECT memes.title, COUNT(*) FROM invocation_records INNER JOIN memes ON meme_id = memes.id GROUP BY memes.title ORDER BY COUNT(*) DESC LIMIT 1; "#, - &[], - )?; + &[], + ) + .await?; let most_invoked_meme = row.get(0); let most_invoked_meme_count: i64 = row.get(1); @@ -559,8 +607,8 @@ pub struct MemerInfo { pub most_used_meme_count: usize, } -pub fn memers() -> Result<Vec<MemerInfo>> { - let mut raw_conn = raw_connection()?; +pub async fn memers() -> Result<Vec<MemerInfo>> { + let raw_conn = raw_connection().await?; let rows = raw_conn.query(r#" WITH random_count AS ( @@ -599,7 +647,7 @@ pub fn memers() -> Result<Vec<MemerInfo>> { INNER JOIN specific_count ON specific_count.user_id = random_count.user_id INNER JOIN memes ON memes.id = most_memed.meme_id ORDER BY (random_count.count + specific_count.count) DESC - "#, &[])?; + "#, &[]).await?; let result = rows .iter() |
