Skip to content

Commit 73906b7

Browse files
authored
Merge pull request #6 from jbr/use-async-channel
use async-channel to notify senders of disconnection
2 parents 0aa18c0 + 95ddc4b commit 73906b7

File tree

5 files changed

+62
-91
lines changed

5 files changed

+62
-91
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ async-std = { version = "1.6.0", features = ["unstable"] }
2121
http-types = "2.0.1"
2222
log = "0.4.8"
2323
memchr = "2.3.3"
24-
pin-project = "0.4.22"
24+
pin-project-lite = "0.1.4"
25+
async-channel = "1.1.1"
2526

2627
[dev-dependencies]
2728
femme = "2.0.0"

src/encoder.rs

Lines changed: 29 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,19 @@
1-
use async_std::sync;
2-
use std::io;
3-
use std::time::Duration;
4-
51
use async_std::io::Read as AsyncRead;
62
use async_std::prelude::*;
73
use async_std::task::{ready, Context, Poll};
4+
5+
use std::io;
86
use std::pin::Pin;
9-
use std::sync::atomic::{AtomicBool, Ordering};
10-
use std::sync::Arc;
11-
12-
use pin_project::{pin_project, pinned_drop};
13-
14-
#[pin_project(PinnedDrop)]
15-
/// An SSE protocol encoder.
16-
#[derive(Debug)]
17-
pub struct Encoder {
18-
buf: Option<Vec<u8>>,
19-
#[pin]
20-
receiver: sync::Receiver<Vec<u8>>,
21-
cursor: usize,
22-
disconnected: Arc<AtomicBool>,
23-
}
7+
use std::time::Duration;
248

25-
#[pinned_drop]
26-
impl PinnedDrop for Encoder {
27-
fn drop(self: Pin<&mut Self>) {
28-
self.disconnected.store(true, Ordering::Relaxed);
9+
pin_project_lite::pin_project! {
10+
/// An SSE protocol encoder.
11+
#[derive(Debug)]
12+
pub struct Encoder {
13+
buf: Option<Vec<u8>>,
14+
#[pin]
15+
receiver: async_channel::Receiver<Vec<u8>>,
16+
cursor: usize,
2917
}
3018
}
3119

@@ -91,79 +79,56 @@ impl AsyncRead for Encoder {
9179

9280
/// The sending side of the encoder.
9381
#[derive(Debug, Clone)]
94-
pub struct Sender {
95-
sender: sync::Sender<Vec<u8>>,
96-
disconnected: Arc<std::sync::atomic::AtomicBool>,
97-
}
82+
pub struct Sender(async_channel::Sender<Vec<u8>>);
9883

9984
/// Create a new SSE encoder.
10085
pub fn encode() -> (Sender, Encoder) {
101-
let (sender, receiver) = sync::channel(1);
102-
let disconnected = Arc::new(AtomicBool::new(false));
103-
86+
let (sender, receiver) = async_channel::bounded(1);
10487
let encoder = Encoder {
10588
receiver,
10689
buf: None,
10790
cursor: 0,
108-
disconnected: disconnected.clone(),
10991
};
110-
111-
let sender = Sender {
112-
sender,
113-
disconnected,
114-
};
115-
116-
(sender, encoder)
92+
(Sender(sender), encoder)
11793
}
11894

119-
/// An error that represents that the [Encoder] has been dropped.
120-
#[derive(Debug, Eq, PartialEq)]
121-
pub struct DisconnectedError;
122-
impl std::error::Error for DisconnectedError {}
123-
impl std::fmt::Display for DisconnectedError {
124-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125-
write!(f, "Disconnected")
95+
impl Sender {
96+
async fn inner_send(&self, bytes: impl Into<Vec<u8>>) -> io::Result<()> {
97+
self.0
98+
.send(bytes.into())
99+
.await
100+
.map_err(|_| io::Error::new(io::ErrorKind::ConnectionAborted, "sse disconnected"))
126101
}
127-
}
128102

129-
#[must_use]
130-
impl Sender {
131103
/// Send a new message over SSE.
132-
pub async fn send(
133-
&self,
134-
name: &str,
135-
data: &str,
136-
id: Option<&str>,
137-
) -> Result<(), DisconnectedError> {
138-
if self.disconnected.load(Ordering::Relaxed) {
139-
return Err(DisconnectedError);
140-
}
141-
104+
pub async fn send(&self, name: &str, data: &str, id: Option<&str>) -> io::Result<()> {
142105
// Write the event name
143106
let msg = format!("event:{}\n", name);
144-
self.sender.send(msg.into_bytes()).await;
107+
self.inner_send(msg).await?;
145108

146109
// Write the id
147110
if let Some(id) = id {
148-
self.sender.send(format!("id:{}\n", id).into_bytes()).await;
111+
self.inner_send(format!("id:{}\n", id)).await?;
149112
}
150113

151114
// Write the data section, and end.
152115
let msg = format!("data:{}\n\n", data);
153-
self.sender.send(msg.into_bytes()).await;
116+
self.inner_send(msg).await?;
117+
154118
Ok(())
155119
}
156120

157121
/// Send a new "retry" message over SSE.
158-
pub async fn send_retry(&self, dur: Duration, id: Option<&str>) {
122+
pub async fn send_retry(&self, dur: Duration, id: Option<&str>) -> io::Result<()> {
159123
// Write the id
160124
if let Some(id) = id {
161-
self.sender.send(format!("id:{}\n", id).into_bytes()).await;
125+
self.inner_send(format!("id:{}\n", id)).await?;
162126
}
163127

164128
// Write the retry section, and end.
165129
let dur = dur.as_secs_f64() as u64;
166130
let msg = format!("retry:{}\n\n", dur);
167-
self.sender.send(msg.into_bytes()).await;
131+
self.inner_send(msg).await?;
132+
Ok(())
168133
}
169134
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ mod lines;
4343
mod message;
4444

4545
pub use decoder::{decode, Decoder};
46-
pub use encoder::{encode, DisconnectedError, Encoder, Sender};
46+
pub use encoder::{encode, Encoder, Sender};
4747
pub use event::Event;
4848
pub use handshake::upgrade;
4949
pub use message::Message;

src/lines.rs

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,30 @@ use std::mem;
22
use std::pin::Pin;
33
use std::str;
44

5-
use pin_project::pin_project;
5+
use pin_project_lite::pin_project;
66

77
use async_std::io::{self, BufRead};
88
use async_std::stream::Stream;
99
use async_std::task::{ready, Context, Poll};
1010

11-
/// A stream of lines in a byte stream.
12-
///
13-
/// This stream is created by the [`lines`] method on types that implement [`BufRead`].
14-
///
15-
/// This type is an async version of [`std::io::Lines`].
16-
///
17-
/// [`lines`]: trait.BufRead.html#method.lines
18-
/// [`BufRead`]: trait.BufRead.html
19-
/// [`std::io::Lines`]: https://doc.rust-lang.org/std/io/struct.Lines.html
20-
#[pin_project]
21-
#[derive(Debug)]
22-
pub(crate) struct Lines<R> {
23-
#[pin]
24-
pub(crate) reader: R,
25-
pub(crate) buf: String,
26-
pub(crate) bytes: Vec<u8>,
27-
pub(crate) read: usize,
11+
pin_project! {
12+
/// A stream of lines in a byte stream.
13+
///
14+
/// This stream is created by the [`lines`] method on types that implement [`BufRead`].
15+
///
16+
/// This type is an async version of [`std::io::Lines`].
17+
///
18+
/// [`lines`]: trait.BufRead.html#method.lines
19+
/// [`BufRead`]: trait.BufRead.html
20+
/// [`std::io::Lines`]: https://doc.rust-lang.org/std/io/struct.Lines.html
21+
#[derive(Debug)]
22+
pub(crate) struct Lines<R> {
23+
#[pin]
24+
pub(crate) reader: R,
25+
pub(crate) buf: String,
26+
pub(crate) bytes: Vec<u8>,
27+
pub(crate) read: usize,
28+
}
2829
}
2930

3031
impl<R> Lines<R> {

tests/encode.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async fn encode_retry() -> http_types::Result<()> {
5353
let (sender, encoder) = encode();
5454
task::spawn(async move {
5555
let dur = Duration::from_secs(12);
56-
sender.send_retry(dur, None).await;
56+
sender.send_retry(dur, None).await.unwrap();
5757
});
5858

5959
let mut reader = decode(BufReader::new(encoder));
@@ -65,16 +65,20 @@ async fn encode_retry() -> http_types::Result<()> {
6565
#[async_std::test]
6666
async fn dropping_encoder() -> http_types::Result<()> {
6767
let (sender, encoder) = encode();
68-
let reader = BufReader::new(encoder);
6968
let sender_clone = sender.clone();
70-
task::spawn(async move { sender_clone.send("cat", "chashu", Some("0")).await.unwrap() });
69+
task::spawn(async move { sender_clone.send("cat", "chashu", None).await });
7170

72-
//move the encoder into Lines, which gets dropped after this
73-
assert_eq!(reader.lines().next().await.unwrap().unwrap(), "event:cat");
71+
let mut reader = decode(BufReader::new(encoder));
72+
let event = reader.next().await.unwrap()?;
73+
assert_message(&event, "cat", "chashu", None);
74+
75+
std::mem::drop(reader);
7476

77+
let response = sender.send("cat", "chashu", None).await;
78+
assert!(response.is_err());
7579
assert_eq!(
76-
sender.send("cat", "nori", None).await,
77-
Err(async_sse::DisconnectedError)
80+
response.unwrap_err().kind(),
81+
async_std::io::ErrorKind::ConnectionAborted
7882
);
7983
Ok(())
8084
}

0 commit comments

Comments
 (0)