Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a 'connection killer' to help make transactions panic & cancel-safe #81

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 113 additions & 28 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,56 @@ fn retryable_error(err: &DieselError) -> bool {
}
}

// Kills the connection when this object is dropped.
//
// Specifically, this sets the diesel "TransactionManager" state
// to broken, which should prevent any subsequent access to the
// connection from succceeding.
//
// This aims to help avoid leaving open transactions alive
// if an asynchronous transaction is cancelled.
struct ConnectionKiller<'a, Conn>
smklein marked this conversation as resolved.
Show resolved Hide resolved
where
Conn: DieselConnection,
{
conn: Option<&'a Connection<Conn>>,
}

impl<'a, Conn> ConnectionKiller<'a, Conn>
where
Conn: DieselConnection,
{
fn new(conn: &'a Connection<Conn>) -> ConnectionKiller<'a, Conn> {
Self { conn: Some(conn) }
}

// Prevents the connection from being killed.
//
// This should be called if a transaction has completed successfully.
fn spare(&mut self) {
self.conn.take();
}
}

impl<'a, Conn> Drop for ConnectionKiller<'a, Conn>
where
Conn: DieselConnection,
{
fn drop(&mut self) {
let Some(conn) = self.conn.take() else {
return;
};

// Ensure that non-transaction operations fail
conn.mark_broken();

// Ensure that transactions fail
let mut conn = conn.inner();
*Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) =
TransactionManagerStatus::InError;
}
}

/// An async variant of [`diesel::r2d2::R2D2Connection`].
#[async_trait]
pub trait AsyncR2D2Connection<Conn>: AsyncConnection<Conn>
Expand All @@ -54,7 +104,7 @@ where

async fn is_broken_async(&mut self) -> bool {
self.as_async_conn()
.run(|conn| Ok::<bool, ()>(conn.is_broken()))
.run(|conn| Ok::<bool, _>(conn.is_broken()))
.await
.unwrap()
}
Expand All @@ -74,24 +124,34 @@ where
#[doc(hidden)]
fn as_async_conn(&self) -> &Connection<Conn>;

// Identifies if the conneciton has been broken
smklein marked this conversation as resolved.
Show resolved Hide resolved
// by an invalid transaction. This should prevent
// future usage.
#[doc(hidden)]
fn is_broken_from_txn(&self) -> bool {
false
}

/// Runs the function `f` in an context where blocking is safe.
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
async fn run<R, Func>(&self, f: Func) -> Result<R, DieselError>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, DieselError> + Send + 'static,
{
let connection = self.get_owned_connection();
connection.run_with_connection(f).await
}

#[doc(hidden)]
async fn run_with_connection<R, E, Func>(self, f: Func) -> Result<R, E>
async fn run_with_connection<R, Func>(self, f: Func) -> Result<R, DieselError>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, DieselError> + Send + 'static,
{
if self.is_broken_from_txn() {
return Err(DieselError::BrokenTransactionManager);
}

spawn_blocking(move || f(&mut *self.as_sync_conn()))
.await
.unwrap() // Propagate panics
Expand All @@ -111,18 +171,15 @@ where
}

#[doc(hidden)]
async fn transaction_depth(&self) -> Result<u32, DieselError> {
let conn = self.get_owned_connection();
fn transaction_depth(&self) -> Result<u32, DieselError> {
let mut conn = self.as_sync_conn();

Self::run_with_connection(conn, |conn| {
match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) {
TransactionManagerStatus::Valid(status) => {
Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0))
}
TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager),
match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) {
TransactionManagerStatus::Valid(status) => {
Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0))
}
})
.await
TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager),
}
}

// Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a
Expand All @@ -132,7 +189,7 @@ where
// we're actually issuing the BEGIN statement here.
#[doc(hidden)]
async fn start_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
if self.transaction_depth().await? != 0 {
if self.transaction_depth()? != 0 {
return Err(DieselError::AlreadyInTransaction);
}
self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
Expand All @@ -147,7 +204,7 @@ where
// we're actually issuing our first SAVEPOINT here.
#[doc(hidden)]
async fn add_retry_savepoint(self: &Arc<Self>) -> Result<(), DieselError> {
match self.transaction_depth().await? {
match self.transaction_depth()? {
0 => return Err(DieselError::NotInTransaction),
1 => (),
_ => return Err(DieselError::AlreadyInTransaction),
Expand Down Expand Up @@ -230,6 +287,12 @@ where
// operation.
let conn = Arc::new(self.get_owned_connection());

// Before we start doing any transaction operations, if we drop
// this future before exiting this function cleanly, we want to
// ensure the connection is killed, rather than existing in a
// "unknown, mid-transaction" state.
let mut killer = ConnectionKiller::new(conn.as_async_conn());

// Refer to CockroachDB's guide on advanced client-side transaction
// retries for the full context:
// https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries
Expand All @@ -241,7 +304,9 @@ where
// TODO: It may be preferable to set this once per connection -- but
// that'll require more interaction with how sessions with the database
// are constructed.
Self::start_transaction(&conn).await?;
Self::start_transaction(&conn).await.inspect_err(|_| {
killer.spare();
})?;
conn.run_with_shared_connection(|conn| {
conn.batch_execute("SET LOCAL force_savepoint_restart = true")
})
Expand All @@ -251,7 +316,7 @@ where
// Add a SAVEPOINT to which we can later return.
Self::add_retry_savepoint(&conn).await?;

let async_conn = Connection(Self::as_async_conn(&conn).0.clone());
let async_conn = Self::as_async_conn(&conn).clone();
match f(async_conn).await {
Ok(value) => {
// The user-level operation succeeded: try to commit the
Expand All @@ -265,6 +330,7 @@ where
if !retryable_error(&err) || !retry().await {
// Bail: ROLLBACK the initial BEGIN statement too.
let _ = Self::rollback_transaction(&conn).await;
killer.spare();
return Err(err);
}
// ROLLBACK happened, we want to retry.
Expand All @@ -273,6 +339,7 @@ where

// Commit the top-level transaction too.
Self::commit_transaction(&conn).await?;
killer.spare();
return Ok(value);
}
Err(user_error) => {
Expand All @@ -281,10 +348,12 @@ where
if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await {
// If we fail while rolling back, prioritize returning
// the ROLLBACK error over the user errors.
return match Self::rollback_transaction(&conn).await {
let res = match Self::rollback_transaction(&conn).await {
Ok(()) => Err(first_rollback_err),
Err(second_rollback_err) => Err(second_rollback_err),
};
killer.spare();
return res;
}

// We rolled back to the retry savepoint, and now want to
Expand All @@ -294,10 +363,12 @@ where
}

// If we aren't retrying, ROLLBACK the BEGIN statement too.
return match Self::rollback_transaction(&conn).await {
let res = match Self::rollback_transaction(&conn).await {
Ok(()) => Err(user_error),
Err(err) => Err(err),
};
killer.spare();
return res;
}
}
}
Expand Down Expand Up @@ -346,6 +417,12 @@ where
// operation.
let conn = Arc::new(self.get_owned_connection());

// Before we start doing any transaction operations, if we drop
// this future before exiting this function cleanly, we want to
// ensure the connection is killed, rather than existing in a
// "unknown, mid-transaction" state.
let mut killer = ConnectionKiller::new(conn.as_async_conn());

// This function mimics the implementation of:
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
//
Expand All @@ -354,7 +431,10 @@ where
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;
.await
.inspect_err(|_| {
killer.spare();
})?;

// TODO: The ideal interface would pass the "async_conn" object to the
// underlying function "f" by reference.
Expand All @@ -368,13 +448,16 @@ where
// enough to be referenceable by a Future, but short enough that we can
// guarantee it doesn't live persist after this function returns, feel
// free to make that change.
let async_conn = Connection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
let async_conn = Self::as_async_conn(&conn).clone();
let res = match f(async_conn).await {
Ok(value) => {
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
.await
.inspect_err(|_| {
killer.spare();
})?;
Ok(value)
}
Err(user_error) => {
Expand All @@ -388,7 +471,9 @@ where
Err(err) => Err(err),
}
}
}
};
killer.spare();
res
}
}

Expand Down
37 changes: 32 additions & 5 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! An async wrapper around a [`diesel::Connection`].

use crate::async_traits::AsyncConnection;
use async_trait::async_trait;
use diesel::r2d2::R2D2Connection;
use diesel::result::Error as DieselError;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use tokio::task;

Expand All @@ -13,19 +16,35 @@ use tokio::task;
/// All blocking methods within this type delegate to
/// [`tokio::task::spawn_blocking`], meaning they won't block
/// any asynchronous work or threads.
pub struct Connection<C>(pub(crate) Arc<Mutex<C>>);
pub struct Connection<C>(Arc<ConnectionInner<C>>);

pub struct ConnectionInner<C> {
pub(crate) inner: Mutex<C>,
pub(crate) broken: AtomicBool,
}

impl<C> Connection<C> {
pub fn new(c: C) -> Self {
Self(Arc::new(Mutex::new(c)))
Self(Arc::new(ConnectionInner {
inner: Mutex::new(c),
broken: AtomicBool::new(false),
}))
}

pub(crate) fn clone(&self) -> Self {
Self(self.0.clone())
}

pub(crate) fn mark_broken(&self) {
self.0.broken.store(true, Ordering::SeqCst);
}

// Accesses the underlying connection.
//
// As this is a blocking mutex, it's recommended to avoid invoking
// this function from an asynchronous context.
pub(crate) fn inner(&self) -> MutexGuard<'_, C> {
self.0.lock().unwrap()
self.0.inner.lock().unwrap()
}
}

Expand All @@ -36,7 +55,11 @@ where
{
#[inline]
async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> {
let diesel_conn = Connection(self.0.clone());
if self.is_broken_from_txn() {
return Err(DieselError::BrokenTransactionManager);
}

let diesel_conn = self.clone();
let query = query.to_string();
task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query))
.await
Expand All @@ -55,7 +78,7 @@ where
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
fn get_owned_connection(&self) -> Self {
Connection(self.0.clone())
self.clone()
}

// Accesses the connection synchronously, protected by a mutex.
Expand All @@ -68,4 +91,8 @@ where
fn as_async_conn(&self) -> &Connection<Conn> {
self
}

fn is_broken_from_txn(&self) -> bool {
self.0.broken.load(Ordering::SeqCst)
}
}
2 changes: 1 addition & 1 deletion src/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ where
}

async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
let c = Connection(conn.0.clone());
let c = conn.clone();
self.run_blocking(move |m| {
m.is_valid(&mut *c.inner())?;
Ok(())
Expand Down
Loading
Loading