Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use &mut self in NotificationProcessor::init #3

Merged
merged 4 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/ingress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ where
In: Send + Sync + fmt::Debug + 'static,
Out: Send + Sync + fmt::Debug + 'static,
{
fn init(&self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
self.init_notification_processor_with_handle(join_set)
}

Expand Down Expand Up @@ -261,7 +261,7 @@ mod tests {
let mut join_set = JoinSet::new();

let notification_manager: NotificationManager<TestMsg> =
NotificationManager::new(&[&nw_adapter], &mut join_set);
NotificationManager::new(vec![Box::new(nw_adapter)], &mut join_set);
let notification_tx = notification_manager.init(&mut join_set);

let unknown_packet = OutPacket(b"unknown_packet".to_vec());
Expand Down Expand Up @@ -295,7 +295,7 @@ mod tests {
let mut join_set = JoinSet::new();

let notification_manager: NotificationManager<TestMsg> =
NotificationManager::new(&[&nw_adapter], &mut join_set);
NotificationManager::new(vec![Box::new(nw_adapter)], &mut join_set);
let _notification_tx = notification_manager.init(&mut join_set);

// An unknown packet should be unrouteable
Expand Down
14 changes: 8 additions & 6 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,8 @@ where
.with_tick_rate(self.tick_rate.unwrap_or(timeout::DEFAULT_TICK_RATE));
self.notification_processors.push(Box::new(timeout_manager));
}
let processors: Vec<&dyn NotificationProcessor<K::Message>> = self
.notification_processors
.iter()
.map(|processor| processor.as_ref())
.collect();

let notification_manager = NotificationManager::new(processors.as_slice(), join_set);
let notification_manager = NotificationManager::new(self.notification_processors, join_set);
let notification_queue: UnboundedSender<Notification<K::Message>> =
notification_manager.init(join_set);

Expand Down Expand Up @@ -297,10 +292,17 @@ where
signal_queue: Arc<SignalQueue<K>>,
notification_queue: UnboundedSender<Notification<K::Message>>,
) -> Self {
let sm_count = state_machines.len();
let state_machines: HashMap<K, BoxedStateMachine<K>> = state_machines
.into_iter()
.map(|sm| (sm.get_kind(), sm))
.collect();
assert_eq!(
sm_count,
state_machines.len(),
"multiple state machines using the same kind, SMs: {sm_count}, Kinds: {}",
state_machines.len(),
);
Self {
signal_queue,
notification_queue,
Expand Down
29 changes: 15 additions & 14 deletions src/notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ impl<M> NotificationManager<M>
where
M: RexMessage,
{
pub fn new(processors: &[&dyn NotificationProcessor<M>], join_set: &mut JoinSet<()>) -> Self {
pub fn new(
processors: Vec<Box<dyn NotificationProcessor<M>>>,
join_set: &mut JoinSet<()>,
) -> Self {
let processors: HashMap<M::Topic, Vec<UnboundedSender<Notification<M>>>> = processors
.iter()
.fold(HashMap::new(), |mut subscribers, processor| {
.into_iter()
.fold(HashMap::new(), |mut subscribers, mut processor| {
let subscriber_tx = processor.init(join_set);
for topic in processor.get_topics() {
subscribers
Expand Down Expand Up @@ -142,7 +145,7 @@ pub trait NotificationProcessor<M>: Send + Sync
where
M: RexMessage,
{
fn init(&self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<M>>;
fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<M>>;
fn get_topics(&self) -> &[M::Topic];
}

Expand Down Expand Up @@ -228,10 +231,14 @@ mod tests {
use crate::timeout::*;

let timeout_manager = TimeoutManager::test_default();
let sq1 = timeout_manager.signal_queue.clone();
let timeout_manager_two = TimeoutManager::test_default();
let sq2 = timeout_manager_two.signal_queue.clone();
let mut join_set = JoinSet::new();
let notification_manager: NotificationManager<TestMsg> =
NotificationManager::new(&[&timeout_manager, &timeout_manager_two], &mut join_set);
let notification_manager: NotificationManager<TestMsg> = NotificationManager::new(
vec![Box::new(timeout_manager), Box::new(timeout_manager_two)],
&mut join_set,
);
let notification_tx = notification_manager.init(&mut join_set);

let test_id = StateId::new_with_u128(TestKind, 1);
Expand All @@ -243,14 +250,8 @@ mod tests {

tokio::time::sleep(Duration::from_millis(10)).await;

let timeout_one = timeout_manager
.signal_queue
.pop_front()
.expect("timeout one");
let timeout_two = timeout_manager_two
.signal_queue
.pop_front()
.expect("timeout two");
let timeout_one = sq1.pop_front().expect("timeout one");
let timeout_two = sq2.pop_front().expect("timeout two");
assert_eq!(timeout_one.id, timeout_two.id);
}
}
22 changes: 5 additions & 17 deletions src/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
time::Duration,
};

use bigerror::attachment::DisplayDuration;
use parking_lot::Mutex;
use tokio::{
sync::{mpsc, mpsc::UnboundedSender},
Expand All @@ -29,19 +30,6 @@ pub trait TimeoutMessage<K: Rex>: RexMessage + From<UnaryRequest<K, Self::Op>> {
pub const DEFAULT_TICK_RATE: Duration = Duration::from_millis(5);
const SHORT_TIMEOUT: Duration = Duration::from_secs(10);

pub struct DisplayDuration(pub Duration);
impl std::fmt::Display for DisplayDuration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", hms_string(self.0))
}
}

impl From<Duration> for DisplayDuration {
fn from(duration: Duration) -> Self {
Self(duration)
}
}

/// convert a [`Duration`] into a "0H00m00s" string
fn hms_string(duration: Duration) -> String {
if duration.is_zero() {
Expand Down Expand Up @@ -348,7 +336,7 @@ where
K::Message: TryInto<TimeoutInput<K>>,
<K::Message as TryInto<TimeoutInput<K>>>::Error: Send,
{
fn init(&self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
self.init_inner_with_handle(join_set)
}

Expand Down Expand Up @@ -382,7 +370,7 @@ mod tests {

#[tokio::test]
async fn timeout_to_signal() {
let timeout_manager = TimeoutManager::test_default();
let mut timeout_manager = TimeoutManager::test_default();

let mut join_set = JoinSet::new();
let timeout_tx: UnboundedSender<Notification<TestMsg>> =
Expand Down Expand Up @@ -415,7 +403,7 @@ mod tests {

#[tokio::test]
async fn timeout_cancellation() {
let timeout_manager = TimeoutManager::test_default();
let mut timeout_manager = TimeoutManager::test_default();

let mut join_set = JoinSet::new();
let timeout_tx: UnboundedSender<Notification<TestMsg>> =
Expand Down Expand Up @@ -448,7 +436,7 @@ mod tests {
#[tokio::test]
#[tracing_test::traced_test]
async fn partial_timeout_cancellation() {
let timeout_manager = TimeoutManager::test_default();
let mut timeout_manager = TimeoutManager::test_default();

let mut join_set = JoinSet::new();
let timeout_tx: UnboundedSender<Notification<TestMsg>> =
Expand Down