Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronously cancel tasks #3883

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions crates/task-impls/src/consensus/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ use utils::anytrace::*;

use super::ConsensusTaskState;
use crate::{
consensus::Versions,
events::HotShotEvent,
helpers::{broadcast_event, cancel_task},
consensus::Versions, events::HotShotEvent, helpers::broadcast_event,
vote_collection::handle_vote,
};

Expand Down Expand Up @@ -170,11 +168,7 @@ pub(crate) async fn handle_view_change<
});

// Cancel the old timeout task
cancel_task(std::mem::replace(
&mut task_state.timeout_task,
new_timeout_task,
))
.await;
std::mem::replace(&mut task_state.timeout_task, new_timeout_task).abort();

let consensus_reader = task_state.consensus.read().await;
consensus_reader
Expand Down
10 changes: 3 additions & 7 deletions crates/task-impls/src/consensus/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use utils::anytrace::Result;
use self::handlers::{
handle_quorum_vote_recv, handle_timeout, handle_timeout_vote_recv, handle_view_change,
};
use crate::{events::HotShotEvent, helpers::cancel_task, vote_collection::VoteCollectorsMap};
use crate::{events::HotShotEvent, vote_collection::VoteCollectorsMap};

/// Event handlers for use in the `handle` method.
mod handlers;
Expand Down Expand Up @@ -167,12 +167,8 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
}

/// Joins all subtasks.
async fn cancel_subtasks(&mut self) {
fn cancel_subtasks(&mut self) {
// Cancel the old timeout task
cancel_task(std::mem::replace(
&mut self.timeout_task,
tokio::spawn(async {}),
))
.await;
std::mem::replace(&mut self.timeout_task, tokio::spawn(async {})).abort();
}
}
2 changes: 1 addition & 1 deletion crates/task-impls/src/da.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,5 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
self.handle(event, sender.clone()).await
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}
7 changes: 1 addition & 6 deletions crates/task-impls/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use hotshot_types::{
utils::{Terminator, View, ViewInner},
vote::{Certificate, HasViewNumber},
};
use tokio::{task::JoinHandle, time::timeout};
use tokio::time::timeout;
use tracing::instrument;
use utils::anytrace::*;

Expand Down Expand Up @@ -640,11 +640,6 @@ pub(crate) async fn validate_proposal_view_and_certs<
Ok(())
}

/// Cancel a task
pub async fn cancel_task<T>(task: JoinHandle<T>) {
task.abort();
}

/// Helper function to send events and log errors
pub async fn broadcast_event<E: Clone + std::fmt::Debug>(event: E, sender: &Sender<E>) {
match sender.broadcast_direct(event).await {
Expand Down
16 changes: 8 additions & 8 deletions crates/task-impls/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use std::{
use async_broadcast::{Receiver, Sender};
use async_lock::RwLock;
use async_trait::async_trait;
use futures::future::join_all;
use hotshot_task::task::TaskState;
use hotshot_types::{
consensus::Consensus,
Expand All @@ -39,7 +38,7 @@ use utils::anytrace::*;

use crate::{
events::{HotShotEvent, HotShotTaskCompleted},
helpers::{broadcast_event, cancel_task},
helpers::broadcast_event,
};

/// the network message task state
Expand Down Expand Up @@ -232,7 +231,7 @@ impl<
Ok(())
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}

impl<
Expand Down Expand Up @@ -340,13 +339,14 @@ impl<
/// Cancel all tasks for previous views
pub fn cancel_tasks(&mut self, view: TYPES::View) {
let keep = self.transmit_tasks.split_off(&view);
let mut cancel = Vec::new();

while let Some((_, tasks)) = self.transmit_tasks.pop_first() {
let mut to_cancel = tasks.into_iter().map(cancel_task).collect();
cancel.append(&mut to_cancel);
for task in tasks {
task.abort();
}
}

self.transmit_tasks = keep;
spawn(async move { join_all(cancel).await });
}

/// Parses a `HotShotEvent` and returns a tuple of: (sender's public key, `MessageKind`, `TransmitType`)
Expand Down Expand Up @@ -801,7 +801,7 @@ pub mod test {
Ok(())
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}

impl<
Expand Down
19 changes: 7 additions & 12 deletions crates/task-impls/src/quorum_proposal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use async_broadcast::{Receiver, Sender};
use async_lock::RwLock;
use async_trait::async_trait;
use either::Either;
use futures::future::join_all;
use hotshot_task::{
dependency::{AndDependency, EventDependency, OrDependency},
dependency_task::DependencyTask,
Expand All @@ -34,10 +33,7 @@ use tracing::instrument;
use utils::anytrace::*;

use self::handlers::{ProposalDependency, ProposalDependencyHandle};
use crate::{
events::HotShotEvent,
helpers::{broadcast_event, cancel_task},
};
use crate::{events::HotShotEvent, helpers::broadcast_event};

mod handlers;

Expand Down Expand Up @@ -350,7 +346,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
for view in (*self.latest_proposed_view + 1)..=(*new_view) {
if let Some(dependency) = self.proposal_dependencies.remove(&TYPES::View::new(view))
{
cancel_task(dependency).await;
dependency.abort();
}
}

Expand Down Expand Up @@ -527,21 +523,20 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
)?;
}
HotShotEvent::ViewChange(view) | HotShotEvent::Timeout(view) => {
self.cancel_tasks(*view).await;
self.cancel_tasks(*view);
}
_ => {}
}
Ok(())
}

/// Cancel all tasks the consensus tasks has spawned before the given view
pub async fn cancel_tasks(&mut self, view: TYPES::View) {
pub fn cancel_tasks(&mut self, view: TYPES::View) {
let keep = self.proposal_dependencies.split_off(&view);
let mut cancel = Vec::new();
while let Some((_, task)) = self.proposal_dependencies.pop_first() {
cancel.push(cancel_task(task));
task.abort();
}
self.proposal_dependencies = keep;
join_all(cancel).await;
}
}

Expand All @@ -560,7 +555,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
self.handle(event, receiver.clone(), sender.clone()).await
}

async fn cancel_subtasks(&mut self) {
fn cancel_subtasks(&mut self) {
while let Some((_, handle)) = self.proposal_dependencies.pop_first() {
handle.abort();
}
Expand Down
11 changes: 5 additions & 6 deletions crates/task-impls/src/quorum_proposal_recv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use vbs::version::Version;
use self::handlers::handle_quorum_proposal_recv;
use crate::{
events::{HotShotEvent, ProposalMissing},
helpers::{broadcast_event, cancel_task, parent_leaf_and_state},
helpers::{broadcast_event, parent_leaf_and_state},
};
/// Event handlers for this task.
mod handlers;
Expand Down Expand Up @@ -108,13 +108,12 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
/// Cancel all tasks the consensus tasks has spawned before the given view
pub fn cancel_tasks(&mut self, view: TYPES::View) {
let keep = self.spawned_tasks.split_off(&view);
let mut cancel = Vec::new();
while let Some((_, tasks)) = self.spawned_tasks.pop_first() {
let mut to_cancel = tasks.into_iter().map(cancel_task).collect();
cancel.append(&mut to_cancel);
for task in tasks {
task.abort();
}
}
self.spawned_tasks = keep;
tokio::spawn(async move { join_all(cancel).await });
}

/// Handles all consensus events relating to propose and vote-enabling events.
Expand Down Expand Up @@ -192,7 +191,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
Ok(())
}

async fn cancel_subtasks(&mut self) {
fn cancel_subtasks(&mut self) {
while !self.spawned_tasks.is_empty() {
let Some((_, handles)) = self.spawned_tasks.pop_first() else {
break;
Expand Down
10 changes: 5 additions & 5 deletions crates/task-impls/src/quorum_vote/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use vbs::version::StaticVersionType;

use crate::{
events::HotShotEvent,
helpers::{broadcast_event, cancel_task},
helpers::broadcast_event,
quorum_vote::handlers::{handle_quorum_proposal_validated, submit_vote, update_shared_state},
};

Expand Down Expand Up @@ -395,7 +395,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> QuorumVoteTaskS
// Cancel the old dependency tasks.
for view in *self.latest_voted_view..(*new_view) {
if let Some(dependency) = self.vote_dependencies.remove(&TYPES::View::new(view)) {
cancel_task(dependency).await;
dependency.abort();
tracing::debug!("Vote dependency removed for view {:?}", view);
}
}
Expand Down Expand Up @@ -578,7 +578,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> QuorumVoteTaskS
// cancel old tasks
let current_tasks = self.vote_dependencies.split_off(&view);
while let Some((_, task)) = self.vote_dependencies.pop_last() {
cancel_task(task).await;
task.abort();
}
self.vote_dependencies = current_tasks;
}
Expand All @@ -587,7 +587,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> QuorumVoteTaskS
// cancel old tasks
let current_tasks = self.vote_dependencies.split_off(&view);
while let Some((_, task)) = self.vote_dependencies.pop_last() {
cancel_task(task).await;
task.abort();
}
self.vote_dependencies = current_tasks;
}
Expand Down Expand Up @@ -720,7 +720,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
self.handle(event, receiver.clone(), sender.clone()).await
}

async fn cancel_subtasks(&mut self) {
fn cancel_subtasks(&mut self) {
while let Some((_, handle)) = self.vote_dependencies.pop_last() {
handle.abort();
}
Expand Down
4 changes: 2 additions & 2 deletions crates/task-impls/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub struct NetworkRequestState<TYPES: NodeType, I: NodeImplementation<TYPES>> {

impl<TYPES: NodeType, I: NodeImplementation<TYPES>> Drop for NetworkRequestState<TYPES, I> {
fn drop(&mut self) {
futures::executor::block_on(async move { self.cancel_subtasks().await });
self.cancel_subtasks();
}
}

Expand Down Expand Up @@ -123,7 +123,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>> TaskState for NetworkRequest
}
}

async fn cancel_subtasks(&mut self) {
fn cancel_subtasks(&mut self) {
self.shutdown_flag.store(true, Ordering::Relaxed);

while !self.spawned_tasks.is_empty() {
Expand Down
2 changes: 1 addition & 1 deletion crates/task-impls/src/rewind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<TYPES: NodeType> TaskState for RewindTaskState<TYPES> {
Ok(())
}

async fn cancel_subtasks(&mut self) {
fn cancel_subtasks(&mut self) {
tracing::info!("Node ID {} Recording {} events", self.id, self.events.len());
let filename = format!("rewind_{}.log", self.id);
let mut file = match OpenOptions::new()
Expand Down
2 changes: 1 addition & 1 deletion crates/task-impls/src/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,5 +822,5 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
self.handle(event, sender.clone()).await
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}
2 changes: 1 addition & 1 deletion crates/task-impls/src/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,5 +336,5 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
Ok(())
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}
2 changes: 1 addition & 1 deletion crates/task-impls/src/vid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,5 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>> TaskState for VidTaskState<T
Ok(())
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}
14 changes: 7 additions & 7 deletions crates/task-impls/src/view_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use utils::anytrace::*;

use crate::{
events::{HotShotEvent, HotShotTaskCompleted},
helpers::{broadcast_event, cancel_task},
helpers::broadcast_event,
vote_collection::{
create_vote_accumulator, AccumulatorInfo, HandleVoteEvent, VoteCollectionTaskState,
},
Expand Down Expand Up @@ -132,7 +132,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
self.handle(event, sender.clone()).await
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}

/// State of a view sync replica task
Expand Down Expand Up @@ -197,7 +197,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> TaskState
Ok(())
}

async fn cancel_subtasks(&mut self) {}
fn cancel_subtasks(&mut self) {}
}

impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions> ViewSyncTaskState<TYPES, I, V> {
Expand Down Expand Up @@ -572,7 +572,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
}

if let Some(timeout_task) = self.timeout_task.take() {
cancel_task(timeout_task).await;
timeout_task.abort();
}

self.timeout_task = Some(spawn({
Expand Down Expand Up @@ -665,7 +665,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
.await;

if let Some(timeout_task) = self.timeout_task.take() {
cancel_task(timeout_task).await;
timeout_task.abort();
}
self.timeout_task = Some(spawn({
let stream = event_stream.clone();
Expand Down Expand Up @@ -721,7 +721,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
}

if let Some(timeout_task) = self.timeout_task.take() {
cancel_task(timeout_task).await;
timeout_task.abort();
}

broadcast_event(
Expand Down Expand Up @@ -792,7 +792,7 @@ impl<TYPES: NodeType, I: NodeImplementation<TYPES>, V: Versions>
// Shouldn't ever receive a timeout for a relay higher than ours
if TYPES::View::new(*round) == self.next_view && *relay == self.relay {
if let Some(timeout_task) = self.timeout_task.take() {
cancel_task(timeout_task).await;
timeout_task.abort();
}
self.relay += 1;
match last_seen_certificate {
Expand Down
Loading