diff --git a/src/async_traits.rs b/src/async_traits.rs index 38c879a..3c96f32 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -41,6 +41,57 @@ fn retryable_error(err: &DieselError) -> bool { } } +// Kills the connection when this object is dropped. +// +// Specifically, this sets the diesel "TransactionManager" state +// to broken, which should prevent any subsequent access to the +// connection from succceeding. +// +// This aims to help avoid leaving open transactions alive +// if an asynchronous transaction is cancelled. +#[must_use] +struct ConnectionKiller<'a, Conn> +where + Conn: DieselConnection, +{ + conn: Option<&'a Connection>, +} + +impl<'a, Conn> ConnectionKiller<'a, Conn> +where + Conn: DieselConnection, +{ + fn new(conn: &'a Connection) -> ConnectionKiller<'a, Conn> { + Self { conn: Some(conn) } + } + + // Prevents the connection from being killed. + // + // This should be called if a transaction has completed successfully. + fn spare(&mut self) { + self.conn.take(); + } +} + +impl<'a, Conn> Drop for ConnectionKiller<'a, Conn> +where + Conn: DieselConnection, +{ + fn drop(&mut self) { + let Some(conn) = self.conn.take() else { + return; + }; + + // Ensure that non-transaction operations fail + conn.mark_broken(); + + // Ensure that transactions fail + let mut conn = conn.inner(); + *Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) = + TransactionManagerStatus::InError; + } +} + /// An async variant of [`diesel::r2d2::R2D2Connection`]. #[async_trait] pub trait AsyncR2D2Connection: AsyncConnection @@ -54,7 +105,7 @@ where async fn is_broken_async(&mut self) -> bool { self.as_async_conn() - .run(|conn| Ok::(conn.is_broken())) + .run(|conn| Ok::(conn.is_broken())) .await .unwrap() } @@ -74,24 +125,34 @@ where #[doc(hidden)] fn as_async_conn(&self) -> &Connection; + // Identifies if the connection has been broken + // by an invalid transaction. This should prevent + // future usage. + #[doc(hidden)] + fn is_broken_from_txn(&self) -> bool { + false + } + /// Runs the function `f` in an context where blocking is safe. - async fn run(&self, f: Func) -> Result + async fn run(&self, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let connection = self.get_owned_connection(); connection.run_with_connection(f).await } #[doc(hidden)] - async fn run_with_connection(self, f: Func) -> Result + async fn run_with_connection(self, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { + if self.is_broken_from_txn() { + return Err(DieselError::BrokenTransactionManager); + } + spawn_blocking(move || f(&mut *self.as_sync_conn())) .await .unwrap() // Propagate panics @@ -111,18 +172,15 @@ where } #[doc(hidden)] - async fn transaction_depth(&self) -> Result { - let conn = self.get_owned_connection(); + fn transaction_depth(&self) -> Result { + let mut conn = self.as_sync_conn(); - Self::run_with_connection(conn, |conn| { - match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) { - TransactionManagerStatus::Valid(status) => { - Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0)) - } - TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager), + match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) { + TransactionManagerStatus::Valid(status) => { + Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0)) } - }) - .await + TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager), + } } // Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a @@ -132,7 +190,7 @@ where // we're actually issuing the BEGIN statement here. #[doc(hidden)] async fn start_transaction(self: &Arc) -> Result<(), DieselError> { - if self.transaction_depth().await? != 0 { + if self.transaction_depth()? != 0 { return Err(DieselError::AlreadyInTransaction); } self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) @@ -147,7 +205,7 @@ where // we're actually issuing our first SAVEPOINT here. #[doc(hidden)] async fn add_retry_savepoint(self: &Arc) -> Result<(), DieselError> { - match self.transaction_depth().await? { + match self.transaction_depth()? { 0 => return Err(DieselError::NotInTransaction), 1 => (), _ => return Err(DieselError::AlreadyInTransaction), @@ -230,6 +288,12 @@ where // operation. let conn = Arc::new(self.get_owned_connection()); + // Before we start doing any transaction operations, if we drop + // this future before exiting this function cleanly, we want to + // ensure the connection is killed, rather than existing in a + // "unknown, mid-transaction" state. + let mut killer = ConnectionKiller::new(conn.as_async_conn()); + // Refer to CockroachDB's guide on advanced client-side transaction // retries for the full context: // https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries @@ -241,7 +305,9 @@ where // TODO: It may be preferable to set this once per connection -- but // that'll require more interaction with how sessions with the database // are constructed. - Self::start_transaction(&conn).await?; + Self::start_transaction(&conn).await.inspect_err(|_| { + killer.spare(); + })?; conn.run_with_shared_connection(|conn| { conn.batch_execute("SET LOCAL force_savepoint_restart = true") }) @@ -251,7 +317,7 @@ where // Add a SAVEPOINT to which we can later return. Self::add_retry_savepoint(&conn).await?; - let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); + let async_conn = Self::as_async_conn(&conn).clone(); match f(async_conn).await { Ok(value) => { // The user-level operation succeeded: try to commit the @@ -265,6 +331,7 @@ where if !retryable_error(&err) || !retry().await { // Bail: ROLLBACK the initial BEGIN statement too. let _ = Self::rollback_transaction(&conn).await; + killer.spare(); return Err(err); } // ROLLBACK happened, we want to retry. @@ -273,6 +340,7 @@ where // Commit the top-level transaction too. Self::commit_transaction(&conn).await?; + killer.spare(); return Ok(value); } Err(user_error) => { @@ -281,10 +349,12 @@ where if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await { // If we fail while rolling back, prioritize returning // the ROLLBACK error over the user errors. - return match Self::rollback_transaction(&conn).await { + let res = match Self::rollback_transaction(&conn).await { Ok(()) => Err(first_rollback_err), Err(second_rollback_err) => Err(second_rollback_err), }; + killer.spare(); + return res; } // We rolled back to the retry savepoint, and now want to @@ -294,10 +364,12 @@ where } // If we aren't retrying, ROLLBACK the BEGIN statement too. - return match Self::rollback_transaction(&conn).await { + let res = match Self::rollback_transaction(&conn).await { Ok(()) => Err(user_error), Err(err) => Err(err), }; + killer.spare(); + return res; } } } @@ -346,6 +418,12 @@ where // operation. let conn = Arc::new(self.get_owned_connection()); + // Before we start doing any transaction operations, if we drop + // this future before exiting this function cleanly, we want to + // ensure the connection is killed, rather than existing in a + // "unknown, mid-transaction" state. + let mut killer = ConnectionKiller::new(conn.as_async_conn()); + // This function mimics the implementation of: // https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction // @@ -354,7 +432,10 @@ where conn.run_with_shared_connection(|conn| { Conn::TransactionManager::begin_transaction(conn).map_err(E::from) }) - .await?; + .await + .inspect_err(|_| { + killer.spare(); + })?; // TODO: The ideal interface would pass the "async_conn" object to the // underlying function "f" by reference. @@ -368,13 +449,16 @@ where // enough to be referenceable by a Future, but short enough that we can // guarantee it doesn't live persist after this function returns, feel // free to make that change. - let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); - match f(async_conn).await { + let async_conn = Self::as_async_conn(&conn).clone(); + let res = match f(async_conn).await { Ok(value) => { conn.run_with_shared_connection(|conn| { Conn::TransactionManager::commit_transaction(conn).map_err(E::from) }) - .await?; + .await + .inspect_err(|_| { + killer.spare(); + })?; Ok(value) } Err(user_error) => { @@ -388,7 +472,9 @@ where Err(err) => Err(err), } } - } + }; + killer.spare(); + res } } diff --git a/src/connection.rs b/src/connection.rs index a045a38..85394f4 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,7 +1,10 @@ //! An async wrapper around a [`diesel::Connection`]. +use crate::async_traits::AsyncConnection; use async_trait::async_trait; use diesel::r2d2::R2D2Connection; +use diesel::result::Error as DieselError; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, MutexGuard}; use tokio::task; @@ -13,11 +16,27 @@ use tokio::task; /// All blocking methods within this type delegate to /// [`tokio::task::spawn_blocking`], meaning they won't block /// any asynchronous work or threads. -pub struct Connection(pub(crate) Arc>); +pub struct Connection(Arc>); + +pub struct ConnectionInner { + pub(crate) inner: Mutex, + pub(crate) broken: AtomicBool, +} impl Connection { pub fn new(c: C) -> Self { - Self(Arc::new(Mutex::new(c))) + Self(Arc::new(ConnectionInner { + inner: Mutex::new(c), + broken: AtomicBool::new(false), + })) + } + + pub(crate) fn clone(&self) -> Self { + Self(self.0.clone()) + } + + pub(crate) fn mark_broken(&self) { + self.0.broken.store(true, Ordering::SeqCst); } // Accesses the underlying connection. @@ -25,7 +44,7 @@ impl Connection { // As this is a blocking mutex, it's recommended to avoid invoking // this function from an asynchronous context. pub(crate) fn inner(&self) -> MutexGuard<'_, C> { - self.0.lock().unwrap() + self.0.inner.lock().unwrap() } } @@ -36,7 +55,11 @@ where { #[inline] async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> { - let diesel_conn = Connection(self.0.clone()); + if self.is_broken_from_txn() { + return Err(DieselError::BrokenTransactionManager); + } + + let diesel_conn = self.clone(); let query = query.to_string(); task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query)) .await @@ -55,7 +78,7 @@ where Connection: crate::AsyncSimpleConnection, { fn get_owned_connection(&self) -> Self { - Connection(self.0.clone()) + self.clone() } // Accesses the connection synchronously, protected by a mutex. @@ -68,4 +91,8 @@ where fn as_async_conn(&self) -> &Connection { self } + + fn is_broken_from_txn(&self) -> bool { + self.0.broken.load(Ordering::SeqCst) + } } diff --git a/src/connection_manager.rs b/src/connection_manager.rs index 5da755a..cab9d2a 100644 --- a/src/connection_manager.rs +++ b/src/connection_manager.rs @@ -75,7 +75,7 @@ where } async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { - let c = Connection(conn.0.clone()); + let c = conn.clone(); self.run_blocking(move |m| { m.is_valid(&mut *c.inner())?; Ok(()) diff --git a/tests/test.rs b/tests/test.rs index 31fea49..0677972 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -5,6 +5,7 @@ use async_bb8_diesel::{ AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, AsyncSimpleConnection, ConnectionError, }; +use bb8::ManageConnection; use crdb_harness::{CockroachInstance, CockroachStarterBuilder}; use diesel::OptionalExtension; use diesel::{pg::PgConnection, prelude::*}; @@ -150,6 +151,104 @@ async fn test_transaction() { test_end(crdb).await; } +enum ConnectionState { + Dead, + Alive, +} + +type ConnectionType = async_bb8_diesel::Connection; + +// Check that we can (or cannot!): +// - Run a batch SQL operation +// - Run a specific SQL operation ("INSERT") +// - Create a transaction +async fn check(conn: &ConnectionType, expected: ConnectionState) { + let result = conn.batch_execute_async("SELECT 1;").await; + match expected { + ConnectionState::Alive => { + result.expect("Should have succeeded"); + } + ConnectionState::Dead => { + result.expect_err("Should have failed"); + } + }; + + use user::dsl; + let result = diesel::insert_into(dsl::user) + .values((dsl::id.eq(0), dsl::name.eq("Jim"))) + .execute_async(&*conn) + .await; + match expected { + ConnectionState::Alive => { + result.expect("Should have succeeded"); + } + ConnectionState::Dead => { + result.expect_err("Should have failed"); + } + }; + + let result = conn + .transaction_async(|_conn| async move { Ok::<(), ConnectionError>(()) }) + .await; + match expected { + ConnectionState::Alive => { + result.expect("Should have succeeded"); + } + ConnectionState::Dead => { + result.expect_err("Should have failed"); + } + }; +} + +#[tokio::test] +async fn test_transaction_cancellation() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let conn = manager.connect().await.unwrap(); + + check(&conn, ConnectionState::Alive).await; + + // Create a transaction which gets cancelled + let txn_fut = conn.transaction_async(|_conn| async move { + tokio::time::sleep(tokio::time::Duration::from_secs(1000)).await; + Ok::<(), ConnectionError>(()) + }); + tokio::time::timeout(tokio::time::Duration::from_millis(5), txn_fut) + .await + .expect_err("Should have timed out"); + + check(&conn, ConnectionState::Dead).await; + + test_end(crdb).await; +} + +#[tokio::test] +async fn test_transaction_retry_cancellation() { + let crdb = test_start().await; + + let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); + let conn = manager.connect().await.unwrap(); + + check(&conn, ConnectionState::Alive).await; + + // Create a retryable transaction which gets cancelled + let txn_fut = conn.transaction_async_with_retry( + |_conn| async move { + tokio::time::sleep(tokio::time::Duration::from_secs(1000)).await; + Ok(()) + }, + || async { panic!("Should not attempt to retry this operation") }, + ); + tokio::time::timeout(tokio::time::Duration::from_millis(5), txn_fut) + .await + .expect_err("Should have timed out"); + + check(&conn, ConnectionState::Dead).await; + + test_end(crdb).await; +} + #[tokio::test] async fn test_transaction_automatic_retry_success_case() { let crdb = test_start().await; @@ -161,10 +260,10 @@ async fn test_transaction_automatic_retry_success_case() { use user::dsl; // Transaction that can retry but does not need to. - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); conn.transaction_async_with_retry( |conn| async move { - assert!(conn.transaction_depth().await.unwrap() > 0); + assert!(conn.transaction_depth().unwrap() > 0); diesel::insert_into(dsl::user) .values((dsl::id.eq(3), dsl::name.eq("Sally"))) .execute_async(&conn) @@ -175,7 +274,7 @@ async fn test_transaction_automatic_retry_success_case() { ) .await .expect("Transaction failed"); - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); test_end(crdb).await; } @@ -197,7 +296,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() { // // 1. Retries on the first call // 2. Explicitly rolls back on the second call - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); let err = conn .transaction_async_with_retry( |_conn| { @@ -226,7 +325,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() { .expect_err("Transaction should have failed"); assert_eq!(err, diesel::result::Error::RollbackTransaction); - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); // The transaction closure should have been attempted twice, but // we should have only asked whether or not to retry once -- after @@ -264,7 +363,7 @@ async fn test_transaction_automatic_retry_injected_errors() { conn.batch_execute_async("SET inject_retry_errors_enabled = true") .await .expect("Failed to inject error"); - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); conn.transaction_async_with_retry( |conn| { let transaction_attempted_count = transaction_attempted_count.clone(); @@ -286,7 +385,7 @@ async fn test_transaction_automatic_retry_injected_errors() { ) .await .expect("Transaction should have succeeded"); - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); // The transaction closure should have been attempted twice, but // we should have only asked whether or not to retry once -- after @@ -314,7 +413,7 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() // Test a transaction that: // // Fails with a non-retryable error. It should exit immediately. - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); assert_eq!( conn.transaction_async_with_retry( |_| async { Err::<(), _>(diesel::result::Error::NotFound) }, @@ -324,7 +423,7 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() .expect_err("Transaction should have failed"), diesel::result::Error::NotFound, ); - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); test_end(crdb).await; } @@ -341,7 +440,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { struct OnlyReturnFromOuterTransaction {} // This outer transaction should succeed immediately... - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); assert_eq!( OnlyReturnFromOuterTransaction {}, conn.transaction_async_with_retry( @@ -372,7 +471,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { .await .expect("Transaction should have succeeded") ); - assert_eq!(conn.transaction_depth().await.unwrap(), 0); + assert_eq!(conn.transaction_depth().unwrap(), 0); test_end(crdb).await; }