Skip to content

Commit

Permalink
Use &mut self in NotificationProcessor::init (#3)
Browse files Browse the repository at this point in the history
Approved by: @jkleinknox
============
This pr modifies the `NotificationProcessor::init` method to use `&mut self`. This allows us to consume resoruces in `self` and allows the `JoinSet` to be passed in later:
```rs
pub struct StatusReporter {
    input_rx: Option<UnboundedReceiver<StatusInput>>,
    reports: Arc<Mutex<HashMap<Uuid, Status>>>,
}

impl NotificationProcessor<OurMessage> for StatusReporter {
    fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<OurMessage>> {
        // we consume self.input_rx and set it to `None`
        let rx = self.input_rx.take().expect("uninitialized input_rx");
        spawn_some_other_rx(join_set, rx);

        self.spawn_notification_rx(join_set)
    }

    fn get_topics(&self) -> &[OurTopic] {
        &[OurTopic::Event]
    }
}

```
  • Loading branch information
mkatychev authored Apr 25, 2024
1 parent 7a2322d commit e2c8a7b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 40 deletions.
6 changes: 6 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ fmt:
# Run clippy fix and rustfmt afterwards
fix *args: && fmt
cd {{invocation_directory()}}; cargo clippy --fix --all-targets --all-features {{args}}

# run cargo clippy, denying warnings
lint:
cd {{invocation_directory()}}; cargo clippy --all-targets --all-features -- -D warnings


6 changes: 3 additions & 3 deletions src/ingress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ where
In: Send + Sync + fmt::Debug + 'static,
Out: Send + Sync + fmt::Debug + 'static,
{
fn init(&self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
self.init_notification_processor_with_handle(join_set)
}

Expand Down Expand Up @@ -261,7 +261,7 @@ mod tests {
let mut join_set = JoinSet::new();

let notification_manager: NotificationManager<TestMsg> =
NotificationManager::new(&[&nw_adapter], &mut join_set);
NotificationManager::new(vec![Box::new(nw_adapter)], &mut join_set);
let notification_tx = notification_manager.init(&mut join_set);

let unknown_packet = OutPacket(b"unknown_packet".to_vec());
Expand Down Expand Up @@ -295,7 +295,7 @@ mod tests {
let mut join_set = JoinSet::new();

let notification_manager: NotificationManager<TestMsg> =
NotificationManager::new(&[&nw_adapter], &mut join_set);
NotificationManager::new(vec![Box::new(nw_adapter)], &mut join_set);
let _notification_tx = notification_manager.init(&mut join_set);

// An unknown packet should be unrouteable
Expand Down
14 changes: 8 additions & 6 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,8 @@ where
.with_tick_rate(self.tick_rate.unwrap_or(timeout::DEFAULT_TICK_RATE));
self.notification_processors.push(Box::new(timeout_manager));
}
let processors: Vec<&dyn NotificationProcessor<K::Message>> = self
.notification_processors
.iter()
.map(|processor| processor.as_ref())
.collect();

let notification_manager = NotificationManager::new(processors.as_slice(), join_set);
let notification_manager = NotificationManager::new(self.notification_processors, join_set);
let notification_queue: UnboundedSender<Notification<K::Message>> =
notification_manager.init(join_set);

Expand Down Expand Up @@ -297,10 +292,17 @@ where
signal_queue: Arc<SignalQueue<K>>,
notification_queue: UnboundedSender<Notification<K::Message>>,
) -> Self {
let sm_count = state_machines.len();
let state_machines: HashMap<K, BoxedStateMachine<K>> = state_machines
.into_iter()
.map(|sm| (sm.get_kind(), sm))
.collect();
assert_eq!(
sm_count,
state_machines.len(),
"multiple state machines using the same kind, SMs: {sm_count}, Kinds: {}",
state_machines.len(),
);
Self {
signal_queue,
notification_queue,
Expand Down
29 changes: 15 additions & 14 deletions src/notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ impl<M> NotificationManager<M>
where
M: RexMessage,
{
pub fn new(processors: &[&dyn NotificationProcessor<M>], join_set: &mut JoinSet<()>) -> Self {
pub fn new(
processors: Vec<Box<dyn NotificationProcessor<M>>>,
join_set: &mut JoinSet<()>,
) -> Self {
let processors: HashMap<M::Topic, Vec<UnboundedSender<Notification<M>>>> = processors
.iter()
.fold(HashMap::new(), |mut subscribers, processor| {
.into_iter()
.fold(HashMap::new(), |mut subscribers, mut processor| {
let subscriber_tx = processor.init(join_set);
for topic in processor.get_topics() {
subscribers
Expand Down Expand Up @@ -142,7 +145,7 @@ pub trait NotificationProcessor<M>: Send + Sync
where
M: RexMessage,
{
fn init(&self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<M>>;
fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<M>>;
fn get_topics(&self) -> &[M::Topic];
}

Expand Down Expand Up @@ -228,10 +231,14 @@ mod tests {
use crate::timeout::*;

let timeout_manager = TimeoutManager::test_default();
let sq1 = timeout_manager.signal_queue.clone();
let timeout_manager_two = TimeoutManager::test_default();
let sq2 = timeout_manager_two.signal_queue.clone();
let mut join_set = JoinSet::new();
let notification_manager: NotificationManager<TestMsg> =
NotificationManager::new(&[&timeout_manager, &timeout_manager_two], &mut join_set);
let notification_manager: NotificationManager<TestMsg> = NotificationManager::new(
vec![Box::new(timeout_manager), Box::new(timeout_manager_two)],
&mut join_set,
);
let notification_tx = notification_manager.init(&mut join_set);

let test_id = StateId::new_with_u128(TestKind, 1);
Expand All @@ -243,14 +250,8 @@ mod tests {

tokio::time::sleep(Duration::from_millis(10)).await;

let timeout_one = timeout_manager
.signal_queue
.pop_front()
.expect("timeout one");
let timeout_two = timeout_manager_two
.signal_queue
.pop_front()
.expect("timeout two");
let timeout_one = sq1.pop_front().expect("timeout one");
let timeout_two = sq2.pop_front().expect("timeout two");
assert_eq!(timeout_one.id, timeout_two.id);
}
}
22 changes: 5 additions & 17 deletions src/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
time::Duration,
};

use bigerror::attachment::DisplayDuration;
use parking_lot::Mutex;
use tokio::{
sync::{mpsc, mpsc::UnboundedSender},
Expand All @@ -29,19 +30,6 @@ pub trait TimeoutMessage<K: Rex>: RexMessage + From<UnaryRequest<K, Self::Op>> {
pub const DEFAULT_TICK_RATE: Duration = Duration::from_millis(5);
const SHORT_TIMEOUT: Duration = Duration::from_secs(10);

pub struct DisplayDuration(pub Duration);
impl std::fmt::Display for DisplayDuration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hms_string(self.0))
}
}

impl From<Duration> for DisplayDuration {
fn from(duration: Duration) -> Self {
Self(duration)
}
}

/// convert a [`Duration`] into a "0H00m00s" string
fn hms_string(duration: Duration) -> String {
if duration.is_zero() {
Expand Down Expand Up @@ -348,7 +336,7 @@ where
K::Message: TryInto<TimeoutInput<K>>,
<K::Message as TryInto<TimeoutInput<K>>>::Error: Send,
{
fn init(&self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
self.init_inner_with_handle(join_set)
}

Expand Down Expand Up @@ -382,7 +370,7 @@ mod tests {

#[tokio::test]
async fn timeout_to_signal() {
let timeout_manager = TimeoutManager::test_default();
let mut timeout_manager = TimeoutManager::test_default();

let mut join_set = JoinSet::new();
let timeout_tx: UnboundedSender<Notification<TestMsg>> =
Expand Down Expand Up @@ -415,7 +403,7 @@ mod tests {

#[tokio::test]
async fn timeout_cancellation() {
let timeout_manager = TimeoutManager::test_default();
let mut timeout_manager = TimeoutManager::test_default();

let mut join_set = JoinSet::new();
let timeout_tx: UnboundedSender<Notification<TestMsg>> =
Expand Down Expand Up @@ -448,7 +436,7 @@ mod tests {
#[tokio::test]
#[tracing_test::traced_test]
async fn partial_timeout_cancellation() {
let timeout_manager = TimeoutManager::test_default();
let mut timeout_manager = TimeoutManager::test_default();

let mut join_set = JoinSet::new();
let timeout_tx: UnboundedSender<Notification<TestMsg>> =
Expand Down

0 comments on commit e2c8a7b

Please sign in to comment.