Skip to content

Commit fbcc064

Browse files
Merge pull request #8 from sine-fdn/connection-closed
Connection closed
2 parents 9c47c2a + e743673 commit fbcc064

File tree

3 files changed

+216
-21
lines changed

3 files changed

+216
-21
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ jobs:
2222
~/.rustup
2323
key: ${{ env.cache-name }}-${{ hashFiles('**/Cargo.toml') }}
2424
- run: cargo build --all-features
25-
- run: cargo test --all-features -- --skip session
25+
- run: cargo test --all-features -- --skip session --skip quit_and_rejoin_session
2626
- run: cargo clippy -- -Dwarnings

src/main.rs

+84-20
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use futures::StreamExt;
33
use libp2p::{
44
gossipsub, noise,
55
swarm::{NetworkBehaviour, SwarmEvent},
6-
upnp, yamux, Multiaddr,
6+
upnp, yamux, Multiaddr, PeerId,
77
};
88
use log::{error, info};
99
use rsa::signature::SignatureEncoding;
@@ -90,13 +90,15 @@ struct MyBehaviour {
9090
enum Event {
9191
Upnp(upnp::Event),
9292
StdIn(String),
93-
Msg(Msg),
93+
Msg(Msg, PeerId),
94+
ConnectionClosed(PeerId),
9495
}
9596

9697
#[derive(Debug, Clone, Serialize, Deserialize)]
9798
enum Msg {
9899
Join(PublicKey, String),
99-
Participants(HashMap<PublicKey, String>),
100+
Quit(PeerId, String),
101+
Participants(HashMap<PublicKey, (String, PeerId)>),
100102
LobbyNowClosed,
101103
Share {
102104
from: PublicKey,
@@ -120,7 +122,17 @@ enum Phase {
120122
SendingShares,
121123
}
122124

123-
fn print_results(results: &BTreeMap<String, i64>, participants: &HashMap<PublicKey, String>) {
125+
fn print_participants(participants: &HashMap<PublicKey, (String, PeerId)>) {
126+
println!("\n-- Participants --");
127+
for (pub_key, (name, _)) in participants {
128+
println!("{pub_key} - {name}");
129+
}
130+
}
131+
132+
fn print_results(
133+
results: &BTreeMap<String, i64>,
134+
participants: &HashMap<PublicKey, (String, PeerId)>,
135+
) {
124136
println!("\nAverage results:");
125137
for (key, result) in results.iter() {
126138
let avg = (*result as f64 / participants.len() as f64) / 100.00;
@@ -200,7 +212,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
200212

201213
let mut phase = Phase::WaitingForParticipants;
202214
let mut stdin = io::BufReader::new(io::stdin()).lines();
203-
let mut participants = HashMap::<PublicKey, String>::new();
215+
let mut participants = HashMap::<PublicKey, (String, PeerId)>::new();
204216
let mut sent_shares = HashMap::<PublicKey, HashMap<&String, i64>>::new();
205217
let mut received_shares = HashMap::<PublicKey, Vec<u8>>::new();
206218
let mut sums = HashMap::<PublicKey, HashMap<String, i64>>::new();
@@ -363,20 +375,13 @@ async fn main() -> Result<(), Box<dyn Error>> {
363375
received_shares.insert(from, share);
364376
}
365377
}
366-
Event::Msg(msg)
378+
Event::Msg(msg, propagation_source)
367379
},
368380
SwarmEvent::IncomingConnectionError { .. } => {
369381
eprintln!("Error while establishing incoming connection");
370382
continue;
371383
},
372-
SwarmEvent::ConnectionClosed { .. } => {
373-
if result.is_none() {
374-
eprintln!("Connection has been closed by one of the participants");
375-
std::process::exit(1);
376-
} else {
377-
std::process::exit(0);
378-
}
379-
},
384+
SwarmEvent::ConnectionClosed { peer_id, .. } => Event::ConnectionClosed(peer_id),
380385
ev => {
381386
info!("{ev:?}");
382387
continue;
@@ -433,7 +438,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
433438
println!("{pub_key} - {name}");
434439
}
435440
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
436-
participants.insert(pub_key.clone(), name.clone());
441+
participants.insert(pub_key.clone(), (name.clone(), *swarm.local_peer_id()));
437442
}
438443
(_, Event::Upnp(upnp::Event::GatewayNotFound)) => {
439444
error!("Gateway does not support UPnP");
@@ -444,20 +449,62 @@ async fn main() -> Result<(), Box<dyn Error>> {
444449
break;
445450
}
446451
(_, Event::Upnp(ev)) => info!("{ev:?}"),
447-
(Phase::WaitingForParticipants, Event::Msg(msg)) => match msg {
452+
(Phase::WaitingForParticipants, Event::ConnectionClosed(peer_id)) => {
453+
if result.is_none() {
454+
let Some(disconnected) =
455+
participants.iter().find(|(_, (_, id))| *id == peer_id)
456+
else {
457+
println!("Connection error, please try again.");
458+
std::process::exit(1);
459+
};
460+
461+
let disconnected = disconnected.1 .0.clone();
462+
463+
println!("\nParticipant {disconnected} disconnected");
464+
465+
if swarm.connected_peers().count() == 0 && is_leader {
466+
participants.retain(|_, (_, id)| *id != peer_id);
467+
} else if is_leader {
468+
let msg = Msg::Quit(peer_id, disconnected).serialize()?;
469+
swarm
470+
.behaviour_mut()
471+
.gossipsub
472+
.publish(topic.clone(), msg)?;
473+
474+
participants.retain(|_, (_, id)| *id != peer_id);
475+
476+
print_participants(&participants);
477+
478+
let msg = Msg::Participants(participants.clone()).serialize()?;
479+
if let Err(e) = swarm.behaviour_mut().gossipsub.publish(topic.clone(), msg)
480+
{
481+
error!("Could not publish to gossipsub: {e:?}");
482+
}
483+
}
484+
continue;
485+
} else {
486+
std::process::exit(0);
487+
}
488+
}
489+
(Phase::WaitingForParticipants, Event::Msg(msg, peer_id)) => match msg {
448490
Msg::Join(public_key, name) => {
449491
if is_leader {
450492
println!("{public_key} - {name}");
451-
participants.insert(public_key, name);
493+
participants.insert(public_key, (name, peer_id));
452494
let msg = Msg::Participants(participants.clone()).serialize()?;
453495
if let Err(e) = swarm.behaviour_mut().gossipsub.publish(topic.clone(), msg)
454496
{
455497
error!("Could not publish to gossipsub: {e:?}");
456498
}
457499
}
458500
}
501+
Msg::Quit(_, name) => {
502+
println!("\nParticipant {name} disconnected");
503+
504+
print_participants(&participants);
505+
}
459506
Msg::Participants(all_participants) => {
460-
for (public_key, name) in all_participants.iter() {
507+
for (public_key, (name, _)) in all_participants.iter() {
461508
if !participants.contains_key(public_key) {
462509
println!("{public_key} - {name}");
463510
}
@@ -485,14 +532,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
485532
std::process::exit(1);
486533
}
487534
},
488-
(Phase::SendingShares, Event::Msg(msg)) => match msg {
535+
(Phase::SendingShares, Event::Msg(msg, _peer_id)) => match msg {
489536
Msg::Join(_, _) | Msg::Participants(_) | Msg::LobbyNowClosed => {
490537
println!(
491538
"Already waiting for shares, but some participant still tried to join!"
492539
);
493540
continue;
494541
}
495-
Msg::Share { .. } => {}
542+
Msg::Quit(..) | Msg::Share { .. } => {},
496543
Msg::Sum(public_key, sum) => {
497544
if is_leader {
498545
sums.insert(public_key, sum);
@@ -503,6 +550,23 @@ async fn main() -> Result<(), Box<dyn Error>> {
503550
std::process::exit(0);
504551
}
505552
},
553+
(Phase::SendingShares, Event::ConnectionClosed(peer_id)) => {
554+
if is_leader {
555+
let Some((_, (disconnected, _))) =
556+
participants.iter().find(|(_, (_, id))| *id == peer_id)
557+
else {
558+
println!("Connection error, please try again.");
559+
std::process::exit(1);
560+
};
561+
562+
println!(
563+
"Aborting benchmark: participant {disconnected} left the while waiting for shares"
564+
);
565+
} else {
566+
println!("Aborting benchmark: a participant left while waiting for shares");
567+
}
568+
std::process::exit(1);
569+
}
506570
(Phase::ConfirmingParticipants, _) => {}
507571
}
508572
}

tests/cli.rs

+131
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,137 @@ fn invalid_address() -> Result<(), Box<dyn std::error::Error>> {
7070
Ok(())
7171
}
7272

73+
#[test]
74+
fn quit_and_rejoin_session() -> Result<(), Box<dyn std::error::Error>> {
75+
let mut new_session = new_command("foo", None, "tests/test_files/valid_json.json")?;
76+
77+
let mut leader = new_session
78+
.stdout(Stdio::piped())
79+
.stdin(Stdio::piped())
80+
.spawn()?;
81+
let stdout = leader.stdout.take().unwrap();
82+
let reader = BufReader::new(stdout);
83+
let stdin = leader.stdin.take().unwrap();
84+
let mut writer = BufWriter::new(stdin);
85+
let mut lines = reader.lines();
86+
87+
let address = loop {
88+
if let Some(Ok(l)) = lines.next() {
89+
if l.contains("--address=/ip4/") {
90+
break l
91+
.split(" ")
92+
.find(|s| s.contains("--address=/ip4/"))
93+
.unwrap()
94+
.replace("--address=", "");
95+
}
96+
}
97+
};
98+
99+
let bar_address = address.clone();
100+
let bar_handle = thread::spawn(move || {
101+
let mut participant = new_command(
102+
"bar",
103+
Some(&bar_address),
104+
"tests/test_files/valid_json.json",
105+
)
106+
.unwrap()
107+
.stdout(Stdio::piped())
108+
.spawn()
109+
.unwrap();
110+
111+
let stdout = participant.stdout.take().unwrap();
112+
let reader = BufReader::new(stdout);
113+
let mut lines = reader.lines();
114+
115+
while let Some(Ok(l)) = lines.next() {
116+
println!("bar > {l}");
117+
if l.contains("- foo") {
118+
participant.kill().unwrap();
119+
break;
120+
}
121+
}
122+
});
123+
124+
while let Some(Ok(l)) = lines.next() {
125+
println!("foo > {l}");
126+
if l.contains("bar disconnected") {
127+
break;
128+
}
129+
}
130+
131+
bar_handle.join().unwrap();
132+
133+
let mut threads = vec![];
134+
for name in ["baz", "qux"] {
135+
sleep(Duration::from_millis(200));
136+
let address = address.clone();
137+
threads.push(thread::spawn(move || {
138+
let mut participant =
139+
new_command(name, Some(&address), "tests/test_files/valid_json.json")
140+
.unwrap()
141+
.stdin(Stdio::piped())
142+
.stdout(Stdio::piped())
143+
.spawn()
144+
.unwrap();
145+
146+
let stdout = participant.stdout.take().unwrap();
147+
let reader = BufReader::new(stdout);
148+
let stdin = participant.stdin.take().unwrap();
149+
let mut writer = BufWriter::new(stdin);
150+
let mut lines = reader.lines();
151+
152+
while let Some(Ok(l)) = lines.next() {
153+
println!("{name} > {l}");
154+
155+
if l.contains("Do you want to join the benchmark?") {
156+
sleep(Duration::from_millis(200));
157+
writeln!(writer, "y").unwrap();
158+
writer.flush().unwrap();
159+
}
160+
161+
if l.contains("results") {
162+
participant.kill().unwrap();
163+
return;
164+
}
165+
}
166+
}));
167+
}
168+
169+
let mut participant_count = 1;
170+
let mut benchmark_complete = false;
171+
while let Some(Ok(l)) = lines.next() {
172+
println!("foo > {}", l);
173+
if l.contains("- baz") || l.contains("- qux") {
174+
participant_count += 1;
175+
}
176+
if participant_count == 3 {
177+
sleep(Duration::from_millis(200));
178+
writeln!(writer, "").unwrap();
179+
writer.flush().unwrap();
180+
}
181+
if l.contains("results") {
182+
benchmark_complete = true;
183+
break;
184+
}
185+
}
186+
187+
sleep(Duration::from_millis(200));
188+
leader.kill()?;
189+
190+
for t in threads {
191+
t.join().unwrap();
192+
}
193+
194+
if benchmark_complete {
195+
Ok(())
196+
} else {
197+
Err(Box::new(Error::new(
198+
ErrorKind::Other,
199+
"Could not complete benchmark",
200+
)))
201+
}
202+
}
203+
73204
#[test]
74205
fn session() -> Result<(), Box<dyn std::error::Error>> {
75206
let mut new_session = new_command("foo", None, "tests/test_files/valid_json.json")?;

0 commit comments

Comments
 (0)