diff --git a/Cargo.toml b/Cargo.toml index a3e1bfa..98c7530 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rex-sm" -version = "0.6.0" +version = "0.7.0" edition = "2021" description = "Hierarchical state machine" license = "MIT" diff --git a/src/builder.rs b/src/builder.rs index b579c99..fd38984 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use bigerror::{ConversionError, Report}; use tokio::{ @@ -97,7 +97,7 @@ where } #[must_use] - pub fn with_timeout_manager( + pub const fn with_timeout_manager( mut self, timeout_topic: ::Topic, ) -> Self { @@ -106,7 +106,7 @@ where } #[must_use] - pub fn with_tick_rate(mut self, tick_rate: Duration) -> Self { + pub const fn with_tick_rate(mut self, tick_rate: Duration) -> Self { self.tick_rate = Some(tick_rate); self } @@ -216,9 +216,9 @@ where fn default() -> Self { Self { notification_queue: NotificationQueue::new(), - signal_queue: Default::default(), - state_machines: Default::default(), - notification_processors: Default::default(), + signal_queue: Arc::default(), + state_machines: Vec::default(), + notification_processors: Vec::default(), timeout_topic: None, tick_rate: None, outbound_tx: None, diff --git a/src/lib.rs b/src/lib.rs index 4bae00a..4717dba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(clippy::module_name_repetitions)] use std::fmt; use bigerror::reportable; @@ -32,69 +33,24 @@ pub use timeout::Timeout; /// enumerations or enumerations whose variants only contain field-less enumerations; note that /// `Copy` is a required supertrait. pub trait State: fmt::Debug + Send + PartialEq + Copy { - fn get_kind(&self) -> &dyn Kind; - fn fail(&mut self) - where - Self: Sized, - { - *self = self.get_kind().failed_state(); - } - fn complete(&mut self) - where - Self: Sized, - { - *self = self.get_kind().completed_state(); - } - fn is_completed(&self) -> bool - where - Self: Sized, - for<'a> &'a Self: PartialEq<&'a Self>, - { - self == &self.get_kind().completed_state() - } - - fn is_failed(&self) -> bool - where - Self: Sized, - for<'a> &'a Self: PartialEq<&'a Self>, - { - self == &self.get_kind().failed_state() - } - fn is_new(&self) -> bool - where - Self: Sized, - for<'a> &'a Self: PartialEq<&'a Self>, - { - self == &self.get_kind().new_state() - } - - /// represents a state that will no longer change - fn is_terminal(&self) -> bool - where - Self: Sized, - { - self.is_failed() || self.is_completed() - } - - /// `&dyn Kind` cannot do direct partial comparison - /// due to type opacity - /// so `State::new_state(self)` is called to allow a vtable lookup - fn kind_eq(&self, kind: &dyn Kind) -> bool - where - Self: Sized, - { - self.get_kind().new_state() == kind.new_state() - } + type Input: Send + Sync + 'static + fmt::Debug; } /// Acts as a discriminant between various [`State`] enumerations, similar to /// [`std::mem::Discriminant`]. /// Used to define the scope for [`Signal`]s cycled through a [`StateMachineManager`]. -pub trait Kind: fmt::Debug + Send { - type State: State; +pub trait Kind: fmt::Debug + Send + Sized { + type State: State + AsRef; + type Input: Send + Sync + 'static + fmt::Debug; + fn new_state(&self) -> Self::State; fn failed_state(&self) -> Self::State; fn completed_state(&self) -> Self::State; + // /// represents a state that will no longer change + fn is_terminal(state: Self::State) -> bool { + let kind = state.as_ref(); + kind.completed_state() == state || kind.failed_state() == state + } } /// Titular trait of the library that enables Hierarchical State Machine (HSM for short) behaviour. @@ -109,13 +65,15 @@ pub trait Kind: fmt::Debug + Send { /// ```text /// /// Kind -> Rex::Message -/// :: :: :: -/// State Input Topic +/// :: :: +/// State::Input Topic /// ``` -pub trait Rex: Kind + HashKind { - type Input: Send + Sync + 'static + fmt::Debug; +pub trait Rex: Kind + HashKind +where + Self::State: AsRef, +{ type Message: RexMessage; - fn state_input(&self, state: ::State) -> Option; + fn state_input(&self, state: Self::State) -> Option; fn timeout_input(&self, instant: Instant) -> Option; } @@ -147,7 +105,7 @@ where f, "{:?}<{}>", self.kind, - self.is_nil() + (!self.is_nil()) .then(|| bs58::encode(self.uuid).into_string()) .unwrap_or_else(|| "NIL".to_string()) ) @@ -155,7 +113,7 @@ where } impl StateId { - pub fn new(kind: K, uuid: Uuid) -> Self { + pub const fn new(kind: K, uuid: Uuid) -> Self { Self { kind, uuid } } @@ -163,7 +121,7 @@ impl StateId { Self::new(kind, Uuid::new_v4()) } - pub fn nil(kind: K) -> Self { + pub const fn nil(kind: K) -> Self { Self::new(kind, Uuid::nil()) } pub fn is_nil(&self) -> bool { @@ -172,7 +130,7 @@ impl StateId { // for testing purposes, easily distinguish UUIDs // by numerical value #[cfg(test)] - pub fn new_with_u128(kind: K, v: u128) -> Self { + pub const fn new_with_u128(kind: K, v: u128) -> Self { Self { kind, uuid: Uuid::from_u128(v), diff --git a/src/manager.rs b/src/manager.rs index 92fd49c..020f733 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -60,7 +60,7 @@ use crate::{ queue::StreamableDeque, storage::{StateStore, Tree}, timeout::{RetainItem, TimeoutInput, TimeoutMessage}, - Kind, Rex, State, StateId, + Kind, Rex, StateId, }; pub trait HashKind: Kind + fmt::Debug + Hash + Eq + PartialEq + 'static + Copy @@ -138,6 +138,10 @@ impl SmContext { self.notification_queue.send(notification); } + pub fn signal_self(&self, input: K::Input) { + self.signal_queue.push_front(Signal { id: self.id, input }); + } + pub fn get_state(&self) -> Option { let tree = self.state_store.get_tree(self.id)?; let guard = tree.lock(); @@ -147,6 +151,7 @@ impl SmContext { pub fn get_tree(&self) -> Option> { self.state_store.get_tree(self.id) } + pub fn has_state(&self) -> bool { self.state_store.get_tree(self.id).is_some() } @@ -157,6 +162,10 @@ impl SmContext { guard.get_parent_id(self.id) }) } + + pub fn has_parent(&self) -> bool { + self.get_parent_id().is_some() + } } impl Clone for SmContext { fn clone(&self) -> Self { @@ -278,7 +287,7 @@ where fn new_child(&self, ctx: &SmContext, child_id: StateId) { let id = ctx.id; let tree = ctx.state_store.get_tree(id).unwrap(); - ctx.state_store.insert_ref(id, tree.clone()); + ctx.state_store.insert_ref(child_id, tree.clone()); let mut tree = tree.lock(); tree.insert(Insert { parent_id: Some(ctx.id), @@ -318,11 +327,6 @@ where self.update_state_and_signal(ctx, id.completed_state()) } - /// represents a state that will no longer change - fn terminal_state(state: K::State) -> bool { - state.is_failed() || state.is_completed() - } - /// update state is meant to be used to signal a parent state of a child state /// _if_ a parent exists, this function makes no assumptions of the potential /// structure of a state hierarchy and _should_ be just as performant on a single @@ -335,9 +339,9 @@ where tracing::error!(%id, "Tree not found!"); panic!("missing SmTree"); }; - let mut guard = tree.lock(); - if let Some(id) = guard.update_and_get_parent_id(Update { id, state }) { + let parent_id = tree.lock().update_and_get_parent_id(Update { id, state }); + if let Some(id) = parent_id { ctx.signal_queue.signal_state_change(id, state); return Some(id); @@ -387,12 +391,10 @@ mod tests { use super::*; use crate::{ - node::{Insert, Node}, notification::GetTopic, - storage::StateStore, test_support::Hold, timeout::{Timeout, TimeoutMessage, TimeoutTopic, TEST_TICK_RATE, TEST_TIMEOUT}, - Rex, RexBuilder, RexMessage, + Rex, RexBuilder, RexMessage, State, }; impl From> for GameMsg { @@ -433,7 +435,7 @@ mod tests { } #[derive(Clone, Debug, derive_more::From)] - pub enum Input { + pub enum GameInput { Ping(PingInput), Pong(PongInput), Menu(MenuInput), @@ -441,10 +443,10 @@ mod tests { // determines whether Ping or Pong will await before packet send // - #[derive(Copy, Clone, PartialEq, Debug)] + #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct WhoHolds(Option); - #[derive(Clone, PartialEq, Debug)] + #[derive(Clone, PartialEq, Eq, Debug)] pub enum MenuInput { Play(WhoHolds), PingPongComplete, @@ -452,7 +454,7 @@ mod tests { FailedPong, } - #[derive(Copy, Clone, PartialEq, Default, Debug)] + #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)] pub enum MenuState { #[default] Ready, @@ -460,7 +462,7 @@ mod tests { Failed, } - #[derive(Copy, Clone, PartialEq, Default, Debug)] + #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)] pub enum PingState { #[default] Ready, @@ -478,7 +480,7 @@ mod tests { RecvTimeout(Instant), } - #[derive(Copy, Clone, PartialEq, Default, Debug)] + #[derive(Copy, Clone, PartialEq, Eq, Default, Debug)] pub enum PongState { #[default] Ready, @@ -495,7 +497,7 @@ mod tests { Returned(Hold), } - #[derive(Copy, Clone, PartialEq, Debug)] + #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum GameState { Ping(PingState), Pong(PongState), @@ -503,11 +505,14 @@ mod tests { } impl State for GameState { - fn get_kind(&self) -> &dyn Kind { + type Input = GameInput; + } + impl AsRef for GameState { + fn as_ref(&self) -> &Game { match self { - GameState::Ping(_) => &Game::Ping, - GameState::Pong(_) => &Game::Pong, - GameState::Menu(_) => &Game::Menu, + Self::Ping(_) => &Game::Ping, + Self::Pong(_) => &Game::Pong, + Self::Menu(_) => &Game::Menu, } } } @@ -520,11 +525,10 @@ mod tests { } impl Rex for Game { - type Input = Input; type Message = GameMsg; fn state_input(&self, state: ::State) -> Option { - if *self != Game::Menu { + if *self != Self::Menu { return None; } @@ -534,14 +538,14 @@ mod tests { GameState::Pong(PongState::Failed) => Some(MenuInput::FailedPong), _ => None, } - .map(|i| i.into()) + .map(std::convert::Into::into) } fn timeout_input(&self, instant: Instant) -> Option { match self { - Game::Ping => Some(PingInput::RecvTimeout(instant).into()), - Game::Pong => Some(PongInput::RecvTimeout(instant).into()), - Game::Menu => None, + Self::Ping => Some(PingInput::RecvTimeout(instant).into()), + Self::Pong => Some(PongInput::RecvTimeout(instant).into()), + Self::Menu => None, } } } @@ -549,37 +553,38 @@ mod tests { impl Timeout for Game { fn return_item(&self, item: RetainItem) -> Option { match self { - Game::Ping => Some(Input::Ping(item.into())), - Game::Pong => Some(Input::Pong(item.into())), - Game::Menu => None, + Self::Ping => Some(GameInput::Ping(item.into())), + Self::Pong => Some(GameInput::Pong(item.into())), + Self::Menu => None, } } } impl Kind for Game { type State = GameState; + type Input = GameInput; fn new_state(&self) -> Self::State { match self { - Game::Ping => GameState::Ping(PingState::default()), - Game::Pong => GameState::Pong(PongState::default()), - Game::Menu => GameState::Menu(MenuState::default()), + Self::Ping => GameState::Ping(PingState::default()), + Self::Pong => GameState::Pong(PongState::default()), + Self::Menu => GameState::Menu(MenuState::default()), } } fn failed_state(&self) -> Self::State { match self { - Game::Ping => GameState::Ping(PingState::Failed), - Game::Pong => GameState::Pong(PongState::Failed), - Game::Menu => GameState::Menu(MenuState::Failed), + Self::Ping => GameState::Ping(PingState::Failed), + Self::Pong => GameState::Pong(PongState::Failed), + Self::Menu => GameState::Menu(MenuState::Failed), } } fn completed_state(&self) -> Self::State { match self { - Game::Ping => GameState::Ping(PingState::Done), - Game::Pong => GameState::Pong(PongState::Done), - Game::Menu => GameState::Menu(MenuState::Done), + Self::Ping => GameState::Ping(PingState::Done), + Self::Pong => GameState::Pong(PongState::Done), + Self::Menu => GameState::Menu(MenuState::Done), } } } @@ -590,15 +595,15 @@ mod tests { impl StateMachine for MenuStateMachine { #[instrument(name = "menu", skip_all)] - fn process(&self, ctx: SmContext, input: Input) { + fn process(&self, ctx: SmContext, input: GameInput) { let id = ctx.id; - let Input::Menu(input) = input else { + let GameInput::Menu(input) = input else { error!(input = ?input, "invalid input!"); return; }; let state = ctx.get_state(); - if let Some(true) = state.map(Self::terminal_state) { + if state.map(Game::is_terminal) == Some(true) { warn!(%id, ?state, "Ignoring input due to invalid state"); return; } @@ -608,25 +613,13 @@ mod tests { let ping_id = StateId::new_rand(Game::Ping); let pong_id = StateId::new_rand(Game::Pong); // Menu + Ping + Pong - let menu_tree = Node::new(id) - .into_insert(Insert { - parent_id: Some(id), - id: ping_id, - }) - .into_insert(Insert { - parent_id: Some(id), - id: pong_id, - }); - - let tree = StateStore::new_tree(menu_tree); - for id in [id, ping_id, pong_id] { - ctx.state_store.insert_ref(id, tree.clone()); - } - + self.create_tree(&ctx); + self.new_child(&ctx, ping_id); + self.new_child(&ctx, pong_id); // signal to Ping state machine ctx.signal_queue.push_back(Signal { id: ping_id, - input: Input::Ping(PingInput::StartSending(pong_id, who_holds)), + input: GameInput::Ping(PingInput::StartSending(pong_id, who_holds)), }); } MenuInput::PingPongComplete => { @@ -635,10 +628,9 @@ mod tests { } failure @ (MenuInput::FailedPing | MenuInput::FailedPong) => { let tree = ctx.get_tree().unwrap(); - let mut guard = tree.lock(); // set all states to failed state - guard.update_all_fn(|mut z| { - z.node.state.fail(); + tree.lock().update_all_fn(|mut z| { + z.node.state = z.node.state.as_ref().failed_state(); let id = z.node.id; ctx.notification_queue .priority_send(Notification(TimeoutInput::cancel_timeout(id).into())); @@ -668,14 +660,15 @@ mod tests { impl StateMachine for PingStateMachine { #[instrument(name = "ping", skip_all)] - fn process(&self, ctx: SmContext, input: Input) { + fn process(&self, ctx: SmContext, input: GameInput) { let id = ctx.id; - let Input::Ping(input) = input else { + let GameInput::Ping(input) = input else { error!(?input, "invalid input!"); return; }; + assert!(ctx.get_parent_id().is_some()); let state = ctx.get_state().unwrap(); - if Self::terminal_state(state) { + if Game::is_terminal(state) { warn!(%id, ?state, "Ignoring input due to invalid state"); return; } @@ -686,7 +679,7 @@ mod tests { info!(msg = 0, "PINGING"); ctx.signal_queue.push_back(Signal { id: pong_id, - input: Input::Pong(PongInput::Packet(Packet { + input: GameInput::Pong(PongInput::Packet(Packet { msg: 0, sender: id, who_holds, @@ -703,7 +696,7 @@ mod tests { self.set_timeout(&ctx, TEST_TIMEOUT); packet.msg += 5; - if let WhoHolds(Some(Game::Ping)) = packet.who_holds { + if packet.who_holds == WhoHolds(Some(Game::Ping)) { info!(msg = packet.msg, "HOLDING"); // hold for half theduration of the message let hold_for = Duration::from_millis(packet.msg); @@ -735,16 +728,17 @@ mod tests { impl StateMachine for PongStateMachine { #[instrument(name = "pong", skip_all, fields(id = %ctx.id))] - fn process(&self, ctx: SmContext, input: Input) { - let Input::Pong(input) = input else { + fn process(&self, ctx: SmContext, input: GameInput) { + let GameInput::Pong(input) = input else { error!(?input, "invalid input!"); return; }; let state = ctx.get_state().unwrap(); - if Self::terminal_state(state) { + if Game::is_terminal(state) { warn!(?state, "Ignoring input due to invalid state"); return; } + assert!(ctx.get_parent_id().is_some()); match input { PongInput::Packet(Packet { @@ -759,7 +753,7 @@ mod tests { self.cancel_timeout(&ctx); ctx.signal_queue.push_back(Signal { id: sender, - input: Input::Ping(PingInput::Packet(Packet { + input: GameInput::Ping(PingInput::Packet(Packet { msg, sender: ctx.id, who_holds, @@ -773,7 +767,7 @@ mod tests { } packet.msg += 5; - if let WhoHolds(Some(Game::Pong)) = packet.who_holds { + if packet.who_holds == WhoHolds(Some(Game::Pong)) { info!(msg = packet.msg, "HOLDING"); // hold for half the duration of the message let hold_for = Duration::from_millis(packet.msg); @@ -824,7 +818,7 @@ mod tests { let menu_id = StateId::new_rand(Game::Menu); ctx.signal_queue.push_back(Signal { id: menu_id, - input: Input::Menu(MenuInput::Play(WhoHolds(None))), + input: GameInput::Menu(MenuInput::Play(WhoHolds(None))), }); tokio::time::sleep(Duration::from_millis(1)).await; @@ -852,9 +846,8 @@ mod tests { drop(node); let tree = ctx.state_store.get_tree(pong_id).unwrap(); - let node = tree.lock(); - let state = node.get_state(pong_id).unwrap(); - assert_eq!(GameState::Pong(PongState::Done), *state); + let state = tree.lock().get_state(pong_id).copied().unwrap(); + assert_eq!(GameState::Pong(PongState::Done), state); } #[tracing_test::traced_test] @@ -873,7 +866,7 @@ mod tests { let menu_id = StateId::new_rand(Game::Menu); ctx.signal_queue.push_back(Signal { id: menu_id, - input: Input::Menu(MenuInput::Play(WhoHolds(Some(Game::Ping)))), + input: GameInput::Menu(MenuInput::Play(WhoHolds(Some(Game::Ping)))), }); tokio::time::sleep(TEST_TIMEOUT * 4).await; @@ -881,12 +874,12 @@ mod tests { { let tree = ctx.state_store.get_tree(menu_id).unwrap(); let node = tree.lock(); - let ping_node = &node.children[0]; - let pong_node = &node.children[1]; + let ping = &node.children[0]; + let pong = &node.children[1]; assert_eq!(menu_id, node.id); assert_eq!(GameState::Menu(MenuState::Failed), node.state); - assert_eq!(GameState::Ping(PingState::Failed), ping_node.state); - assert_eq!(GameState::Pong(PongState::Failed), pong_node.state); + assert_eq!(GameState::Ping(PingState::Failed), ping.state); + assert_eq!(GameState::Pong(PongState::Failed), pong.state); // !!NOTE!! ============================================================ // we are trying to acquire another lock... @@ -918,7 +911,7 @@ mod tests { let menu_id = StateId::new_rand(Game::Menu); ctx.signal_queue.push_back(Signal { id: menu_id, - input: Input::Menu(MenuInput::Play(WhoHolds(Some(Game::Pong)))), + input: GameInput::Menu(MenuInput::Play(WhoHolds(Some(Game::Pong)))), }); tokio::time::sleep(TEST_TIMEOUT * 4).await; @@ -932,6 +925,7 @@ mod tests { assert_eq!(GameState::Ping(PingState::Failed), ping_node.state); assert_eq!(GameState::Pong(PongState::Failed), pong_node.state); // Ensure that our Menu failed due to Ping + drop(node); assert_eq!(MenuInput::FailedPing, *menu_failures.get(&menu_id).unwrap()); } } diff --git a/src/node.rs b/src/node.rs index 0414a3b..5388316 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,25 +1,8 @@ -use std::{collections::HashSet, fmt, hash::Hash}; +use std::collections::HashSet; -use crate::{Kind, State}; - -impl Kind for Id -where - K: Kind, - Id: std::ops::Deref + Send + fmt::Debug, -{ - type State = K::State; - fn new_state(&self) -> Self::State { - self.deref().new_state() - } - fn failed_state(&self) -> Self::State { - self.deref().failed_state() - } - - fn completed_state(&self) -> Self::State { - self.deref().completed_state() - } -} +use crate::{HashKind, Kind, StateId}; +#[derive(Debug)] pub struct Insert { pub parent_id: Option, pub id: Id, @@ -38,23 +21,22 @@ pub struct Node { pub children: Vec>, } -impl Node +impl Node, K::State> where - S: State, - Id: Copy + Eq + PartialEq + Hash + fmt::Display + Kind + fmt::Debug, + K: Kind + HashKind, { #[must_use] - pub fn new(id: Id) -> Self { + pub fn new(id: StateId) -> Self { Self { - id, state: id.new_state(), + id, descendant_keys: HashSet::new(), children: Vec::new(), } } #[must_use] - pub fn zipper(self) -> Zipper { + pub const fn zipper(self) -> Zipper, K::State> { Zipper { node: self, parent: None, @@ -63,7 +45,7 @@ where } #[must_use] - pub fn get(&self, id: Id) -> Option<&Node> { + pub fn get(&self, id: StateId) -> Option<&Self> { if self.id == id { return Some(self); } @@ -79,20 +61,20 @@ where } #[must_use] - pub fn get_state(&self, id: Id) -> Option<&S> { + pub fn get_state(&self, id: StateId) -> Option<&K::State> { self.get(id).map(|n| &n.state) } #[must_use] - pub fn child(&self, id: Id) -> Option<&Node> { + pub fn child(&self, id: StateId) -> Option<&Self> { self.children .iter() .find(|node| node.id == id || node.descendant_keys.contains(&id)) } - // get array index by of node with Id in self.descendant_keys + // get array index by of node with StateId in self.descendant_keys #[must_use] - pub fn child_idx(&self, id: Id) -> Option { + pub fn child_idx(&self, id: StateId) -> Option { self.children .iter() .enumerate() @@ -100,13 +82,13 @@ where .map(|(idx, _)| idx) } - pub fn insert(&mut self, insert: Insert) { + pub fn insert(&mut self, insert: Insert>) { // temporary allocation to allow a drop in &mut implementation // // this can be optimized later but right now allocation impact // is non existent since Node::new // does not grow its `?Sized` types - let mut swap_node = Node::new(self.id); + let mut swap_node = Self::new(self.id); std::mem::swap(&mut swap_node, self); swap_node = swap_node.into_insert(insert); @@ -116,7 +98,10 @@ where /// inserts a new node using self by value #[must_use] - pub fn into_insert(self, Insert { parent_id, id }: Insert) -> Node { + pub fn into_insert( + self, + Insert { parent_id, id }: Insert>, + ) -> Self { // inserts at this point should be guaranteed Some(id) // ince a parent_id.is_none() should be handled by the node // store through a new graph creation @@ -129,7 +114,7 @@ where } #[must_use] - pub fn get_parent_id(&self, id: Id) -> Option { + pub fn get_parent_id(&self, id: StateId) -> Option> { // root_node edge case if !self.descendant_keys.contains(&id) { return None; @@ -147,9 +132,9 @@ where None } - pub fn update(&mut self, update: Update) { + pub fn update(&mut self, update: Update, K::State>) { // see Node::insert - let mut swap_node = Node::new(self.id); + let mut swap_node = Self::new(self.id); std::mem::swap(&mut swap_node, self); swap_node = swap_node.into_update(update); @@ -158,9 +143,12 @@ where } /// update a given node's state and return the parent ID if it exists - pub fn update_and_get_parent_id(&mut self, Update { id, state }: Update) -> Option { + pub fn update_and_get_parent_id( + &mut self, + Update { id, state }: Update, K::State>, + ) -> Option> { // see Node::insert - let mut swap_node = Node::new(self.id); + let mut swap_node = Self::new(self.id); std::mem::swap(&mut swap_node, self); let (parent_id, mut swap_node) = swap_node @@ -177,10 +165,10 @@ where // apply a closure to all nodes in a tree pub fn update_all_fn(&mut self, f: F) where - F: Fn(Zipper) -> Node + Clone, + F: Fn(Zipper, K::State>) -> Self + Clone, { // see Node::insert - let mut swap_node = Node::new(self.id); + let mut swap_node = Self::new(self.id); std::mem::swap(&mut swap_node, self); swap_node = swap_node.zipper().finish_update_fn(f); @@ -189,7 +177,10 @@ where } #[must_use] - pub fn into_update(self, Update { id, state }: Update) -> Node { + pub fn into_update( + self, + Update { id, state }: Update, K::State>, + ) -> Self { self.zipper().by_id(id).set_state(state).finish_update() } } @@ -211,12 +202,13 @@ pub struct Zipper { self_idx: usize, } -impl Zipper +type ZipperNode = Node, ::State>; + +impl Zipper, K::State> where - S: State, - Id: Copy + Eq + PartialEq + Hash + fmt::Display + Kind + fmt::Debug, + K: Kind + HashKind, { - fn by_id(mut self, id: Id) -> Zipper { + fn by_id(mut self, id: StateId) -> Self { let mut contains_id = self.node.descendant_keys.contains(&id); while contains_id { let idx = self.node.child_idx(id).unwrap(); @@ -230,7 +222,7 @@ where self } - fn child(mut self, idx: usize) -> Zipper { + fn child(mut self, idx: usize) -> Self { // Remove the specified child from the node's children. // Zipper should avoid having a parent reference // since parents will be mutated during node refocusing. @@ -238,27 +230,27 @@ where let child = self.node.children.swap_remove(idx); // Return a new Zipper focused on the specified child. - Zipper { + Self { node: child, parent: Some(Box::new(self)), self_idx: idx, } } - fn set_state(mut self, state: S) -> Zipper { + const fn set_state(mut self, state: K::State) -> Self { self.node.state = state; self } - fn insert_child(mut self, id: Id) -> Zipper { + fn insert_child(mut self, id: StateId) -> Self { self.node.children.push(Node::new(id)); self } - fn parent(self) -> Zipper { + fn parent(self) -> Self { // Destructure this Zipper // https://github.com/rust-lang/rust/issues/16293#issuecomment-185906859 - let Zipper { + let Self { node, parent, self_idx, @@ -275,7 +267,7 @@ where parent.node.children.swap(self_idx, last_idx); // Return a new Zipper focused on the parent. - Zipper { + Self { node: parent.node, parent: parent.parent, self_idx: parent.self_idx, @@ -283,7 +275,7 @@ where } // try something like Iterator::fold - fn finish_insert(mut self, id: Id) -> Node { + fn finish_insert(mut self, id: StateId) -> ZipperNode { self.node.descendant_keys.insert(id); while self.parent.is_some() { self = self.parent(); @@ -294,7 +286,7 @@ where } #[must_use] - pub fn finish_update(mut self) -> Node { + pub fn finish_update(mut self) -> ZipperNode { while self.parent.is_some() { self = self.parent(); } @@ -303,15 +295,15 @@ where } // only act on parent nodes - fn finish_update_parent_id(self) -> (Option, Node) { + fn finish_update_parent_id(self) -> (Option>, ZipperNode) { let parent_id = self.parent.as_ref().map(|z| z.node.id); (parent_id, self.finish_update()) } // act on all nodes - fn finish_update_fn(mut self, f: F) -> Node + fn finish_update_fn(mut self, f: F) -> ZipperNode where - F: Fn(Zipper) -> Node + Clone, + F: Fn(Self) -> ZipperNode + Clone, { self.node.children = self .node @@ -420,8 +412,9 @@ mod tests { // ...except for Dave, he is in "Completed" state // ================================================= tree = tree.zipper().finish_update_fn(|mut z| { - if !z.node.state.is_completed() { - z.node.state.fail(); + let kind: NodeKind = *z.node.state.as_ref(); + if !(z.node.state == kind.completed_state()) { + z.node.state = kind.failed_state(); } z.finish_update() }); diff --git a/src/notification.rs b/src/notification.rs index 0d5f106..371a811 100644 --- a/src/notification.rs +++ b/src/notification.rs @@ -161,7 +161,7 @@ where K: HashKind, O: Operation, { - pub fn new(id: StateId, op: O) -> Self { + pub const fn new(id: StateId, op: O) -> Self { Self { id, op } } } diff --git a/src/queue.rs b/src/queue.rs index f926041..c155a47 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -45,7 +45,7 @@ struct RawDeque { } impl RawDeque { - fn new() -> Self { + const fn new() -> Self { Self { front_values: VecDeque::new(), back_values: VecDeque::new(), @@ -108,7 +108,7 @@ impl StreamableDeque { /// Returns a stream of items using `pop_front()` /// This opens us up to handle a `back_stream()` as well - pub fn stream(&self) -> StreamReceiver { + pub const fn stream(&self) -> StreamReceiver { StreamReceiver { queue: self, awake: None, @@ -162,6 +162,7 @@ impl<'a, T> Stream for StreamReceiver<'a, T> { awake: awake.clone(), }); self.awake = Some(awake); + drop(inner); Poll::Pending } } @@ -172,7 +173,7 @@ impl<'a, T> Drop for StreamReceiver<'a, T> { fn drop(&mut self) { let awake = self.awake.take().map(|w| w.load(Ordering::Relaxed)); - if let Some(true) = awake { + if awake == Some(true) { let mut queue_wakers = self.queue.inner.lock(); // StreamReceiver was woken by a None, notify another if let Some(n) = queue_wakers.rx_notifiers.pop_front() { diff --git a/src/storage.rs b/src/storage.rs index c5ab8e4..40e15ae 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -10,7 +10,7 @@ use crate::{node::Node, Kind, Rex, StateId}; /// this allows separate state hirearchies to be acted upon concurrently /// while making operations in a particular tree blocking pub struct StateStore { - pub trees: DashMap>>>, + trees: DashMap>>>, } impl Default for StateStore, K::State> @@ -37,21 +37,37 @@ impl StateStore, K::State> { } } + /// # Panics + /// + /// Will panic if [`StateId`] is nil pub fn new_tree(node: Node, K::State>) -> Tree { + assert!(!node.id.is_nil()); Arc::new(FairMutex::new(node)) } - // insert node creates a new reference to the same node + /// insert node creates a new reference to the same node + /// # Panics + /// + /// Will panic if [`StateId`] is nil pub fn insert_ref(&self, id: StateId, node: Tree) { + assert!(!id.is_nil()); self.trees.insert(id, node); } // decrements the reference count on a given `Node` + /// # Panics + /// + /// Will panic if [`StateId`] is nil pub fn remove_ref(&self, id: StateId) { + assert!(!id.is_nil()); self.trees.remove(&id); } + /// # Panics + /// + /// Will panic if [`StateId`] is nil pub fn get_tree(&self, id: StateId) -> Option> { + assert!(!id.is_nil()); let node = self.trees.get(&id); node.map(|n| n.value().clone()) } diff --git a/src/test_support.rs b/src/test_support.rs index 09744c9..df6b92b 100644 --- a/src/test_support.rs +++ b/src/test_support.rs @@ -38,15 +38,18 @@ macro_rules! node_state { } impl State for NodeState { - fn get_kind(&self) -> &dyn Kind { + type Input = (); + } + impl AsRef for NodeState { + fn as_ref(&self) -> &NodeKind { match self { - $( NodeState::$name(_) => &NodeKind::$name, )* + $( Self::$name(_) => &NodeKind::$name, )* } } } - impl Kind for NodeKind { type State = NodeState; + type Input = (); fn new_state(&self) -> Self::State { match self { @@ -98,9 +101,9 @@ impl TimeoutMessage for TestMsg { impl GetTopic for TestMsg { fn get_topic(&self) -> TestTopic { match self { - TestMsg::TimeoutInput(_) => TestTopic::Timeout, - TestMsg::Ingress(_) => TestTopic::Ingress, - TestMsg::Other => TestTopic::Other, + Self::TimeoutInput(_) => TestTopic::Timeout, + Self::Ingress(_) => TestTopic::Ingress, + Self::Other => TestTopic::Other, } } } @@ -115,7 +118,11 @@ pub enum TestState { } impl State for TestState { - fn get_kind(&self) -> &dyn Kind { + type Input = TestInput; +} + +impl AsRef for TestState { + fn as_ref(&self) -> &TestKind { &TestKind } } @@ -125,6 +132,7 @@ pub struct TestKind; impl Kind for TestKind { type State = TestState; + type Input = TestInput; fn new_state(&self) -> Self::State { TestState::New @@ -153,7 +161,7 @@ impl TryFrom for TestInput { } impl Timeout for TestKind {} -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum TestInput { Timeout(Instant), Packet(InPacket), @@ -162,11 +170,10 @@ pub enum TestInput { #[derive(Clone, Debug, PartialEq, Eq)] pub struct OutPacket(pub Vec); -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct InPacket(pub Vec); impl Rex for TestKind { - type Input = TestInput; type Message = TestMsg; fn state_input(&self, _state: ::State) -> Option { @@ -198,7 +205,7 @@ impl StateRouter for TestStateRouter { impl<'a> TryFrom<&'a InPacket> for TestKind { type Error = Report; fn try_from(_value: &'a InPacket) -> Result { - Ok(TestKind) + Ok(Self) } } @@ -226,12 +233,12 @@ impl TryInto> for TestMsg { impl From for TestMsg { fn from(val: OutPacket) -> Self { - TestMsg::Ingress(val) + Self::Ingress(val) } } impl From> for TestMsg { fn from(value: TimeoutInput) -> Self { - TestMsg::TimeoutInput(value) + Self::TimeoutInput(value) } } diff --git a/src/timeout.rs b/src/timeout.rs index 0cfaa56..2f44b96 100644 --- a/src/timeout.rs +++ b/src/timeout.rs @@ -178,9 +178,9 @@ pub enum Operation { impl std::fmt::Display for Operation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let op = match self { - Operation::Cancel => "timeout::Cancel", - Operation::Set(_) => "timeout::Set", - Operation::Retain(_, _) => "timeout::Retain", + Self::Cancel => "timeout::Cancel", + Self::Set(_) => "timeout::Set", + Self::Retain(_, _) => "timeout::Retain", }; write!(f, "{op}") } @@ -222,7 +222,7 @@ where } } - pub fn cancel_timeout(id: StateId) -> Self { + pub const fn cancel_timeout(id: StateId) -> Self { Self { id, op: Operation::Cancel, @@ -237,11 +237,11 @@ where } #[cfg(test)] - fn with_id(&self, id: StateId) -> Self { + const fn with_id(&self, id: StateId) -> Self { Self { id, ..*self } } #[cfg(test)] - fn with_op(&self, op: TimeoutOp) -> Self { + const fn with_op(&self, op: TimeoutOp) -> Self { Self { op, ..*self } } } @@ -313,7 +313,7 @@ where ledger.set_timeout(id, instant); } Operation::Retain(item, instant) => { - ledger.retain(id, instant, item) + ledger.retain(id, instant, item); } } } @@ -338,15 +338,10 @@ where let now = Instant::now(); let mut ledger = timer_ledger.lock(); // Get all instants where `instant <= now` - let expired: Vec = - ledger.timers.range(..=now).map(|(k, _)| *k).collect(); - - for id in expired - .iter() - .filter_map(|t| ledger.timers.remove(t)) - .flat_map(IntoIterator::into_iter) - .collect::>() - { + let mut release = ledger.timers.split_off(&now); + std::mem::swap(&mut release, &mut ledger.timers); + + for id in release.into_values().flat_map(IntoIterator::into_iter) { warn!(%id, "timed out"); ledger.ids.remove(&id); if let Some(input) = id.timeout_input(now) { @@ -360,11 +355,8 @@ where let mut release = ledger.retainer.split_off(&now); std::mem::swap(&mut release, &mut ledger.retainer); - for (id, item) in release - .into_values() - .flat_map(IntoIterator::into_iter) - .collect::>() - { + drop(ledger); + for (id, item) in release.into_values().flat_map(IntoIterator::into_iter) { if let Some(input) = id.return_item(item) { // caveat with this push_front setup is // that later timeouts will be on top of the stack @@ -415,7 +407,7 @@ mod tests { impl TestDefault for TimeoutManager { fn test_default() -> Self { let signal_queue = SignalQueue::default(); - TimeoutManager::new(signal_queue, TestTopic::Timeout).with_tick_rate(TEST_TICK_RATE) + Self::new(signal_queue, TestTopic::Timeout).with_tick_rate(TEST_TICK_RATE) } }