diff --git a/src/async_traits.rs b/src/async_traits.rs index 38c879a..3d75775 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -1,6 +1,6 @@ //! Async versions of traits for issuing Diesel queries. -use crate::connection::Connection; +use crate::{connection::Connection, error::RunError}; use async_trait::async_trait; use diesel::{ connection::{ @@ -21,7 +21,7 @@ use std::any::Any; use std::future::Future; use std::sync::Arc; use std::sync::MutexGuard; -use tokio::task::spawn_blocking; +use tokio::task::{spawn_blocking, JoinError}; /// An async variant of [`diesel::connection::SimpleConnection`]. #[async_trait] @@ -48,13 +48,13 @@ where Conn: 'static + DieselConnection + R2D2Connection, Self: Send + Sized + 'static, { - async fn ping_async(&mut self) -> diesel::result::QueryResult<()> { + async fn ping_async(&mut self) -> Result<(), RunError> { self.as_async_conn().run(|conn| conn.ping()).await } async fn is_broken_async(&mut self) -> bool { self.as_async_conn() - .run(|conn| Ok::(conn.is_broken())) + .run(|conn| Ok::(conn.is_broken())) .await .unwrap() } @@ -75,43 +75,36 @@ where fn as_async_conn(&self) -> &Connection; /// Runs the function `f` in an context where blocking is safe. - async fn run(&self, f: Func) -> Result + async fn run(&self, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let connection = self.get_owned_connection(); connection.run_with_connection(f).await } #[doc(hidden)] - async fn run_with_connection(self, f: Func) -> Result + async fn run_with_connection(self, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - spawn_blocking(move || f(&mut *self.as_sync_conn())) - .await - .unwrap() // Propagate panics + handle_spawn_blocking_error(spawn_blocking(move || f(&mut *self.as_sync_conn())).await) } #[doc(hidden)] - async fn run_with_shared_connection(self: &Arc, f: Func) -> Result + async fn run_with_shared_connection(self: &Arc, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let conn = self.clone(); - spawn_blocking(move || f(&mut *conn.as_sync_conn())) - .await - .unwrap() // Propagate panics + handle_spawn_blocking_error(spawn_blocking(move || f(&mut *conn.as_sync_conn())).await) } #[doc(hidden)] - async fn transaction_depth(&self) -> Result { + async fn transaction_depth(&self) -> Result { let conn = self.get_owned_connection(); Self::run_with_connection(conn, |conn| { @@ -131,9 +124,9 @@ where // This method is a wrapper around that call, with validation that // we're actually issuing the BEGIN statement here. #[doc(hidden)] - async fn start_transaction(self: &Arc) -> Result<(), DieselError> { + async fn start_transaction(self: &Arc) -> Result<(), RunError> { if self.transaction_depth().await? != 0 { - return Err(DieselError::AlreadyInTransaction); + return Err(RunError::DieselError(DieselError::AlreadyInTransaction)); } self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) .await?; @@ -146,11 +139,11 @@ where // This method is a wrapper around that call, with validation that // we're actually issuing our first SAVEPOINT here. #[doc(hidden)] - async fn add_retry_savepoint(self: &Arc) -> Result<(), DieselError> { + async fn add_retry_savepoint(self: &Arc) -> Result<(), RunError> { match self.transaction_depth().await? { - 0 => return Err(DieselError::NotInTransaction), + 0 => return Err(RunError::DieselError(DieselError::NotInTransaction)), 1 => (), - _ => return Err(DieselError::AlreadyInTransaction), + _ => return Err(RunError::DieselError(DieselError::AlreadyInTransaction)), }; self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) @@ -159,14 +152,14 @@ where } #[doc(hidden)] - async fn commit_transaction(self: &Arc) -> Result<(), DieselError> { + async fn commit_transaction(self: &Arc) -> Result<(), RunError> { self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn)) .await?; Ok(()) } #[doc(hidden)] - async fn rollback_transaction(self: &Arc) -> Result<(), DieselError> { + async fn rollback_transaction(self: &Arc) -> Result<(), RunError> { self.run_with_shared_connection(|conn| { Conn::TransactionManager::rollback_transaction(conn) }) @@ -185,10 +178,10 @@ where &'a self, f: Func, retry: RetryFunc, - ) -> Result + ) -> Result where R: Any + Send + 'static, - Fut: FutureExt> + Send, + Fut: FutureExt> + Send, Func: (Fn(Connection) -> Fut) + Send + Sync, RetryFut: FutureExt + Send, RetryFunc: Fn() -> RetryFut + Send + Sync, @@ -221,11 +214,11 @@ where #[cfg(feature = "cockroach")] async fn transaction_async_with_retry_inner( &self, - f: &(dyn Fn(Connection) -> BoxFuture<'_, Result, DieselError>> + f: &(dyn Fn(Connection) -> BoxFuture<'_, Result, RunError>> + Send + Sync), retry: &(dyn Fn() -> BoxFuture<'_, bool> + Send + Sync), - ) -> Result, DieselError> { + ) -> Result, RunError> { // Check out a connection once, and use it for the duration of the // operation. let conn = Arc::new(self.get_owned_connection()); @@ -262,20 +255,26 @@ where // // We're still in the transaction, but we at least // tried to ROLLBACK to our savepoint. - if !retryable_error(&err) || !retry().await { + let retried = match &err { + RunError::DieselError(err) => retryable_error(err) && retry().await, + RunError::RuntimeShutdown => false, + }; + if retried { + // ROLLBACK happened, we want to retry. + continue; + } else { // Bail: ROLLBACK the initial BEGIN statement too. + // In the case of the run let _ = Self::rollback_transaction(&conn).await; return Err(err); } - // ROLLBACK happened, we want to retry. - continue; } // Commit the top-level transaction too. Self::commit_transaction(&conn).await?; return Ok(value); } - Err(user_error) => { + Err(RunError::DieselError(user_error)) => { // The user-level operation failed: ROLLBACK to the retry // savepoint. if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await { @@ -295,10 +294,18 @@ where // If we aren't retrying, ROLLBACK the BEGIN statement too. return match Self::rollback_transaction(&conn).await { - Ok(()) => Err(user_error), + Ok(()) => Err(RunError::DieselError(user_error)), Err(err) => Err(err), }; } + Err(RunError::RuntimeShutdown) => { + // The runtime is shutting down: attempt to ROLLBACK the + // transaction. You might think it's pointless to try and + // do this, but in reality the shutdown timeout for async + // tasks might not have expired. + let _ = Self::rollback_transaction(&conn).await; + return Err(RunError::RuntimeShutdown); + } } } } @@ -306,7 +313,7 @@ where async fn transaction_async(&'a self, f: Func) -> Result where R: Send + 'static, - E: From + Send + 'static, + E: From + Send + 'static, Fut: Future> + Send, Func: FnOnce(Connection) -> Fut + Send, { @@ -325,9 +332,8 @@ where .boxed() }); - self.transaction_async_inner(f) - .await - .map(|v| *v.downcast::().expect("Should be an 'R' type")) + let v = self.transaction_async_inner(f).await?; + Ok(*v.downcast::().expect("Should be an 'R' type")) } // NOTE: This function intentionally avoids as many generic parameters as possible @@ -340,7 +346,7 @@ where >, ) -> Result, E> where - E: From + Send + 'static, + E: From + Send + 'static, { // Check out a connection once, and use it for the duration of the // operation. @@ -351,10 +357,8 @@ where // // However, it modifies all callsites to instead issue // known-to-be-synchronous operations from an asynchronous context. - conn.run_with_shared_connection(|conn| { - Conn::TransactionManager::begin_transaction(conn).map_err(E::from) - }) - .await?; + conn.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) + .await?; // TODO: The ideal interface would pass the "async_conn" object to the // underlying function "f" by reference. @@ -372,21 +376,45 @@ where match f(async_conn).await { Ok(value) => { conn.run_with_shared_connection(|conn| { - Conn::TransactionManager::commit_transaction(conn).map_err(E::from) + Conn::TransactionManager::commit_transaction(conn) }) .await?; Ok(value) } Err(user_error) => { - match conn - .run_with_shared_connection(|conn| { - Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) - }) - .await - { - Ok(()) => Err(user_error), - Err(err) => Err(err), - } + conn.run_with_shared_connection(|conn| { + Conn::TransactionManager::rollback_transaction(conn) + }) + .await?; + Err(user_error) + } + } + } +} + +fn handle_spawn_blocking_error( + result: Result, JoinError>, +) -> Result { + match result { + Ok(Ok(v)) => Ok(v), + Ok(Err(err)) => Err(RunError::DieselError(err)), + Err(err) => { + if err.is_cancelled() { + // The only way a spawn_blocking task can be marked cancelled + // is if the runtime started shutting down _before_ + // spawn_blocking was called. + Err(RunError::RuntimeShutdown) + } else if err.is_panic() { + // Propagate panics. + std::panic::panic_any(err.into_panic()); + } else { + // Not possible to reach this as of Tokio 1.40, but maybe in + // future versions. + panic!( + "unexpected JoinError, did a new version of Tokio add \ + a source other than panics and cancellations? {:?}", + err + ); } } } @@ -398,26 +426,26 @@ pub trait AsyncRunQueryDsl where Conn: 'static + DieselConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result + async fn execute_async(self, asc: &AsyncConn) -> Result where Self: ExecuteDsl; - async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn load_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_result_async(self, asc: &AsyncConn) -> Result + async fn get_result_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn first_async(self, asc: &AsyncConn) -> Result + async fn first_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LimitDsl, @@ -431,14 +459,14 @@ where Conn: 'static + DieselConnection, AsyncConn: Send + Sync + AsyncConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result + async fn execute_async(self, asc: &AsyncConn) -> Result where Self: ExecuteDsl, { asc.run(|conn| self.execute(conn)).await } - async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn load_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -446,7 +474,7 @@ where asc.run(|conn| self.load(conn)).await } - async fn get_result_async(self, asc: &AsyncConn) -> Result + async fn get_result_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -454,7 +482,7 @@ where asc.run(|conn| self.get_result(conn)).await } - async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -462,7 +490,7 @@ where asc.run(|conn| self.get_results(conn)).await } - async fn first_async(self, asc: &AsyncConn) -> Result + async fn first_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LimitDsl, @@ -477,7 +505,7 @@ pub trait AsyncSaveChangesDsl where Conn: 'static + DieselConnection, { - async fn save_changes_async(self, asc: &AsyncConn) -> Result + async fn save_changes_async(self, asc: &AsyncConn) -> Result where Self: Sized, Conn: diesel::query_dsl::UpdateAndFetchResults, @@ -491,7 +519,7 @@ where Conn: 'static + DieselConnection, AsyncConn: Send + Sync + AsyncConnection, { - async fn save_changes_async(self, asc: &AsyncConn) -> Result + async fn save_changes_async(self, asc: &AsyncConn) -> Result where Conn: diesel::query_dsl::UpdateAndFetchResults, Output: Send + 'static, diff --git a/src/error.rs b/src/error.rs index 333c13e..184df43 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,6 +20,18 @@ pub enum ConnectionError { #[error("Failed to issue a query: {0}")] Query(#[from] DieselError), + + #[error("runtime shutting down")] + RuntimeShutdown, +} + +impl From for ConnectionError { + fn from(error: RunError) -> Self { + match error { + RunError::DieselError(e) => ConnectionError::Query(e), + RunError::RuntimeShutdown => ConnectionError::RuntimeShutdown, + } + } } /// Syntactic sugar around a Result returning an [`PoolError`]. @@ -44,6 +56,31 @@ impl OptionalExtension for Result { } } +impl OptionalExtension for Result { + fn optional(self) -> Result, ConnectionError> { + let self_as_query_result: diesel::QueryResult = match self { + Ok(value) => Ok(value), + Err(RunError::DieselError(error_kind)) => Err(error_kind), + Err(RunError::RuntimeShutdown) => return Err(ConnectionError::RuntimeShutdown), + }; + + self_as_query_result + .optional() + .map_err(ConnectionError::Query) + } +} + +/// An error encountered while running a function on a connection pool. +#[derive(Error, Debug, PartialEq)] +pub enum RunError { + /// There was a Diesel error running the query. + #[error(transparent)] + DieselError(#[from] DieselError), + + #[error("runtime shutting down")] + RuntimeShutdown, +} + /// Describes an error performing an operation from a connection pool. /// /// This is a superset of [`ConnectionError`] which also may diff --git a/src/lib.rs b/src/lib.rs index e4fbecd..e9aef5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,4 +16,6 @@ pub use async_traits::{ }; pub use connection::Connection; pub use connection_manager::ConnectionManager; -pub use error::{ConnectionError, ConnectionResult, OptionalExtension, PoolError, PoolResult}; +pub use error::{ + ConnectionError, ConnectionResult, OptionalExtension, PoolError, PoolResult, RunError, +}; diff --git a/tests/test.rs b/tests/test.rs index 31fea49..21ff106 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -4,9 +4,9 @@ use async_bb8_diesel::{ AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, AsyncSimpleConnection, ConnectionError, + OptionalExtension, RunError, }; use crdb_harness::{CockroachInstance, CockroachStarterBuilder}; -use diesel::OptionalExtension; use diesel::{pg::PgConnection, prelude::*}; table! { @@ -208,13 +208,17 @@ async fn test_transaction_automatic_retry_explicit_rollback() { if *count < 2 { eprintln!("test: Manually restarting txn"); - return Err::<(), _>(diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::SerializationFailure, - Box::new("restart transaction".to_string()), + return Err::<(), _>(RunError::DieselError( + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::SerializationFailure, + Box::new("restart transaction".to_string()), + ), )); } eprintln!("test: Manually rolling back txn"); - return Err(diesel::result::Error::RollbackTransaction); + return Err(RunError::DieselError( + diesel::result::Error::RollbackTransaction, + )); } }, || async { @@ -225,7 +229,10 @@ async fn test_transaction_automatic_retry_explicit_rollback() { .await .expect_err("Transaction should have failed"); - assert_eq!(err, diesel::result::Error::RollbackTransaction); + assert_eq!( + err, + RunError::DieselError(diesel::result::Error::RollbackTransaction) + ); assert_eq!(conn.transaction_depth().await.unwrap(), 0); // The transaction closure should have been attempted twice, but @@ -317,12 +324,12 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() assert_eq!(conn.transaction_depth().await.unwrap(), 0); assert_eq!( conn.transaction_async_with_retry( - |_| async { Err::<(), _>(diesel::result::Error::NotFound) }, + |_| async { Err::<(), _>(RunError::DieselError(diesel::result::Error::NotFound)) }, || async { panic!("Should not attempt to retry this operation") } ) .await .expect_err("Transaction should have failed"), - diesel::result::Error::NotFound, + RunError::DieselError(diesel::result::Error::NotFound), ); assert_eq!(conn.transaction_depth().await.unwrap(), 0); @@ -361,7 +368,10 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { ) .await .expect_err("Nested transaction should have failed"); - assert_eq!(err, diesel::result::Error::AlreadyInTransaction); + assert_eq!( + err, + RunError::DieselError(diesel::result::Error::AlreadyInTransaction) + ); // We still want to show that control exists within the outer // transaction, so we explicitly return here. @@ -395,9 +405,9 @@ async fn test_transaction_custom_error() { Other, } - impl From for MyError { - fn from(error: diesel::result::Error) -> Self { - MyError::Db(ConnectionError::Query(error)) + impl From for MyError { + fn from(error: RunError) -> Self { + MyError::Db(ConnectionError::from(error)) } }