From 7e96e50f539d1df0570f4ca168c212185f71cd59 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 25 Apr 2025 10:50:57 +0200 Subject: [PATCH] Also instrument the postgres connection builder This commit fixes an issue where we did not emitt a `BeginTransaction` event for transactions created with the postgres specific connection builder. Fix #229 --- src/transaction_manager.rs | 22 ++++++++++++++-------- tests/instrumentation.rs | 30 +++++++++++++++++++++++++++++- tests/lib.rs | 8 ++++++-- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/transaction_manager.rs b/src/transaction_manager.rs index 6d4d984..cd5bc5b 100644 --- a/src/transaction_manager.rs +++ b/src/transaction_manager.rs @@ -169,15 +169,21 @@ impl AnsiTransactionManager { { let is_broken = conn.transaction_state().is_broken.clone(); let state = Self::get_transaction_state(conn)?; - match state.transaction_depth() { - None => { - Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?; - Self::get_transaction_state(conn)? - .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?; - Ok(()) - } - Some(_depth) => Err(Error::AlreadyInTransaction), + if let Some(_depth) = state.transaction_depth() { + return Err(Error::AlreadyInTransaction); } + let instrumentation_depth = NonZeroU32::new(1); + + conn.instrumentation() + .on_connection_event(InstrumentationEvent::begin_transaction( + instrumentation_depth.expect("We know that 1 is not zero"), + )); + + // Keep remainder of this method in sync with `begin_transaction()`. + Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?; + Self::get_transaction_state(conn)? + .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?; + Ok(()) } // This function should be used to await any connection diff --git a/tests/instrumentation.rs b/tests/instrumentation.rs index 039ebce..e14a0c3 100644 --- a/tests/instrumentation.rs +++ b/tests/instrumentation.rs @@ -54,9 +54,14 @@ impl From> for Event { } async fn setup_test_case() -> (Arc>>, TestConnection) { + setup_test_case_with_connection(connection_with_sean_and_tess_in_users_table().await) +} + +fn setup_test_case_with_connection( + mut conn: TestConnection, +) -> (Arc>>, TestConnection) { let events = Arc::new(Mutex::new(Vec::::new())); let events_to_check = events.clone(); - let mut conn = connection_with_sean_and_tess_in_users_table().await; conn.set_instrumentation(move |event: InstrumentationEvent<'_>| { events.lock().unwrap().push(event.into()); }); @@ -255,3 +260,26 @@ async fn check_events_transaction_nested() { assert_matches!(events[10], Event::StartQuery { .. }); assert_matches!(events[11], Event::FinishQuery { .. }); } + +#[cfg(feature = "postgres")] +#[tokio::test] +async fn check_events_transaction_builder() { + use crate::connection_without_transaction; + use diesel::result::Error; + use scoped_futures::ScopedFutureExt; + + let (events_to_check, mut conn) = + setup_test_case_with_connection(connection_without_transaction().await); + conn.build_transaction() + .run(|_tx| async move { Ok::<(), Error>(()) }.scope_boxed()) + .await + .unwrap(); + let events = events_to_check.lock().unwrap(); + assert_eq!(events.len(), 6, "{:?}", events); + assert_matches!(events[0], Event::BeginTransaction { .. }); + assert_matches!(events[1], Event::StartQuery { .. }); + assert_matches!(events[2], Event::FinishQuery { .. }); + assert_matches!(events[3], Event::CommitTransaction { .. }); + assert_matches!(events[4], Event::StartQuery { .. }); + assert_matches!(events[5], Event::FinishQuery { .. }); +} diff --git a/tests/lib.rs b/tests/lib.rs index a3cc806..c305cf3 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -203,8 +203,7 @@ async fn setup(connection: &mut TestConnection) { } async fn connection() -> TestConnection { - let db_url = std::env::var("DATABASE_URL").unwrap(); - let mut conn = TestConnection::establish(&db_url).await.unwrap(); + let mut conn = connection_without_transaction().await; if cfg!(feature = "postgres") { // postgres allows to modify the schema inside of a transaction conn.begin_test_transaction().await.unwrap(); @@ -218,3 +217,8 @@ async fn connection() -> TestConnection { } conn } + +async fn connection_without_transaction() -> TestConnection { + let db_url = std::env::var("DATABASE_URL").unwrap(); + TestConnection::establish(&db_url).await.unwrap() +}