From 833f2bed24ab49f1c6242762b6d1e0be9192e870 Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 10 May 2024 20:17:31 -0400 Subject: wip --- src/db/mod.rs | 129 ++++++++++++++++++++++++++++++------------------------- src/db/models.rs | 56 ++++++++++++++++-------- 2 files changed, 109 insertions(+), 76 deletions(-) (limited to 'src/db') diff --git a/src/db/mod.rs b/src/db/mod.rs index e7ff17f..04f2239 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -11,24 +11,13 @@ use chrono::{ }; use diesel::{ prelude::*, - r2d2::{ - ConnectionManager, - ManageConnection, - }, NotFound, }; - -use postgres::Client as RawPgConn; -use r2d2_postgres::{ - postgres::{ - Config, - NoTls, - }, - PostgresConnectionManager as RawPgConnMgr, -}; - +use diesel_async::pooled_connection::AsyncDieselConnectionManager; +use tokio_postgres::Client as RawPgConn; use anyhow::anyhow; -use diesel_migrations::MigrationHarness; +use diesel_async::pooled_connection::deadpool::Pool; +use deadpool_postgres::{Pool as RawPgConnMgr, PoolConfig}; use lazy_static::lazy_static; use crate::{ @@ -42,47 +31,62 @@ use self::schema::*; mod models; mod schema; -const MIGRATIONS: diesel_migrations::EmbeddedMigrations = diesel_migrations::embed_migrations!(); -static MIGRATE: std::sync::Once = std::sync::Once::new(); +const MIGRATIONS: diesel_async_migrations::EmbeddedMigrations = diesel_async_migrations::embed_migrations!(); +static MIGRATE: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::new(); lazy_static! { static ref DB_URL: String = env::var("DATABASE_URL").expect("no database url in environment"); static ref DB_CONFIG: Config = Config::from_str(&DB_URL).expect("parsing db url as config"); - static ref CONN_MGR: ConnectionManager = ConnectionManager::new(DB_URL.clone()); - static ref RAW_CONN_MGR: RawPgConnMgr = RawPgConnMgr::new(DB_CONFIG.clone(), NoTls); + + static ref POOL: diesel_async::pooled_connection::deadpool::Pool> = { + let cfg = AsyncDieselConnectionManager::new(DB_URL.clone()); + let pool = Pool::builder(cfg).build().unwrap(); + + pool + }; + + static ref RAW_CONN_MGR: RawPgConnMgr = { + deadpool_postgres::Config::builder(tokio_postgres::NoTls).expect("failed to init config") + .config(PoolConfig::new(8)) + .build().expect("failed to build pool") + }; } #[inline] -pub fn connection() -> Result { - CONN_MGR - .connect() - .map(|mut conn| { - MIGRATE.call_once(|| { - log::info!("running migrations"); - conn.run_pending_migrations(MIGRATIONS).expect("failed running migrations"); - log::info!("migrations complete"); - }); +pub async fn connection() -> Result { + let pool: &Pool> = POOL; + + pool.get() + .then(|mut conn| async move { + if let Ok(conn) = conn { + MIGRATE.get_or_init(|| async move { + log::info!("running migrations"); + MIGRATIONS.run_pending_migrations(&mut conn).await.expect("failed running migrations"); + log::info!("migrations complete"); + }).await; + } conn }) + .await .map_err(Error::from) } #[inline] -fn raw_connection() -> Result { +async fn raw_connection() -> Result { // HACK - if !MIGRATE.is_completed() { - connection()?; + if !MIGRATE.initialized() { + connection().await?; } RAW_CONN_MGR.connect().map_err(Error::from) } -pub fn find_meme>(conn: &mut PgConnection, search: T) -> Result { +pub async fn find_meme>(conn: &mut AsyncPgConnection, search: T) -> Result { let search = search.as_ref(); - let mut meme = memes::table.filter(memes::title.eq(search)).limit(1).first::(conn); + let mut meme = memes::table.filter(memes::title.eq(search)).limit(1).first::(conn).await; if let Err(NotFound) = meme { let format_search = format!("%{}%", search); @@ -90,18 +94,18 @@ pub fn find_meme>(conn: &mut PgConnection, search: T) -> Result(conn); + .first::(conn).await; } meme.map_err(Error::from) } -pub fn query_meme>( +pub async fn query_meme>( search: T, user_id: Option, age_desc: bool, ) -> Result> { - let mut raw_conn = raw_connection()?; + let mut raw_conn = raw_connection().await?; let search = format!("%{}%", search.as_ref()); @@ -123,7 +127,7 @@ pub fn query_meme>( }, ), &[&search, &(user_id.unwrap_or(0) as i64), &user_id.is_none()], - )?; + ).await?; let result = rows .iter() @@ -150,21 +154,21 @@ pub fn query_meme>( Ok(result) } -pub fn delete_meme>( - conn: &mut PgConnection, +pub async fn delete_meme>( + conn: &mut AsyncPgConnection, search: T, deleted_by: u64, ) -> Result<()> { conn.transaction::<(), Error, _>(|tx| { - let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::(tx)?; + let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::(tx).await?; - diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx)?; + diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx).await?; if let Some(image_id) = deleted.image_id { - let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx)?; + let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx).await?; if count == 0 { - diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx)?; + diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx).await?; } } @@ -172,10 +176,10 @@ pub fn delete_meme>( let count = memes::table .select(::diesel::dsl::count_star()) .filter(memes::audio_id.eq(audio_id)) - .execute(tx)?; + .execute(tx).await?; if count == 0 { - diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx)?; + diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx).await?; } } @@ -185,16 +189,16 @@ pub fn delete_meme>( meme_id: deleted.id, }; - let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx)?; + let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx).await?; Ok(()) }) } -pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result { +pub async fn rare_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result { use rand::prelude::*; - let mut raw_conn = raw_connection()?; + let mut raw_conn = raw_connection().await?; let rows = raw_conn.query( r#" @@ -229,7 +233,7 @@ pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result { LIMIT 100; "#, &[&!audio, &audio], - )?; + ).await?; let elems = rows .iter() @@ -249,10 +253,10 @@ pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result { .ok_or_else(|| anyhow!("couldn't locate meme satisfying target probability"))? .0; - Meme::find(conn, meme_id) + Meme::find(conn, meme_id).await } -pub fn rand_meme(conn: &mut PgConnection, audio: bool) -> Result { +pub async fn rand_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result { use rand::{ seq::SliceRandom, thread_rng, @@ -268,21 +272,23 @@ pub fn rand_meme(conn: &mut PgConnection, audio: bool) -> Result { .or(memes::audio_id.is_not_null()), ) .load(conn) + .await .map_err(Error::from)? } else { memes::table .select(memes::id) .filter(memes::content.is_not_null().or(memes::image_id.is_not_null())) .load(conn) + .await .map_err(Error::from)? }; let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load meme"))?; - memes::table.find(id).first::(conn).map_err(Error::from) + memes::table.find(id).first::(conn).await.map_err(Error::from) } -pub fn rand_audio_meme(conn: &mut PgConnection) -> Result { +pub async fn rand_audio_meme(conn: &mut AsyncPgConnection) -> Result { use rand::{ seq::SliceRandom, thread_rng, @@ -292,14 +298,15 @@ pub fn rand_audio_meme(conn: &mut PgConnection) -> Result { .select(memes::id) .filter(memes::audio_id.is_not_null()) .load(conn) + .await .map_err(Error::from)?; let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load audio meme"))?; - memes::table.find(id).first::(conn).map_err(Error::from) + memes::table.find(id).first::(conn).await.map_err(Error::from) } -pub fn rand_silent_meme(conn: &mut PgConnection) -> Result { +pub async fn rand_silent_meme(conn: &mut AsyncPgConnection) -> Result { use rand::{ seq::SliceRandom, thread_rng, @@ -309,11 +316,12 @@ pub fn rand_silent_meme(conn: &mut PgConnection) -> Result { .select(memes::id) .filter(memes::audio_id.is_null()) .load(conn) + .await .map_err(Error::from)?; let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load audio meme"))?; - memes::table.find(id).first::(conn).map_err(Error::from) + memes::table.find(id).first::(conn).await.map_err(Error::from) } #[derive(Debug, Clone)] @@ -347,7 +355,7 @@ pub struct Stats { pub most_popular_meme_overall_count: usize, } -pub fn stats(conn: &mut PgConnection) -> Result { +pub async fn stats(conn: &mut AsyncPgConnection) -> Result { use chrono::{ NaiveDate, NaiveDateTime, @@ -373,36 +381,41 @@ pub fn stats(conn: &mut PgConnection) -> Result { .select(count(memes::image_id)) .filter(memes::image_id.is_not_null()) .first(conn) + .await .map_err(Error::from)?; let audio_count: i64 = memes::table .select(count(memes::audio_id)) .filter(memes::audio_id.is_not_null()) .first(conn) + .await .map_err(Error::from)?; let started_recording: NaiveDateTime = invocation_records::table .select(invocation_records::time) .order(invocation_records::time) .first(conn) + .await .map_err(Error::from)?; let started_recording = to_utc(started_recording); let total_meme_invocations: i64 = - invocation_records::table.select(count_star()).first(conn).map_err(Error::from)?; + invocation_records::table.select(count_star()).first(conn).await.map_err(Error::from)?; let audio_meme_invocations: i64 = invocation_records::table .inner_join(memes::table) .select(count_star()) .filter(memes::audio_id.is_not_null()) .first(conn) + .await .map_err(Error::from)?; let random_meme_invocations: i64 = invocation_records::table .select(count_star()) .filter(invocation_records::random.eq(true)) .first(conn) + .await .map_err(Error::from)?; let mut raw_conn = raw_connection()?; diff --git a/src/db/models.rs b/src/db/models.rs index bba25a2..f7bbf8e 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -5,6 +5,7 @@ use diesel::{ Insertable, Queryable, }; +use diesel_async::{AsyncPgConnection, RunQueryDsl}; use sha1::Digest; use crate::{ @@ -25,18 +26,27 @@ pub struct Meme { } impl Meme { - pub fn image(&self, conn: &mut PgConnection) -> Option> { + pub async fn image(&self, conn: &mut AsyncPgConnection) -> Option> { self.image_id - .map(|x: i32| images::table.filter(images::id.eq(x)).first(conn).map_err(Error::from)) + .map(|x: i32| images::table.filter(images::id.eq(x)) + .first(conn) + .await + .map_err(Error::from)) } - pub fn audio(&self, conn: &mut PgConnection) -> Option> { + pub async fn audio(&self, conn: &mut AsyncPgConnection) -> Option> { self.audio_id - .map(|x: i32| audio::table.filter(audio::id.eq(x)).first(conn).map_err(Error::from)) + .map(|x: i32| audio::table.filter(audio::id.eq(x)) + .first(conn) + .await + .map_err(Error::from)) } - pub fn find(conn: &mut PgConnection, id: i32) -> Result { - memes::table.find(id).get_result(conn).map_err(Error::from) + pub async fn find(conn: &mut AsyncPgConnection, id: i32) -> Result { + memes::table.find(id) + .get_result(conn) + .await + .map_err(Error::from) } } @@ -51,7 +61,7 @@ pub struct NewMeme { } impl NewMeme { - pub fn save(mut self, conn: &mut PgConnection, by_user: u64) -> Result { + pub async fn save(mut self, conn: &mut AsyncPgConnection, by_user: u64) -> Result { let metadata = Metadata::create(conn, by_user)?; self.metadata_id = metadata.id; @@ -59,6 +69,7 @@ impl NewMeme { diesel::insert_into(memes::table) .values(&self) .get_result::(conn) + .await .map_err(Error::from) } } @@ -73,7 +84,7 @@ pub struct Audio { } impl Audio { - pub fn create(conn: &mut PgConnection, data: Vec, by_user: u64) -> Result { + pub fn create(conn: &mut AsyncPgConnection, data: Vec, by_user: u64) -> Result { let mut data_hash = ::sha1::Sha1::new(); data_hash.update(&data); let data_hash = data_hash.finalize().to_vec(); @@ -81,7 +92,8 @@ impl Audio { let id = audio::table .select(audio::id) .filter(audio::data_hash.eq(&data_hash)) - .get_results::(conn)?; + .get_results::(conn) + .await?; if let Some(id) = id.first() { return Ok(*id); @@ -99,6 +111,7 @@ impl Audio { .values(&new_audio) .returning(audio::id) .get_result(conn) + .await .map_err(Error::from) } } @@ -123,7 +136,7 @@ pub struct Image { impl Image { pub fn create( - conn: &mut PgConnection, + conn: &mut AsyncPgConnection, filename: &str, data: Vec, by_user: u64, @@ -135,7 +148,8 @@ impl Image { let id = images::table .select(images::id) .filter(images::data_hash.eq(&data_hash)) - .get_results::(conn)?; + .get_results::(conn) + .await?; if let Some(id) = id.first() { return Ok(*id); @@ -154,6 +168,7 @@ impl Image { .values(&new_image) .returning(images::id) .get_result(conn) + .await .map_err(Error::from) } } @@ -176,17 +191,18 @@ pub struct Metadata { } impl Metadata { - pub fn create(conn: &mut PgConnection, by_user: u64) -> Result { + pub fn create(conn: &mut AsyncPgConnection, by_user: u64) -> Result { diesel::insert_into(metadata::table) .values(&NewMetadata { created_by: by_user as i64, }) .get_result::(conn) + .await .map_err(Error::from) } - pub fn find(conn: &mut PgConnection, id: i32) -> Result { - metadata::table.find(id).get_result::(conn).map_err(Error::from) + pub fn find(conn: &mut AsyncPgConnection, id: i32) -> Result { + metadata::table.find(id).get_result::(conn).await.map_err(Error::from) } } @@ -206,13 +222,14 @@ pub struct AuditRecord { } impl AuditRecord { - pub fn create(conn: &mut PgConnection, metadata: i32, by_user: u64) -> Result { + pub fn create(conn: &mut AsyncPgConnection, metadata: i32, by_user: u64) -> Result { diesel::insert_into(audit_records::table) .values(&NewAuditRecord { updated_by: by_user as i64, metadata_id: metadata, }) .get_result::(conn) + .await .map_err(Error::from) } } @@ -264,7 +281,7 @@ pub struct NewInvocationRecord { impl InvocationRecord { pub fn create( - conn: &mut PgConnection, + conn: &mut AsyncPgConnection, user_id: u64, message_id: u64, meme_id: i32, @@ -278,21 +295,24 @@ impl InvocationRecord { random, }) .get_result::(conn) + .await .map_err(Error::from) } - pub fn last(conn: &mut PgConnection) -> Result { + pub fn last(conn: &mut AsyncPgConnection) -> Result { invocation_records::table .order(invocation_records::time.desc()) .first(conn) + .await .map_err(Error::from) } - pub fn last_n(conn: &mut PgConnection, n: usize) -> Result> { + pub fn last_n(conn: &mut AsyncPgConnection, n: usize) -> Result> { invocation_records::table .order(invocation_records::time.desc()) .limit(n as i64) .load(conn) + .await .map_err(Error::from) } } -- cgit v1.3.1