Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose transaction_depth through get_transaction_depth() method #3427

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
20 changes: 20 additions & 0 deletions sqlx-core/src/any/connection/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static {

fn start_rollback(&mut self);

/// Returns the current transaction depth.
///
/// Transaction depth indicates the level of nested transactions:
/// - Level 0: No active transaction.
/// - Level 1: A transaction is active.
/// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it.
fn get_transaction_depth(&self) -> usize {
unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later.");
}

/// Checks if the connection is currently in a transaction.
///
/// This method returns `true` if the current transaction depth is greater than 0,
/// indicating that a transaction is active. It returns `false` if the transaction depth is 0,
/// meaning no transaction is active.
#[inline]
fn is_in_transaction(&self) -> bool {
self.get_transaction_depth() != 0
}

/// The number of statements currently cached in the connection.
fn cached_statements_size(&self) -> usize {
0
Expand Down
4 changes: 4 additions & 0 deletions sqlx-core/src/any/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ impl Connection for AnyConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.backend.get_transaction_depth()
}

fn cached_statements_size(&self) -> usize {
self.backend.cached_statements_size()
}
Expand Down
5 changes: 5 additions & 0 deletions sqlx-core/src/any/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use futures_util::future::BoxFuture;

use crate::any::{Any, AnyConnection};
use crate::database::Database;
use crate::error::Error;
use crate::transaction::TransactionManager;

Expand All @@ -24,4 +25,8 @@ impl TransactionManager for AnyTransactionManager {
fn start_rollback(conn: &mut AnyConnection) {
conn.backend.start_rollback()
}

fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize {
conn.backend.get_transaction_depth()
}
}
23 changes: 22 additions & 1 deletion sqlx-core/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::database::{Database, HasStatementCache};
use crate::error::Error;

use crate::transaction::Transaction;
use crate::transaction::{Transaction, TransactionManager};
use futures_core::future::BoxFuture;
use log::LevelFilter;
use std::fmt::Debug;
Expand Down Expand Up @@ -49,6 +49,27 @@ pub trait Connection: Send {
where
Self: Sized;

/// Returns the current transaction depth.
///
/// Transaction depth indicates the level of nested transactions:
/// - Level 0: No active transaction.
/// - Level 1: A transaction is active.
/// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it.
fn get_transaction_depth(&self) -> usize {
// Fallback implementation to avoid breaking changes
<Self::Database as Database>::TransactionManager::get_transaction_depth(self)
}

/// Checks if the connection is currently in a transaction.
///
/// This method returns `true` if the current transaction depth is greater than 0,
/// indicating that a transaction is active. It returns `false` if the transaction depth is 0,
/// meaning no transaction is active.
#[inline]
fn is_in_transaction(&self) -> bool {
self.get_transaction_depth() != 0
}

/// Execute the function inside a transaction.
///
/// If the function returns an error, the transaction will be rolled back. If it does not
Expand Down
8 changes: 8 additions & 0 deletions sqlx-core/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ pub trait TransactionManager {

/// Starts to abort the active transaction or restore from the most recent snapshot.
fn start_rollback(conn: &mut <Self::Database as Database>::Connection);

/// Returns the current transaction depth.
///
/// Transaction depth indicates the level of nested transactions:
/// - Level 0: No active transaction.
/// - Level 1: A transaction is active.
/// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it.
fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize;
}

/// An in-progress database transaction or savepoint.
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl AnyConnectionBackend for MySqlConnection {
MySqlTransactionManager::start_rollback(self)
}

fn get_transaction_depth(&self) -> usize {
MySqlTransactionManager::get_transaction_depth(self)
}

fn shrink_buffers(&mut self) {
Connection::shrink_buffers(self);
}
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ impl Connection for MySqlConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.inner.transaction_depth
}

fn shrink_buffers(&mut self) {
self.inner.stream.shrink_buffers();
}
Expand Down
4 changes: 4 additions & 0 deletions sqlx-mysql/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ impl TransactionManager for MySqlTransactionManager {
conn.inner.transaction_depth = depth - 1;
}
}

fn get_transaction_depth(conn: &MySqlConnection) -> usize {
conn.inner.transaction_depth
}
}
4 changes: 4 additions & 0 deletions sqlx-postgres/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ impl AnyConnectionBackend for PgConnection {
PgTransactionManager::start_rollback(self)
}

fn get_transaction_depth(&self) -> usize {
PgTransactionManager::get_transaction_depth(self)
}

fn shrink_buffers(&mut self) {
Connection::shrink_buffers(self);
}
Expand Down
4 changes: 4 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ impl Connection for PgConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.transaction_depth
}

fn cached_statements_size(&self) -> usize {
self.cache_statement.len()
}
Expand Down
5 changes: 5 additions & 0 deletions sqlx-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures_core::future::BoxFuture;
use sqlx_core::database::Database;

use crate::error::Error;
use crate::executor::Executor;
Expand Down Expand Up @@ -59,6 +60,10 @@ impl TransactionManager for PgTransactionManager {
conn.transaction_depth -= 1;
}
}

fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize {
conn.transaction_depth
}
}

struct Rollback<'c> {
Expand Down
4 changes: 4 additions & 0 deletions sqlx-sqlite/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl AnyConnectionBackend for SqliteConnection {
SqliteTransactionManager::start_rollback(self)
}

fn get_transaction_depth(&self) -> usize {
SqliteTransactionManager::get_transaction_depth(self)
}

fn shrink_buffers(&mut self) {
// NO-OP.
}
Expand Down
1 change: 0 additions & 1 deletion sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ impl EstablishParams {
Ok(ConnectionState {
handle,
statements: Statements::new(self.statement_cache_capacity),
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
Expand Down
12 changes: 5 additions & 7 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ unsafe impl Send for UpdateHookHandler {}
pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

// transaction status
pub(crate) transaction_depth: usize,

pub(crate) statements: Statements,

log_settings: LogSettings,
Expand Down Expand Up @@ -210,11 +207,12 @@ impl Connection for SqliteConnection {
Transaction::begin(self)
}

fn get_transaction_depth(&self) -> usize {
self.worker.shared.get_transaction_depth()
}

fn cached_statements_size(&self) -> usize {
self.worker
.shared
.cached_statements_size
.load(std::sync::atomic::Ordering::Acquire)
self.worker.shared.get_cached_statements_size()
}

fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Expand Down
28 changes: 20 additions & 8 deletions sqlx-sqlite/src/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@ pub(crate) struct ConnectionWorker {
}

pub(crate) struct WorkerSharedState {
pub(crate) cached_statements_size: AtomicUsize,
transaction_depth: AtomicUsize,
cached_statements_size: AtomicUsize,
pub(crate) conn: Mutex<ConnectionState>,
}

impl WorkerSharedState {
pub(crate) fn get_transaction_depth(&self) -> usize {
self.transaction_depth.load(Ordering::Acquire)
}

pub(crate) fn get_cached_statements_size(&self) -> usize {
self.cached_statements_size.load(Ordering::Acquire)
}
}

enum Command {
Prepare {
query: Box<str>,
Expand Down Expand Up @@ -93,6 +104,7 @@ impl ConnectionWorker {
};

let shared = Arc::new(WorkerSharedState {
transaction_depth: AtomicUsize::new(0),
cached_statements_size: AtomicUsize::new(0),
// note: must be fair because in `Command::UnlockDb` we unlock the mutex
// and then immediately try to relock it; an unfair mutex would immediately
Expand Down Expand Up @@ -181,12 +193,12 @@ impl ConnectionWorker {
update_cached_statements_size(&conn, &shared.cached_statements_size);
}
Command::Begin { tx } => {
let depth = conn.transaction_depth;
let depth = shared.transaction_depth.load(Ordering::Acquire);
let res =
conn.handle
.exec(begin_ansi_transaction_sql(depth))
.map(|_| {
conn.transaction_depth += 1;
shared.transaction_depth.fetch_add(1, Ordering::Release);
});
let res_ok = res.is_ok();

Expand All @@ -199,7 +211,7 @@ impl ConnectionWorker {
.handle
.exec(rollback_ansi_transaction_sql(depth + 1))
.map(|_| {
conn.transaction_depth -= 1;
shared.transaction_depth.fetch_sub(1, Ordering::Release);
})
{
// The rollback failed. To prevent leaving the connection
Expand All @@ -211,13 +223,13 @@ impl ConnectionWorker {
}
}
Command::Commit { tx } => {
let depth = conn.transaction_depth;
let depth = shared.transaction_depth.load(Ordering::Acquire);

let res = if depth > 0 {
conn.handle
.exec(commit_ansi_transaction_sql(depth))
.map(|_| {
conn.transaction_depth -= 1;
shared.transaction_depth.fetch_sub(1, Ordering::Release);
})
} else {
Ok(())
Expand All @@ -237,13 +249,13 @@ impl ConnectionWorker {
continue;
}

let depth = conn.transaction_depth;
let depth = shared.transaction_depth.load(Ordering::Acquire);

let res = if depth > 0 {
conn.handle
.exec(rollback_ansi_transaction_sql(depth))
.map(|_| {
conn.transaction_depth -= 1;
shared.transaction_depth.fetch_sub(1, Ordering::Release);
})
} else {
Ok(())
Expand Down
8 changes: 6 additions & 2 deletions sqlx-sqlite/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use futures_core::future::BoxFuture;

use crate::{Sqlite, SqliteConnection};
use sqlx_core::error::Error;
use sqlx_core::transaction::TransactionManager;

use crate::{Sqlite, SqliteConnection};

/// Implementation of [`TransactionManager`] for SQLite.
pub struct SqliteTransactionManager;

Expand All @@ -25,4 +25,8 @@ impl TransactionManager for SqliteTransactionManager {
fn start_rollback(conn: &mut SqliteConnection) {
conn.worker.start_rollback().ok();
}

fn get_transaction_depth(conn: &SqliteConnection) -> usize {
conn.worker.shared.get_transaction_depth()
}
}
5 changes: 5 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
#[sqlx_macros::test]
async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
assert_eq!(conn.get_transaction_depth(), 0);

conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)")
.await?;
Expand All @@ -523,6 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// begin
let mut tx = conn.begin().await?; // transaction
assert_eq!(conn.get_transaction_depth(), 1);

// insert a user
sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)")
Expand All @@ -532,6 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// begin once more
let mut tx2 = tx.begin().await?; // savepoint
assert_eq!(conn.get_transaction_depth(), 2);

// insert another user
sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)")
Expand All @@ -541,6 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// never mind, rollback
tx2.rollback().await?; // roll that one back
assert_eq!(conn.get_transaction_depth(), 1);

// did we really?
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523")
Expand All @@ -551,6 +555,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> {

// actually, commit
tx.commit().await?;
assert_eq!(conn.get_transaction_depth(), 0);

// did we really?
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523")
Expand Down