Skip to content

Commit abedb57

Browse files
committed
Fix #198
This commit introduces a boolean flag that tracks whether we currently execute a transaction related SQL command. We set this flag to true directly before starting the future execution and back to false afterwards. This enables us to detect the cancellation of such futures while the command is executed. In such cases we consider the connection to be broken as we do not know how much of the command was actually executed.
1 parent 35cb1ad commit abedb57

File tree

1 file changed

+83
-162
lines changed

1 file changed

+83
-162
lines changed

src/transaction_manager.rs

Lines changed: 83 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use diesel::QueryResult;
88
use scoped_futures::ScopedBoxFuture;
99
use std::borrow::Cow;
1010
use std::num::NonZeroU32;
11+
use std::sync::atomic::{AtomicBool, Ordering};
12+
use std::sync::Arc;
1113

1214
use crate::AsyncConnection;
1315
// TODO: refactor this to share more code with diesel
@@ -88,24 +90,31 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
8890
/// in an error state.
8991
#[doc(hidden)]
9092
fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
91-
match Self::transaction_manager_status_mut(conn).transaction_state() {
92-
// all transactions are closed
93-
// so we don't consider this connection broken
94-
Ok(ValidTransactionManagerStatus {
95-
in_transaction: None,
96-
..
97-
}) => false,
98-
// The transaction manager is in an error state
99-
// Therefore we consider this connection broken
100-
Err(_) => true,
101-
// The transaction manager contains a open transaction
102-
// we do consider this connection broken
103-
// if that transaction was not opened by `begin_test_transaction`
104-
Ok(ValidTransactionManagerStatus {
105-
in_transaction: Some(s),
106-
..
107-
}) => !s.test_transaction,
108-
}
93+
check_broken_transaction_state(conn)
94+
}
95+
}
96+
97+
fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
98+
where
99+
Conn: AsyncConnection,
100+
{
101+
match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
102+
// all transactions are closed
103+
// so we don't consider this connection broken
104+
Ok(ValidTransactionManagerStatus {
105+
in_transaction: None,
106+
..
107+
}) => false,
108+
// The transaction manager is in an error state
109+
// Therefore we consider this connection broken
110+
Err(_) => true,
111+
// The transaction manager contains a open transaction
112+
// we do consider this connection broken
113+
// if that transaction was not opened by `begin_test_transaction`
114+
Ok(ValidTransactionManagerStatus {
115+
in_transaction: Some(s),
116+
..
117+
}) => !s.test_transaction,
109118
}
110119
}
111120

@@ -114,147 +123,23 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
114123
#[derive(Default, Debug)]
115124
pub struct AnsiTransactionManager {
116125
pub(crate) status: TransactionManagerStatus,
126+
// this boolean flag tracks whether we are currently in the process
127+
// of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
128+
// if we ever encounter a situation where this flag is set
129+
// while the connection is returned to a pool
130+
// that means the connection is broken as someone dropped the
131+
// transaction future while these commands where executed
132+
// and we cannot know the connection state anymore
133+
//
134+
// We ensure this by wrapping all calls to `.await`
135+
// into `AnsiTransactionManager::critical_transaction_block`
136+
// below
137+
//
138+
// See https://github.com/weiznich/diesel_async/issues/198 for
139+
// details
140+
pub(crate) is_broken: Arc<AtomicBool>,
117141
}
118142

119-
// /// Status of the transaction manager
120-
// #[derive(Debug)]
121-
// pub enum TransactionManagerStatus {
122-
// /// Valid status, the manager can run operations
123-
// Valid(ValidTransactionManagerStatus),
124-
// /// Error status, probably following a broken connection. The manager will no longer run operations
125-
// InError,
126-
// }
127-
128-
// impl Default for TransactionManagerStatus {
129-
// fn default() -> Self {
130-
// TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
131-
// }
132-
// }
133-
134-
// impl TransactionManagerStatus {
135-
// /// Returns the transaction depth if the transaction manager's status is valid, or returns
136-
// /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
137-
// pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
138-
// match self {
139-
// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
140-
// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
141-
// }
142-
// }
143-
144-
// /// If in transaction and transaction manager is not broken, registers that the
145-
// /// connection can not be used anymore until top-level transaction is rolled back
146-
// pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
147-
// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
148-
// in_transaction:
149-
// Some(InTransactionStatus {
150-
// top_level_transaction_requires_rollback,
151-
// ..
152-
// }),
153-
// }) = self
154-
// {
155-
// *top_level_transaction_requires_rollback = true;
156-
// }
157-
// }
158-
159-
// /// Sets the transaction manager status to InError
160-
// ///
161-
// /// Subsequent attempts to use transaction-related features will result in a
162-
// /// [`Error::BrokenTransactionManager`] error
163-
// pub fn set_in_error(&mut self) {
164-
// *self = TransactionManagerStatus::InError
165-
// }
166-
167-
// fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
168-
// match self {
169-
// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
170-
// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
171-
// }
172-
// }
173-
174-
// pub(crate) fn set_test_transaction_flag(&mut self) {
175-
// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
176-
// in_transaction: Some(s),
177-
// }) = self
178-
// {
179-
// s.test_transaction = true;
180-
// }
181-
// }
182-
// }
183-
184-
// /// Valid transaction status for the manager. Can return the current transaction depth
185-
// #[allow(missing_copy_implementations)]
186-
// #[derive(Debug, Default)]
187-
// pub struct ValidTransactionManagerStatus {
188-
// in_transaction: Option<InTransactionStatus>,
189-
// }
190-
191-
// #[allow(missing_copy_implementations)]
192-
// #[derive(Debug)]
193-
// struct InTransactionStatus {
194-
// transaction_depth: NonZeroU32,
195-
// top_level_transaction_requires_rollback: bool,
196-
// test_transaction: bool,
197-
// }
198-
199-
// impl ValidTransactionManagerStatus {
200-
// /// Return the current transaction depth
201-
// ///
202-
// /// This value is `None` if no current transaction is running
203-
// /// otherwise the number of nested transactions is returned.
204-
// pub fn transaction_depth(&self) -> Option<NonZeroU32> {
205-
// self.in_transaction.as_ref().map(|it| it.transaction_depth)
206-
// }
207-
208-
// /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
209-
// /// `Ok(())`
210-
// pub fn change_transaction_depth(
211-
// &mut self,
212-
// transaction_depth_change: TransactionDepthChange,
213-
// ) -> QueryResult<()> {
214-
// match (&mut self.in_transaction, transaction_depth_change) {
215-
// (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
216-
// // Can be replaced with saturating_add directly on NonZeroU32 once
217-
// // <https://github.com/rust-lang/rust/issues/84186> is stable
218-
// in_transaction.transaction_depth =
219-
// NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
220-
// .expect("nz + nz is always non-zero");
221-
// Ok(())
222-
// }
223-
// (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
224-
// // This sets `transaction_depth` to `None` as soon as we reach zero
225-
// match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
226-
// Some(depth) => in_transaction.transaction_depth = depth,
227-
// None => self.in_transaction = None,
228-
// }
229-
// Ok(())
230-
// }
231-
// (None, TransactionDepthChange::IncreaseDepth) => {
232-
// self.in_transaction = Some(InTransactionStatus {
233-
// transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
234-
// top_level_transaction_requires_rollback: false,
235-
// test_transaction: false,
236-
// });
237-
// Ok(())
238-
// }
239-
// (None, TransactionDepthChange::DecreaseDepth) => {
240-
// // We screwed up something somewhere
241-
// // we cannot decrease the transaction count if
242-
// // we are not inside a transaction
243-
// Err(Error::NotInTransaction)
244-
// }
245-
// }
246-
// }
247-
// }
248-
249-
// /// Represents a change to apply to the depth of a transaction
250-
// #[derive(Debug, Clone, Copy)]
251-
// pub enum TransactionDepthChange {
252-
// /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
253-
// IncreaseDepth,
254-
// /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
255-
// DecreaseDepth,
256-
// }
257-
258143
impl AnsiTransactionManager {
259144
fn get_transaction_state<Conn>(
260145
conn: &mut Conn,
@@ -274,17 +159,34 @@ impl AnsiTransactionManager {
274159
where
275160
Conn: AsyncConnection<TransactionManager = Self>,
276161
{
162+
let is_broken = conn.transaction_state().is_broken.clone();
277163
let state = Self::get_transaction_state(conn)?;
278164
match state.transaction_depth() {
279165
None => {
280-
conn.batch_execute(sql).await?;
166+
Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
281167
Self::get_transaction_state(conn)?
282168
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
283169
Ok(())
284170
}
285171
Some(_depth) => Err(Error::AlreadyInTransaction),
286172
}
287173
}
174+
175+
// This function should be used to await any connection
176+
// related future in our transaction manager implementation
177+
//
178+
// It takes care of tracking entering and exiting executing the future
179+
// which in turn is used to determine if it's safe to still use
180+
// the connection in the event of a canceled transaction execution
181+
async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
182+
where
183+
F: std::future::Future,
184+
{
185+
is_broken.store(true, Ordering::Relaxed);
186+
let res = f.await;
187+
is_broken.store(false, Ordering::Relaxed);
188+
res
189+
}
288190
}
289191

290192
#[async_trait::async_trait]
@@ -308,7 +210,11 @@ where
308210
.unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
309211
conn.instrumentation()
310212
.on_connection_event(InstrumentationEvent::begin_transaction(depth));
311-
conn.batch_execute(&start_transaction_sql).await?;
213+
Self::critical_transaction_block(
214+
&conn.transaction_state().is_broken.clone(),
215+
conn.batch_execute(&start_transaction_sql),
216+
)
217+
.await?;
312218
Self::get_transaction_state(conn)?
313219
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
314220

@@ -344,7 +250,10 @@ where
344250
conn.instrumentation()
345251
.on_connection_event(InstrumentationEvent::rollback_transaction(depth));
346252

347-
match conn.batch_execute(&rollback_sql).await {
253+
let is_broken = conn.transaction_state().is_broken.clone();
254+
255+
match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
256+
{
348257
Ok(()) => {
349258
match Self::get_transaction_state(conn)?
350259
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
@@ -429,7 +338,9 @@ where
429338
conn.instrumentation()
430339
.on_connection_event(InstrumentationEvent::commit_transaction(depth));
431340

432-
match conn.batch_execute(&commit_sql).await {
341+
let is_broken = conn.transaction_state().is_broken.clone();
342+
343+
match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
433344
Ok(()) => {
434345
match Self::get_transaction_state(conn)?
435346
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
@@ -453,7 +364,12 @@ where
453364
..
454365
}) = conn.transaction_state().status
455366
{
456-
match Self::rollback_transaction(conn).await {
367+
match Self::critical_transaction_block(
368+
&is_broken,
369+
Self::rollback_transaction(conn),
370+
)
371+
.await
372+
{
457373
Ok(()) => {}
458374
Err(rollback_error) => {
459375
conn.transaction_state().status.set_in_error();
@@ -472,4 +388,9 @@ where
472388
fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
473389
&mut conn.transaction_state().status
474390
}
391+
392+
fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
393+
conn.transaction_state().is_broken.load(Ordering::Relaxed)
394+
|| check_broken_transaction_state(conn)
395+
}
475396
}

0 commit comments

Comments
 (0)