diff options
Diffstat (limited to 'src/db/mod.rs')
| -rw-r--r-- | src/db/mod.rs | 129 |
1 files changed, 71 insertions, 58 deletions
diff --git a/src/db/mod.rs b/src/db/mod.rs index e7ff17f..04f2239 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -11,24 +11,13 @@ use chrono::{ }; use diesel::{ prelude::*, - r2d2::{ - ConnectionManager, - ManageConnection, - }, NotFound, }; - -use postgres::Client as RawPgConn; -use r2d2_postgres::{ - postgres::{ - Config, - NoTls, - }, - PostgresConnectionManager as RawPgConnMgr, -}; - +use diesel_async::pooled_connection::AsyncDieselConnectionManager; +use tokio_postgres::Client as RawPgConn; use anyhow::anyhow; -use diesel_migrations::MigrationHarness; +use diesel_async::pooled_connection::deadpool::Pool; +use deadpool_postgres::{Pool as RawPgConnMgr, PoolConfig}; use lazy_static::lazy_static; use crate::{ @@ -42,47 +31,62 @@ use self::schema::*; mod models; mod schema; -const MIGRATIONS: diesel_migrations::EmbeddedMigrations = diesel_migrations::embed_migrations!(); -static MIGRATE: std::sync::Once = std::sync::Once::new(); +const MIGRATIONS: diesel_async_migrations::EmbeddedMigrations = diesel_async_migrations::embed_migrations!(); +static MIGRATE: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::new(); lazy_static! { static ref DB_URL: String = env::var("DATABASE_URL").expect("no database url in environment"); static ref DB_CONFIG: Config = Config::from_str(&DB_URL).expect("parsing db url as config"); - static ref CONN_MGR: ConnectionManager<PgConnection> = ConnectionManager::new(DB_URL.clone()); - static ref RAW_CONN_MGR: RawPgConnMgr<NoTls> = RawPgConnMgr::new(DB_CONFIG.clone(), NoTls); + + static ref POOL: diesel_async::pooled_connection::deadpool::Pool<AsyncDieselConnectionManager<AsyncPgConnection>> = { + let cfg = AsyncDieselConnectionManager::new(DB_URL.clone()); + let pool = Pool::builder(cfg).build().unwrap(); + + pool + }; + + static ref RAW_CONN_MGR: RawPgConnMgr = { + deadpool_postgres::Config::builder(tokio_postgres::NoTls).expect("failed to init config") + .config(PoolConfig::new(8)) + .build().expect("failed to build pool") + }; } #[inline] -pub fn connection() -> Result<PgConnection> { - CONN_MGR - .connect() - .map(|mut conn| { - MIGRATE.call_once(|| { - log::info!("running migrations"); - conn.run_pending_migrations(MIGRATIONS).expect("failed running migrations"); - log::info!("migrations complete"); - }); +pub async fn connection() -> Result<AsyncPgConnection> { + let pool: &Pool<AsyncDieselConnectionManager<_>> = POOL; + + pool.get() + .then(|mut conn| async move { + if let Ok(conn) = conn { + MIGRATE.get_or_init(|| async move { + log::info!("running migrations"); + MIGRATIONS.run_pending_migrations(&mut conn).await.expect("failed running migrations"); + log::info!("migrations complete"); + }).await; + } conn }) + .await .map_err(Error::from) } #[inline] -fn raw_connection() -> Result<RawPgConn> { +async fn raw_connection() -> Result<tokio_postgres::Client> { // HACK - if !MIGRATE.is_completed() { - connection()?; + if !MIGRATE.initialized() { + connection().await?; } RAW_CONN_MGR.connect().map_err(Error::from) } -pub fn find_meme<T: AsRef<str>>(conn: &mut PgConnection, search: T) -> Result<Meme> { +pub async fn find_meme<T: AsRef<str>>(conn: &mut AsyncPgConnection, search: T) -> Result<Meme> { let search = search.as_ref(); - let mut meme = memes::table.filter(memes::title.eq(search)).limit(1).first::<Meme>(conn); + let mut meme = memes::table.filter(memes::title.eq(search)).limit(1).first::<Meme>(conn).await; if let Err(NotFound) = meme { let format_search = format!("%{}%", search); @@ -90,18 +94,18 @@ pub fn find_meme<T: AsRef<str>>(conn: &mut PgConnection, search: T) -> Result<Me meme = memes::table .filter(memes::title.ilike(&format_search).or(memes::content.ilike(&format_search))) .limit(1) - .first::<Meme>(conn); + .first::<Meme>(conn).await; } meme.map_err(Error::from) } -pub fn query_meme<T: AsRef<str>>( +pub async fn query_meme<T: AsRef<str>>( search: T, user_id: Option<u64>, age_desc: bool, ) -> Result<Vec<(Meme, Metadata)>> { - let mut raw_conn = raw_connection()?; + let mut raw_conn = raw_connection().await?; let search = format!("%{}%", search.as_ref()); @@ -123,7 +127,7 @@ pub fn query_meme<T: AsRef<str>>( }, ), &[&search, &(user_id.unwrap_or(0) as i64), &user_id.is_none()], - )?; + ).await?; let result = rows .iter() @@ -150,21 +154,21 @@ pub fn query_meme<T: AsRef<str>>( Ok(result) } -pub fn delete_meme<T: AsRef<str>>( - conn: &mut PgConnection, +pub async fn delete_meme<T: AsRef<str>>( + conn: &mut AsyncPgConnection, search: T, deleted_by: u64, ) -> Result<()> { conn.transaction::<(), Error, _>(|tx| { - let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::<Meme>(tx)?; + let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::<Meme>(tx).await?; - diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx)?; + diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx).await?; if let Some(image_id) = deleted.image_id { - let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx)?; + let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx).await?; if count == 0 { - diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx)?; + diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx).await?; } } @@ -172,10 +176,10 @@ pub fn delete_meme<T: AsRef<str>>( let count = memes::table .select(::diesel::dsl::count_star()) .filter(memes::audio_id.eq(audio_id)) - .execute(tx)?; + .execute(tx).await?; if count == 0 { - diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx)?; + diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx).await?; } } @@ -185,16 +189,16 @@ pub fn delete_meme<T: AsRef<str>>( meme_id: deleted.id, }; - let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx)?; + let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx).await?; Ok(()) }) } -pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> { +pub async fn rare_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result<Meme> { use rand::prelude::*; - let mut raw_conn = raw_connection()?; + let mut raw_conn = raw_connection().await?; let rows = raw_conn.query( r#" @@ -229,7 +233,7 @@ pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> { LIMIT 100; "#, &[&!audio, &audio], - )?; + ).await?; let elems = rows .iter() @@ -249,10 +253,10 @@ pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> { .ok_or_else(|| anyhow!("couldn't locate meme satisfying target probability"))? .0; - Meme::find(conn, meme_id) + Meme::find(conn, meme_id).await } -pub fn rand_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> { +pub async fn rand_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result<Meme> { use rand::{ seq::SliceRandom, thread_rng, @@ -268,21 +272,23 @@ pub fn rand_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> { .or(memes::audio_id.is_not_null()), ) .load(conn) + .await .map_err(Error::from)? } else { memes::table .select(memes::id) .filter(memes::content.is_not_null().or(memes::image_id.is_not_null())) .load(conn) + .await .map_err(Error::from)? }; let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load meme"))?; - memes::table.find(id).first::<Meme>(conn).map_err(Error::from) + memes::table.find(id).first::<Meme>(conn).await.map_err(Error::from) } -pub fn rand_audio_meme(conn: &mut PgConnection) -> Result<Meme> { +pub async fn rand_audio_meme(conn: &mut AsyncPgConnection) -> Result<Meme> { use rand::{ seq::SliceRandom, thread_rng, @@ -292,14 +298,15 @@ pub fn rand_audio_meme(conn: &mut PgConnection) -> Result<Meme> { .select(memes::id) .filter(memes::audio_id.is_not_null()) .load(conn) + .await .map_err(Error::from)?; let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load audio meme"))?; - memes::table.find(id).first::<Meme>(conn).map_err(Error::from) + memes::table.find(id).first::<Meme>(conn).await.map_err(Error::from) } -pub fn rand_silent_meme(conn: &mut PgConnection) -> Result<Meme> { +pub async fn rand_silent_meme(conn: &mut AsyncPgConnection) -> Result<Meme> { use rand::{ seq::SliceRandom, thread_rng, @@ -309,11 +316,12 @@ pub fn rand_silent_meme(conn: &mut PgConnection) -> Result<Meme> { .select(memes::id) .filter(memes::audio_id.is_null()) .load(conn) + .await .map_err(Error::from)?; let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load audio meme"))?; - memes::table.find(id).first::<Meme>(conn).map_err(Error::from) + memes::table.find(id).first::<Meme>(conn).await.map_err(Error::from) } #[derive(Debug, Clone)] @@ -347,7 +355,7 @@ pub struct Stats { pub most_popular_meme_overall_count: usize, } -pub fn stats(conn: &mut PgConnection) -> Result<Stats> { +pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> { use chrono::{ NaiveDate, NaiveDateTime, @@ -373,36 +381,41 @@ pub fn stats(conn: &mut PgConnection) -> Result<Stats> { .select(count(memes::image_id)) .filter(memes::image_id.is_not_null()) .first(conn) + .await .map_err(Error::from)?; let audio_count: i64 = memes::table .select(count(memes::audio_id)) .filter(memes::audio_id.is_not_null()) .first(conn) + .await .map_err(Error::from)?; let started_recording: NaiveDateTime = invocation_records::table .select(invocation_records::time) .order(invocation_records::time) .first(conn) + .await .map_err(Error::from)?; let started_recording = to_utc(started_recording); let total_meme_invocations: i64 = - invocation_records::table.select(count_star()).first(conn).map_err(Error::from)?; + invocation_records::table.select(count_star()).first(conn).await.map_err(Error::from)?; let audio_meme_invocations: i64 = invocation_records::table .inner_join(memes::table) .select(count_star()) .filter(memes::audio_id.is_not_null()) .first(conn) + .await .map_err(Error::from)?; let random_meme_invocations: i64 = invocation_records::table .select(count_star()) .filter(invocation_records::random.eq(true)) .first(conn) + .await .map_err(Error::from)?; let mut raw_conn = raw_connection()?; |
