aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNathan Perry <np@nathanperry.dev>2024-05-10 20:17:31 -0400
committerNathan Perry <np@nathanperry.dev>2024-05-10 20:17:31 -0400
commit833f2bed24ab49f1c6242762b6d1e0be9192e870 (patch)
tree6addb09b277a3559a0049b31d602b80cd977bdbc /src
parentc56dea4fbd53fde13efaf742ab9ee9d56575128a (diff)
wip
Diffstat (limited to 'src')
-rw-r--r--src/commands/meme/create.rs8
-rw-r--r--src/commands/meme/history.rs14
-rw-r--r--src/commands/meme/mod.rs4
-rw-r--r--src/db/mod.rs129
-rw-r--r--src/db/models.rs56
5 files changed, 124 insertions, 87 deletions
diff --git a/src/commands/meme/create.rs b/src/commands/meme/create.rs
index 2cd9465..160de7b 100644
--- a/src/commands/meme/create.rs
+++ b/src/commands/meme/create.rs
@@ -72,7 +72,7 @@ pub async fn addmeme(ctx: &Context, msg: &Message, args: Args) -> CommandResult
if let Some(att) = image {
let data = att.download().await?;
- image_id = Some(Image::create(&mut conn, &att.filename, data, msg.author.id.get())?);
+ image_id = Some(Image::create(&mut conn, &att.filename, data, msg.author.id.get()).await?);
};
let save_result = NewMeme {
@@ -83,6 +83,7 @@ pub async fn addmeme(ctx: &Context, msg: &Message, args: Args) -> CommandResult
metadata_id: 0,
}
.save(&mut conn, msg.author.id.get())
+ .await
.map(|_| {});
use diesel::result::DatabaseErrorKind;
@@ -182,7 +183,7 @@ pub async fn addaudiomeme(ctx: &Context, msg: &Message, args: Args) -> CommandRe
if let Ok(att) = image_att {
let data = att.download().await?;
- image_id = Image::create(&mut conn, &att.filename, data, msg.author.id.get())?.pipe(Some);
+ image_id = Image::create(&mut conn, &att.filename, data, msg.author.id.get())?.await.pipe(Some);
}
let mut audio_data = Vec::new();
@@ -196,7 +197,7 @@ pub async fn addaudiomeme(ctx: &Context, msg: &Message, args: Args) -> CommandRe
.await;
}
- let audio_id = Audio::create(&mut conn, audio_data, msg.author.id.get())?;
+ let audio_id = Audio::create(&mut conn, audio_data, msg.author.id.get()).await?;
let save_result = NewMeme {
title,
@@ -206,6 +207,7 @@ pub async fn addaudiomeme(ctx: &Context, msg: &Message, args: Args) -> CommandRe
metadata_id: 0,
}
.save(&mut conn, msg.author.id.get())
+ .await
.map(|_| {});
use diesel::result::DatabaseErrorKind;
diff --git a/src/commands/meme/history.rs b/src/commands/meme/history.rs
index e2953d1..e5b3d33 100644
--- a/src/commands/meme/history.rs
+++ b/src/commands/meme/history.rs
@@ -60,7 +60,7 @@ static CLEAN_DATE_FORMAT: &str = "%b %-e %Y";
pub async fn wat(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
let mut conn = connection()?;
- let record = match InvocationRecord::last(&mut conn) {
+ let record = match InvocationRecord::last(&mut conn).await {
Ok(x) => x,
Err(e) => {
if let Some(NotFound) = e.downcast_ref::<DieselError>() {
@@ -75,11 +75,11 @@ pub async fn wat(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
},
};
- let meme = Meme::find(&mut conn, record.meme_id);
+ let meme = Meme::find(&mut conn, record.meme_id).await;
match meme {
Ok(ref meme) => {
- let metadata = Metadata::find(&mut conn, meme.metadata_id)?;
+ let metadata = Metadata::find(&mut conn, meme.metadata_id).await?;
let author = CONFIG.discord.guild().member(&ctx, metadata.created_by as u64).await?;
util::send(
@@ -125,7 +125,7 @@ pub async fn history(ctx: &Context, msg: &Message, mut args: Args) -> CommandRes
let records = {
let mut conn = connection()?;
- InvocationRecord::last_n(&mut conn, n)?
+ InvocationRecord::last_n(&mut conn, n).await?
};
if records.is_empty() {
@@ -150,7 +150,9 @@ pub async fn history(ctx: &Context, msg: &Message, mut args: Args) -> CommandRes
""
};
- let meme = Meme::find(&mut conn, rec.meme_id).and_then(|meme| {
+ let meme = Meme::find(&mut conn, rec.meme_id).await;
+
+ .and_then(|meme| {
Metadata::find(&mut conn, meme.metadata_id).map(|metadata| (metadata, meme))
});
@@ -214,7 +216,7 @@ pub async fn stats(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
};
let mut conn = connection()?;
- let stats = db::stats(&mut conn)?;
+ let stats = db::stats(&mut conn).await?;
debug!("reporting stats");
diff --git a/src/commands/meme/mod.rs b/src/commands/meme/mod.rs
index c40e80a..b5a3c98 100644
--- a/src/commands/meme/mod.rs
+++ b/src/commands/meme/mod.rs
@@ -1,4 +1,4 @@
-use diesel::PgConnection;
+use diesel_async::AsyncPgConnection;
use log::debug;
use rand::random;
use serenity::{
@@ -70,7 +70,7 @@ struct Memes;
async fn send_meme(
ctx: &Context,
t: &Meme,
- conn: &mut PgConnection,
+ conn: &mut AsyncPgConnection,
msg: &Message,
) -> CommandResult {
let should_tts =
diff --git a/src/db/mod.rs b/src/db/mod.rs
index e7ff17f..04f2239 100644
--- a/src/db/mod.rs
+++ b/src/db/mod.rs
@@ -11,24 +11,13 @@ use chrono::{
};
use diesel::{
prelude::*,
- r2d2::{
- ConnectionManager,
- ManageConnection,
- },
NotFound,
};
-
-use postgres::Client as RawPgConn;
-use r2d2_postgres::{
- postgres::{
- Config,
- NoTls,
- },
- PostgresConnectionManager as RawPgConnMgr,
-};
-
+use diesel_async::pooled_connection::AsyncDieselConnectionManager;
+use tokio_postgres::Client as RawPgConn;
use anyhow::anyhow;
-use diesel_migrations::MigrationHarness;
+use diesel_async::pooled_connection::deadpool::Pool;
+use deadpool_postgres::{Pool as RawPgConnMgr, PoolConfig};
use lazy_static::lazy_static;
use crate::{
@@ -42,47 +31,62 @@ use self::schema::*;
mod models;
mod schema;
-const MIGRATIONS: diesel_migrations::EmbeddedMigrations = diesel_migrations::embed_migrations!();
-static MIGRATE: std::sync::Once = std::sync::Once::new();
+const MIGRATIONS: diesel_async_migrations::EmbeddedMigrations = diesel_async_migrations::embed_migrations!();
+static MIGRATE: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::new();
lazy_static! {
static ref DB_URL: String =
env::var("DATABASE_URL").expect("no database url in environment");
static ref DB_CONFIG: Config = Config::from_str(&DB_URL).expect("parsing db url as config");
- static ref CONN_MGR: ConnectionManager<PgConnection> = ConnectionManager::new(DB_URL.clone());
- static ref RAW_CONN_MGR: RawPgConnMgr<NoTls> = RawPgConnMgr::new(DB_CONFIG.clone(), NoTls);
+
+ static ref POOL: diesel_async::pooled_connection::deadpool::Pool<AsyncDieselConnectionManager<AsyncPgConnection>> = {
+ let cfg = AsyncDieselConnectionManager::new(DB_URL.clone());
+ let pool = Pool::builder(cfg).build().unwrap();
+
+ pool
+ };
+
+ static ref RAW_CONN_MGR: RawPgConnMgr = {
+ deadpool_postgres::Config::builder(tokio_postgres::NoTls).expect("failed to init config")
+ .config(PoolConfig::new(8))
+ .build().expect("failed to build pool")
+ };
}
#[inline]
-pub fn connection() -> Result<PgConnection> {
- CONN_MGR
- .connect()
- .map(|mut conn| {
- MIGRATE.call_once(|| {
- log::info!("running migrations");
- conn.run_pending_migrations(MIGRATIONS).expect("failed running migrations");
- log::info!("migrations complete");
- });
+pub async fn connection() -> Result<AsyncPgConnection> {
+ let pool: &Pool<AsyncDieselConnectionManager<_>> = POOL;
+
+ pool.get()
+ .then(|mut conn| async move {
+ if let Ok(conn) = conn {
+ MIGRATE.get_or_init(|| async move {
+ log::info!("running migrations");
+ MIGRATIONS.run_pending_migrations(&mut conn).await.expect("failed running migrations");
+ log::info!("migrations complete");
+ }).await;
+ }
conn
})
+ .await
.map_err(Error::from)
}
#[inline]
-fn raw_connection() -> Result<RawPgConn> {
+async fn raw_connection() -> Result<tokio_postgres::Client> {
// HACK
- if !MIGRATE.is_completed() {
- connection()?;
+ if !MIGRATE.initialized() {
+ connection().await?;
}
RAW_CONN_MGR.connect().map_err(Error::from)
}
-pub fn find_meme<T: AsRef<str>>(conn: &mut PgConnection, search: T) -> Result<Meme> {
+pub async fn find_meme<T: AsRef<str>>(conn: &mut AsyncPgConnection, search: T) -> Result<Meme> {
let search = search.as_ref();
- let mut meme = memes::table.filter(memes::title.eq(search)).limit(1).first::<Meme>(conn);
+ let mut meme = memes::table.filter(memes::title.eq(search)).limit(1).first::<Meme>(conn).await;
if let Err(NotFound) = meme {
let format_search = format!("%{}%", search);
@@ -90,18 +94,18 @@ pub fn find_meme<T: AsRef<str>>(conn: &mut PgConnection, search: T) -> Result<Me
meme = memes::table
.filter(memes::title.ilike(&format_search).or(memes::content.ilike(&format_search)))
.limit(1)
- .first::<Meme>(conn);
+ .first::<Meme>(conn).await;
}
meme.map_err(Error::from)
}
-pub fn query_meme<T: AsRef<str>>(
+pub async fn query_meme<T: AsRef<str>>(
search: T,
user_id: Option<u64>,
age_desc: bool,
) -> Result<Vec<(Meme, Metadata)>> {
- let mut raw_conn = raw_connection()?;
+ let mut raw_conn = raw_connection().await?;
let search = format!("%{}%", search.as_ref());
@@ -123,7 +127,7 @@ pub fn query_meme<T: AsRef<str>>(
},
),
&[&search, &(user_id.unwrap_or(0) as i64), &user_id.is_none()],
- )?;
+ ).await?;
let result = rows
.iter()
@@ -150,21 +154,21 @@ pub fn query_meme<T: AsRef<str>>(
Ok(result)
}
-pub fn delete_meme<T: AsRef<str>>(
- conn: &mut PgConnection,
+pub async fn delete_meme<T: AsRef<str>>(
+ conn: &mut AsyncPgConnection,
search: T,
deleted_by: u64,
) -> Result<()> {
conn.transaction::<(), Error, _>(|tx| {
- let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::<Meme>(tx)?;
+ let deleted = memes::table.filter(memes::title.eq(search.as_ref())).first::<Meme>(tx).await?;
- diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx)?;
+ diesel::delete(memes::table).filter(memes::id.eq(deleted.id)).execute(tx).await?;
if let Some(image_id) = deleted.image_id {
- let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx)?;
+ let count = memes::table.filter(memes::image_id.eq(image_id)).count().execute(tx).await?;
if count == 0 {
- diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx)?;
+ diesel::delete(images::table).filter(images::id.eq(image_id)).execute(tx).await?;
}
}
@@ -172,10 +176,10 @@ pub fn delete_meme<T: AsRef<str>>(
let count = memes::table
.select(::diesel::dsl::count_star())
.filter(memes::audio_id.eq(audio_id))
- .execute(tx)?;
+ .execute(tx).await?;
if count == 0 {
- diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx)?;
+ diesel::delete(audio::table).filter(audio::id.eq(audio_id)).execute(tx).await?;
}
}
@@ -185,16 +189,16 @@ pub fn delete_meme<T: AsRef<str>>(
meme_id: deleted.id,
};
- let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx)?;
+ let _ = diesel::insert_into(tombstones::table).values(&tombstone).execute(tx).await?;
Ok(())
})
}
-pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> {
+pub async fn rare_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result<Meme> {
use rand::prelude::*;
- let mut raw_conn = raw_connection()?;
+ let mut raw_conn = raw_connection().await?;
let rows = raw_conn.query(
r#"
@@ -229,7 +233,7 @@ pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> {
LIMIT 100;
"#,
&[&!audio, &audio],
- )?;
+ ).await?;
let elems = rows
.iter()
@@ -249,10 +253,10 @@ pub fn rare_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> {
.ok_or_else(|| anyhow!("couldn't locate meme satisfying target probability"))?
.0;
- Meme::find(conn, meme_id)
+ Meme::find(conn, meme_id).await
}
-pub fn rand_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> {
+pub async fn rand_meme(conn: &mut AsyncPgConnection, audio: bool) -> Result<Meme> {
use rand::{
seq::SliceRandom,
thread_rng,
@@ -268,21 +272,23 @@ pub fn rand_meme(conn: &mut PgConnection, audio: bool) -> Result<Meme> {
.or(memes::audio_id.is_not_null()),
)
.load(conn)
+ .await
.map_err(Error::from)?
} else {
memes::table
.select(memes::id)
.filter(memes::content.is_not_null().or(memes::image_id.is_not_null()))
.load(conn)
+ .await
.map_err(Error::from)?
};
let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load meme"))?;
- memes::table.find(id).first::<Meme>(conn).map_err(Error::from)
+ memes::table.find(id).first::<Meme>(conn).await.map_err(Error::from)
}
-pub fn rand_audio_meme(conn: &mut PgConnection) -> Result<Meme> {
+pub async fn rand_audio_meme(conn: &mut AsyncPgConnection) -> Result<Meme> {
use rand::{
seq::SliceRandom,
thread_rng,
@@ -292,14 +298,15 @@ pub fn rand_audio_meme(conn: &mut PgConnection) -> Result<Meme> {
.select(memes::id)
.filter(memes::audio_id.is_not_null())
.load(conn)
+ .await
.map_err(Error::from)?;
let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load audio meme"))?;
- memes::table.find(id).first::<Meme>(conn).map_err(Error::from)
+ memes::table.find(id).first::<Meme>(conn).await.map_err(Error::from)
}
-pub fn rand_silent_meme(conn: &mut PgConnection) -> Result<Meme> {
+pub async fn rand_silent_meme(conn: &mut AsyncPgConnection) -> Result<Meme> {
use rand::{
seq::SliceRandom,
thread_rng,
@@ -309,11 +316,12 @@ pub fn rand_silent_meme(conn: &mut PgConnection) -> Result<Meme> {
.select(memes::id)
.filter(memes::audio_id.is_null())
.load(conn)
+ .await
.map_err(Error::from)?;
let id = ids.choose(&mut thread_rng()).ok_or_else(|| anyhow!("couldn't load audio meme"))?;
- memes::table.find(id).first::<Meme>(conn).map_err(Error::from)
+ memes::table.find(id).first::<Meme>(conn).await.map_err(Error::from)
}
#[derive(Debug, Clone)]
@@ -347,7 +355,7 @@ pub struct Stats {
pub most_popular_meme_overall_count: usize,
}
-pub fn stats(conn: &mut PgConnection) -> Result<Stats> {
+pub async fn stats(conn: &mut AsyncPgConnection) -> Result<Stats> {
use chrono::{
NaiveDate,
NaiveDateTime,
@@ -373,36 +381,41 @@ pub fn stats(conn: &mut PgConnection) -> Result<Stats> {
.select(count(memes::image_id))
.filter(memes::image_id.is_not_null())
.first(conn)
+ .await
.map_err(Error::from)?;
let audio_count: i64 = memes::table
.select(count(memes::audio_id))
.filter(memes::audio_id.is_not_null())
.first(conn)
+ .await
.map_err(Error::from)?;
let started_recording: NaiveDateTime = invocation_records::table
.select(invocation_records::time)
.order(invocation_records::time)
.first(conn)
+ .await
.map_err(Error::from)?;
let started_recording = to_utc(started_recording);
let total_meme_invocations: i64 =
- invocation_records::table.select(count_star()).first(conn).map_err(Error::from)?;
+ invocation_records::table.select(count_star()).first(conn).await.map_err(Error::from)?;
let audio_meme_invocations: i64 = invocation_records::table
.inner_join(memes::table)
.select(count_star())
.filter(memes::audio_id.is_not_null())
.first(conn)
+ .await
.map_err(Error::from)?;
let random_meme_invocations: i64 = invocation_records::table
.select(count_star())
.filter(invocation_records::random.eq(true))
.first(conn)
+ .await
.map_err(Error::from)?;
let mut raw_conn = raw_connection()?;
diff --git a/src/db/models.rs b/src/db/models.rs
index bba25a2..f7bbf8e 100644
--- a/src/db/models.rs
+++ b/src/db/models.rs
@@ -5,6 +5,7 @@ use diesel::{
Insertable,
Queryable,
};
+use diesel_async::{AsyncPgConnection, RunQueryDsl};
use sha1::Digest;
use crate::{
@@ -25,18 +26,27 @@ pub struct Meme {
}
impl Meme {
- pub fn image(&self, conn: &mut PgConnection) -> Option<Result<Image>> {
+ pub async fn image(&self, conn: &mut AsyncPgConnection) -> Option<Result<Image>> {
self.image_id
- .map(|x: i32| images::table.filter(images::id.eq(x)).first(conn).map_err(Error::from))
+ .map(|x: i32| images::table.filter(images::id.eq(x))
+ .first(conn)
+ .await
+ .map_err(Error::from))
}
- pub fn audio(&self, conn: &mut PgConnection) -> Option<Result<Audio>> {
+ pub async fn audio(&self, conn: &mut AsyncPgConnection) -> Option<Result<Audio>> {
self.audio_id
- .map(|x: i32| audio::table.filter(audio::id.eq(x)).first(conn).map_err(Error::from))
+ .map(|x: i32| audio::table.filter(audio::id.eq(x))
+ .first(conn)
+ .await
+ .map_err(Error::from))
}
- pub fn find(conn: &mut PgConnection, id: i32) -> Result<Meme> {
- memes::table.find(id).get_result(conn).map_err(Error::from)
+ pub async fn find(conn: &mut AsyncPgConnection, id: i32) -> Result<Meme> {
+ memes::table.find(id)
+ .get_result(conn)
+ .await
+ .map_err(Error::from)
}
}
@@ -51,7 +61,7 @@ pub struct NewMeme {
}
impl NewMeme {
- pub fn save(mut self, conn: &mut PgConnection, by_user: u64) -> Result<Meme> {
+ pub async fn save(mut self, conn: &mut AsyncPgConnection, by_user: u64) -> Result<Meme> {
let metadata = Metadata::create(conn, by_user)?;
self.metadata_id = metadata.id;
@@ -59,6 +69,7 @@ impl NewMeme {
diesel::insert_into(memes::table)
.values(&self)
.get_result::<Meme>(conn)
+ .await
.map_err(Error::from)
}
}
@@ -73,7 +84,7 @@ pub struct Audio {
}
impl Audio {
- pub fn create(conn: &mut PgConnection, data: Vec<u8>, by_user: u64) -> Result<i32> {
+ pub fn create(conn: &mut AsyncPgConnection, data: Vec<u8>, by_user: u64) -> Result<i32> {
let mut data_hash = ::sha1::Sha1::new();
data_hash.update(&data);
let data_hash = data_hash.finalize().to_vec();
@@ -81,7 +92,8 @@ impl Audio {
let id = audio::table
.select(audio::id)
.filter(audio::data_hash.eq(&data_hash))
- .get_results::<i32>(conn)?;
+ .get_results::<i32>(conn)
+ .await?;
if let Some(id) = id.first() {
return Ok(*id);
@@ -99,6 +111,7 @@ impl Audio {
.values(&new_audio)
.returning(audio::id)
.get_result(conn)
+ .await
.map_err(Error::from)
}
}
@@ -123,7 +136,7 @@ pub struct Image {
impl Image {
pub fn create(
- conn: &mut PgConnection,
+ conn: &mut AsyncPgConnection,
filename: &str,
data: Vec<u8>,
by_user: u64,
@@ -135,7 +148,8 @@ impl Image {
let id = images::table
.select(images::id)
.filter(images::data_hash.eq(&data_hash))
- .get_results::<i32>(conn)?;
+ .get_results::<i32>(conn)
+ .await?;
if let Some(id) = id.first() {
return Ok(*id);
@@ -154,6 +168,7 @@ impl Image {
.values(&new_image)
.returning(images::id)
.get_result(conn)
+ .await
.map_err(Error::from)
}
}
@@ -176,17 +191,18 @@ pub struct Metadata {
}
impl Metadata {
- pub fn create(conn: &mut PgConnection, by_user: u64) -> Result<Metadata> {
+ pub fn create(conn: &mut AsyncPgConnection, by_user: u64) -> Result<Metadata> {
diesel::insert_into(metadata::table)
.values(&NewMetadata {
created_by: by_user as i64,
})
.get_result::<Metadata>(conn)
+ .await
.map_err(Error::from)
}
- pub fn find(conn: &mut PgConnection, id: i32) -> Result<Metadata> {
- metadata::table.find(id).get_result::<Metadata>(conn).map_err(Error::from)
+ pub fn find(conn: &mut AsyncPgConnection, id: i32) -> Result<Metadata> {
+ metadata::table.find(id).get_result::<Metadata>(conn).await.map_err(Error::from)
}
}
@@ -206,13 +222,14 @@ pub struct AuditRecord {
}
impl AuditRecord {
- pub fn create(conn: &mut PgConnection, metadata: i32, by_user: u64) -> Result<AuditRecord> {
+ pub fn create(conn: &mut AsyncPgConnection, metadata: i32, by_user: u64) -> Result<AuditRecord> {
diesel::insert_into(audit_records::table)
.values(&NewAuditRecord {
updated_by: by_user as i64,
metadata_id: metadata,
})
.get_result::<AuditRecord>(conn)
+ .await
.map_err(Error::from)
}
}
@@ -264,7 +281,7 @@ pub struct NewInvocationRecord {
impl InvocationRecord {
pub fn create(
- conn: &mut PgConnection,
+ conn: &mut AsyncPgConnection,
user_id: u64,
message_id: u64,
meme_id: i32,
@@ -278,21 +295,24 @@ impl InvocationRecord {
random,
})
.get_result::<InvocationRecord>(conn)
+ .await
.map_err(Error::from)
}
- pub fn last(conn: &mut PgConnection) -> Result<Self> {
+ pub fn last(conn: &mut AsyncPgConnection) -> Result<Self> {
invocation_records::table
.order(invocation_records::time.desc())
.first(conn)
+ .await
.map_err(Error::from)
}
- pub fn last_n(conn: &mut PgConnection, n: usize) -> Result<Vec<Self>> {
+ pub fn last_n(conn: &mut AsyncPgConnection, n: usize) -> Result<Vec<Self>> {
invocation_records::table
.order(invocation_records::time.desc())
.limit(n as i64)
.load(conn)
+ .await
.map_err(Error::from)
}
}