diff --git a/types/src/v0/impls/l1.rs b/types/src/v0/impls/l1.rs index 2f571d71d2..0d12dbb6ad 100644 --- a/types/src/v0/impls/l1.rs +++ b/types/src/v0/impls/l1.rs @@ -2,10 +2,7 @@ use std::{ cmp::{min, Ordering}, fmt::Debug, num::NonZeroUsize, - sync::{ - atomic::{self, AtomicBool}, - Arc, - }, + sync::Arc, time::Duration, }; @@ -33,7 +30,7 @@ use tracing::Instrument; use url::Url; use super::{L1BlockInfo, L1State, L1UpdateTask, RpcClient}; -use crate::{FeeInfo, L1Client, L1ClientOptions, L1Event, L1Snapshot, WsConn}; +use crate::{FeeInfo, L1Client, L1ClientOptions, L1Event, L1ReconnectTask, L1Snapshot}; impl PartialOrd for L1BlockInfo { fn partial_cmp(&self, other: &Self) -> Option { @@ -88,14 +85,18 @@ impl RpcClient { async fn ws(url: Url, retry_delay: Duration) -> anyhow::Result { Ok(Self::Ws { - conn: Arc::new(RwLock::new(WsConn { - inner: Ws::connect(url.clone()).await?, - resetting: AtomicBool::new(false), - })), + conn: Arc::new(RwLock::new(Ws::connect(url.clone()).await?)), + reconnect: Default::default(), retry_delay, url, }) } + + async fn stop(&self) { + if let Self::Ws { reconnect, .. } = self { + *reconnect.lock().await = L1ReconnectTask::Cancelled; + } + } } #[async_trait] @@ -111,6 +112,7 @@ impl JsonRpcClient for RpcClient { Self::Http(client) => client.request(method, params).await?, Self::Ws { conn, + reconnect, url, retry_delay, } => { @@ -122,44 +124,53 @@ impl JsonRpcClient for RpcClient { .map_err(|_| { ProviderError::CustomError("connection closed; reset in progress".into()) })?; - match conn_guard.inner.request(method, params).await { + match conn_guard.request(method, params).await { Ok(res) => res, Err(err @ WsClientError::UnexpectedClose) => { // If the WebSocket connection is closed, try to reopen it. - if !conn_guard.resetting.swap(true, atomic::Ordering::SeqCst) { - // We are the first one to try and reset this connection, so let's do - // it. We spawn a separate task to do it, because resetting it might - // take a long time (especially if it was closed because of a provider - // outage), and it is not good to block indefinitely in this low-level - // request API. - let conn = conn.clone(); - let url = url.clone(); - let retry_delay = *retry_delay; - let span = tracing::warn_span!("ws resetter"); - spawn( - async move { - tracing::warn!("ws connection closed, trying to reset"); - let inner = loop { - match Ws::connect(url.clone()).await { - Ok(inner) => break inner, - Err(err) => { - tracing::warn!("failed to reconnect: {err:#}"); - sleep(retry_delay).await; + if let Ok(mut reconnect_guard) = reconnect.try_lock() { + if matches!(*reconnect_guard, L1ReconnectTask::Idle) { + // No one is currently resetting this connection, so it's up to us. + let conn = conn.clone(); + let reconnect = reconnect.clone(); + let url = url.clone(); + let retry_delay = *retry_delay; + let span = tracing::warn_span!("ws resetter"); + *reconnect_guard = L1ReconnectTask::Reconnecting(spawn( + async move { + tracing::warn!("ws connection closed, trying to reset"); + let new_conn = loop { + match Ws::connect(url.clone()).await { + Ok(conn) => break conn, + Err(err) => { + tracing::warn!("failed to reconnect: {err:#}"); + sleep(retry_delay).await; + } } + }; + + // Reset the connection, and set the reconnect task back to + // idle, so that the connection can be reset again if + // needed. + let mut conn = conn.write().await; + let mut reconnect = reconnect.lock().await; + *conn = new_conn; + if !matches!(*reconnect, L1ReconnectTask::Cancelled) { + *reconnect = L1ReconnectTask::Idle; } - }; - let new_conn = WsConn { - inner, - resetting: AtomicBool::new(false), - }; - *conn.write().await = new_conn; - tracing::info!("ws connection successfully reestablished"); - } - .instrument(span), - ); + + tracing::info!("ws connection successfully reestablished"); + } + .instrument(span), + )); + } } else { - // Otherwise, if we couldn't get a write lock, it is because someone - // else is already resetting this connection, so we have nothing to do. + // If we fail to get a lock on the reconnect task, it can only mean one + // of two things: + // * someone else is already preparing to reset the connection + // * the entire L1 client is being shut down + // In either case, we don't want/need to reset the connection ourselves, + // so nothing to do here. } Err(err)? } @@ -190,7 +201,6 @@ impl PubsubClient for RpcClient { .map_err(|_| { ProviderError::CustomError("connection closed; reset in progress".into()) })? - .inner .subscribe(id)?), } } @@ -211,12 +221,20 @@ impl PubsubClient for RpcClient { .map_err(|_| { ProviderError::CustomError("connection closed; reset in progress".into()) })? - .inner .unsubscribe(id)?), } } } +impl Drop for L1ReconnectTask { + fn drop(&mut self) { + if let Self::Reconnecting(task) = self { + tracing::info!("cancelling L1 reconnect task"); + task.abort(); + } + } +} + impl Drop for L1UpdateTask { fn drop(&mut self) { if let Some(task) = self.0.get_mut().take() { @@ -316,6 +334,7 @@ impl L1Client { if let Some(update_task) = self.update_task.0.lock().await.take() { update_task.abort(); } + (*self.provider).as_ref().stop().await; } pub fn provider(&self) -> &impl Middleware { @@ -1051,7 +1070,7 @@ mod test { setup_test(); let port = pick_unused_port().unwrap(); - let anvil = Anvil::new().block_time(1u32).port(port).spawn(); + let mut anvil = Anvil::new().block_time(1u32).port(port).spawn(); let provider = Provider::new( RpcClient::ws(anvil.ws_endpoint().parse().unwrap(), Duration::from_secs(1)) .await @@ -1061,33 +1080,39 @@ mod test { // Check the provider is working. assert_eq!(provider.get_chainid().await.unwrap(), 31337.into()); - // Disconnect the WebSocket and reconnect it. Technically this spawns a whole new Anvil - // chain, but for the purposes of this test it should look to the client like an L1 server - // closing a WebSocket connection. - drop(anvil); - let err = provider.get_chainid().await.unwrap_err(); - tracing::info!("L1 request failed as expected with closed connection: {err:#}"); - - // Let the connection stay down for a little while: Ethers internally tries to reconnect, - // and starting up to fast again might hit that and cause a false positive. The problem is, - // Ethers doesn't try very hard, and if we wait a bit, we will test the worst possible case - // where the internal retry logic gives up and just kills the whole provider. - tracing::info!("sleep 5"); - sleep(Duration::from_secs(5)).await; - - // Once a connection is reestablished, the provider will eventually work again. - tracing::info!("restarting L1"); - let _anvil = Anvil::new().block_time(1u32).port(port).spawn(); - // Give a bit of time for the provider to reconnect. - for retry in 0..5 { - if let Ok(chain_id) = provider.get_chainid().await { - assert_eq!(chain_id, 31337.into()); - return; + // Test two reconnects in a row, to ensure the reconnecter is reset properly after the first + // one. + 'outer: for i in 0..2 { + tracing::info!("reconnect {i}"); + // Disconnect the WebSocket and reconnect it. Technically this spawns a whole new Anvil + // chain, but for the purposes of this test it should look to the client like an L1 + // server closing a WebSocket connection. + drop(anvil); + let err = provider.get_chainid().await.unwrap_err(); + tracing::info!("L1 request failed as expected with closed connection: {err:#}"); + + // Let the connection stay down for a little while: Ethers internally tries to + // reconnect, and starting up to fast again might hit that and cause a false positive. + // The problem is, Ethers doesn't try very hard, and if we wait a bit, we will test the + // worst possible case where the internal retry logic gives up and just kills the whole + // provider. + tracing::info!("sleep 5"); + sleep(Duration::from_secs(5)).await; + + // Once a connection is reestablished, the provider will eventually work again. + tracing::info!("restarting L1"); + anvil = Anvil::new().block_time(1u32).port(port).spawn(); + // Give a bit of time for the provider to reconnect. + for retry in 0..5 { + if let Ok(chain_id) = provider.get_chainid().await { + assert_eq!(chain_id, 31337.into()); + continue 'outer; + } + tracing::warn!(retry, "waiting for provider to reconnect"); + sleep(Duration::from_secs(1)).await; } - tracing::warn!(retry, "waiting for provider to reconnect"); - sleep(Duration::from_secs(1)).await; + panic!("request never succeeded after reconnect"); } - panic!("request never succeeded after reconnect"); } #[tokio::test(flavor = "multi_thread")] diff --git a/types/src/v0/mod.rs b/types/src/v0/mod.rs index b84b20fe45..59a44cf65d 100644 --- a/types/src/v0/mod.rs +++ b/types/src/v0/mod.rs @@ -123,7 +123,7 @@ reexport_unchanged_types!( ViewBasedUpgrade, BlockSize, ); -pub(crate) use v0_3::{L1Event, L1State, L1UpdateTask, RpcClient, WsConn}; +pub(crate) use v0_3::{L1Event, L1ReconnectTask, L1State, L1UpdateTask, RpcClient}; #[derive( Clone, Copy, Debug, Default, Hash, Eq, PartialEq, PartialOrd, Ord, Deserialize, Serialize, diff --git a/types/src/v0/v0_1/l1.rs b/types/src/v0/v0_1/l1.rs index dca0f38f11..5259aaa9a2 100644 --- a/types/src/v0/v0_1/l1.rs +++ b/types/src/v0/v0_1/l1.rs @@ -7,7 +7,6 @@ use ethers::{ }; use lru::LruCache; use serde::{Deserialize, Serialize}; -use std::sync::atomic::AtomicBool; use std::{num::NonZeroUsize, sync::Arc, time::Duration}; use tokio::{ sync::{Mutex, RwLock}, @@ -118,18 +117,13 @@ pub struct L1Client { pub(crate) enum RpcClient { Http(Http), Ws { - conn: Arc>, + conn: Arc>, + reconnect: Arc>, url: Url, retry_delay: Duration, }, } -#[derive(Debug)] -pub(crate) struct WsConn { - pub(crate) inner: Ws, - pub(crate) resetting: AtomicBool, -} - /// In-memory view of the L1 state, updated asynchronously. #[derive(Debug)] pub(crate) struct L1State { @@ -145,3 +139,11 @@ pub(crate) enum L1Event { #[derive(Debug, Default)] pub(crate) struct L1UpdateTask(pub(crate) Mutex>>); + +#[derive(Debug, Default)] +pub(crate) enum L1ReconnectTask { + Reconnecting(JoinHandle<()>), + #[default] + Idle, + Cancelled, +} diff --git a/types/src/v0/v0_3/mod.rs b/types/src/v0/v0_3/mod.rs index c475e6d14a..86a0150a37 100644 --- a/types/src/v0/v0_3/mod.rs +++ b/types/src/v0/v0_3/mod.rs @@ -12,7 +12,7 @@ pub use super::v0_1::{ UpgradeType, ViewBasedUpgrade, BLOCK_MERKLE_TREE_HEIGHT, FEE_MERKLE_TREE_HEIGHT, NS_ID_BYTE_LEN, NS_OFFSET_BYTE_LEN, NUM_NSS_BYTE_LEN, NUM_TXS_BYTE_LEN, TX_OFFSET_BYTE_LEN, }; -pub(crate) use super::v0_1::{L1Event, L1State, L1UpdateTask, RpcClient, WsConn}; +pub(crate) use super::v0_1::{L1Event, L1ReconnectTask, L1State, L1UpdateTask, RpcClient}; pub const VERSION: Version = Version { major: 0, minor: 3 };