Skip to content

Commit 79ce820

Browse files
committed
use a concrete error type in async-h1
1 parent 2d17054 commit 79ce820

File tree

11 files changed

+180
-97
lines changed

11 files changed

+180
-97
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ log = "0.4.11"
2828
pin-project = "1.0.2"
2929
async-channel = "1.5.1"
3030
async-dup = "1.2.2"
31+
thiserror = "1.0.22"
3132

3233
[dev-dependencies]
3334
pretty_assertions = "0.6.1"

src/client/decode.rs

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
use async_std::io::{BufReader, Read};
22
use async_std::prelude::*;
3-
use http_types::{ensure, ensure_eq, format_err};
3+
use http_types::content::ContentLength;
44
use http_types::{
5-
headers::{CONTENT_LENGTH, DATE, TRANSFER_ENCODING},
5+
headers::{DATE, TRANSFER_ENCODING},
66
Body, Response, StatusCode,
77
};
88

99
use std::convert::TryFrom;
1010

11-
use crate::chunked::ChunkedDecoder;
1211
use crate::date::fmt_http_date;
12+
use crate::{chunked::ChunkedDecoder, Error};
1313
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
1414

1515
const CR: u8 = b'\r';
1616
const LF: u8 = b'\n';
1717

1818
/// Decode an HTTP response on the client.
19-
pub async fn decode<R>(reader: R) -> http_types::Result<Response>
19+
pub async fn decode<R>(reader: R) -> crate::Result<Option<Response>>
2020
where
2121
R: Read + Unpin + Send + Sync + 'static,
2222
{
@@ -29,13 +29,14 @@ where
2929
loop {
3030
let bytes_read = reader.read_until(LF, &mut buf).await?;
3131
// No more bytes are yielded from the stream.
32-
assert!(bytes_read != 0, "Empty response"); // TODO: ensure?
32+
if bytes_read == 0 {
33+
return Ok(None);
34+
}
3335

3436
// Prevent CWE-400 DDOS with large HTTP Headers.
35-
ensure!(
36-
buf.len() < MAX_HEAD_LENGTH,
37-
"Head byte length should be less than 8kb"
38-
);
37+
if buf.len() >= MAX_HEAD_LENGTH {
38+
return Err(Error::HeadersTooLong);
39+
}
3940

4041
// We've hit the end delimiter of the stream.
4142
let idx = buf.len() - 1;
@@ -49,17 +50,23 @@ where
4950

5051
// Convert our header buf into an httparse instance, and validate.
5152
let status = httparse_res.parse(&buf)?;
52-
ensure!(!status.is_partial(), "Malformed HTTP head");
53+
if status.is_partial() {
54+
return Err(Error::PartialHead);
55+
}
5356

54-
let code = httparse_res.code;
55-
let code = code.ok_or_else(|| format_err!("No status code found"))?;
57+
let code = httparse_res.code.ok_or_else(|| Error::MissingStatusCode)?;
5658

5759
// Convert httparse headers + body into a `http_types::Response` type.
58-
let version = httparse_res.version;
59-
let version = version.ok_or_else(|| format_err!("No version found"))?;
60-
ensure_eq!(version, 1, "Unsupported HTTP version");
60+
let version = httparse_res.version.ok_or_else(|| Error::MissingVersion)?;
61+
62+
if version != 1 {
63+
return Err(Error::UnsupportedVersion(version));
64+
}
65+
66+
let status_code =
67+
StatusCode::try_from(code).map_err(|_| Error::UnrecognizedStatusCode(code))?;
68+
let mut res = Response::new(status_code);
6169

62-
let mut res = Response::new(StatusCode::try_from(code)?);
6370
for header in httparse_res.headers.iter() {
6471
res.append_header(header.name, std::str::from_utf8(header.value)?);
6572
}
@@ -69,13 +76,13 @@ where
6976
res.insert_header(DATE, &format!("date: {}\r\n", date)[..]);
7077
}
7178

72-
let content_length = res.header(CONTENT_LENGTH);
79+
let content_length =
80+
ContentLength::from_headers(&res).map_err(|_| Error::MalformedHeader("content-length"))?;
7381
let transfer_encoding = res.header(TRANSFER_ENCODING);
7482

75-
ensure!(
76-
content_length.is_none() || transfer_encoding.is_none(),
77-
"Unexpected Content-Length header"
78-
);
83+
if content_length.is_some() && transfer_encoding.is_some() {
84+
return Err(Error::UnexpectedHeader("content-length"));
85+
}
7986

8087
if let Some(encoding) = transfer_encoding {
8188
if encoding.last().as_str() == "chunked" {
@@ -84,16 +91,16 @@ where
8491
res.set_body(Body::from_reader(reader, None));
8592

8693
// Return the response.
87-
return Ok(res);
94+
return Ok(Some(res));
8895
}
8996
}
9097

9198
// Check for Content-Length.
92-
if let Some(len) = content_length {
93-
let len = len.last().as_str().parse::<usize>()?;
94-
res.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
99+
if let Some(content_length) = content_length {
100+
let len = content_length.len();
101+
res.set_body(Body::from_reader(reader.take(len), Some(len as usize)));
95102
}
96103

97104
// Return the response.
98-
Ok(res)
105+
Ok(Some(res))
99106
}

src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub use decode::decode;
1010
pub use encode::Encoder;
1111

1212
/// Opens an HTTP/1.1 connection to a remote host.
13-
pub async fn connect<RW>(mut stream: RW, req: Request) -> http_types::Result<Response>
13+
pub async fn connect<RW>(mut stream: RW, req: Request) -> crate::Result<Option<Response>>
1414
where
1515
RW: Read + Write + Send + Sync + Unpin + 'static,
1616
{

src/error.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use std::str::Utf8Error;
2+
3+
use http_types::url;
4+
use thiserror::Error;
5+
6+
/// Concrete errors that occur within async-h1
7+
#[derive(Error, Debug)]
8+
#[non_exhaustive]
9+
pub enum Error {
10+
/// [`std::io::Error`]
11+
#[error(transparent)]
12+
IO(#[from] std::io::Error),
13+
14+
/// [`url::ParseError`]
15+
#[error(transparent)]
16+
Url(#[from] url::ParseError),
17+
18+
/// this error describes a malformed request with a path that does
19+
/// not start with / or http:// or https://
20+
#[error("unexpected uri format")]
21+
UnexpectedURIFormat,
22+
23+
/// this error describes a http 1.1 request that is missing a Host
24+
/// header
25+
#[error("mandatory host header missing")]
26+
HostHeaderMissing,
27+
28+
/// this error describes a request that does not specify a path
29+
#[error("request path missing")]
30+
RequestPathMissing,
31+
32+
/// [`httparse::Error`]
33+
#[error(transparent)]
34+
Httparse(#[from] httparse::Error),
35+
36+
/// an incomplete http head
37+
#[error("partial http head")]
38+
PartialHead,
39+
40+
/// we were unable to parse a header
41+
#[error("malformed http header {0}")]
42+
MalformedHeader(&'static str),
43+
44+
/// async-h1 doesn't speak this http version
45+
/// this error is deprecated
46+
#[error("unsupported http version 1.{0}")]
47+
UnsupportedVersion(u8),
48+
49+
/// we were unable to parse this http method
50+
#[error("unsupported http method {0}")]
51+
UnrecognizedMethod(String),
52+
53+
/// this request did not have a method
54+
#[error("missing method")]
55+
MissingMethod,
56+
57+
/// this request did not have a status code
58+
#[error("missing status code")]
59+
MissingStatusCode,
60+
61+
/// we were unable to parse this http method
62+
#[error("unrecognized http status code {0}")]
63+
UnrecognizedStatusCode(u16),
64+
65+
/// this request did not have a version, but we expect one
66+
/// this error is deprecated
67+
#[error("missing version")]
68+
MissingVersion,
69+
70+
/// we expected utf8, but there was an encoding error
71+
#[error(transparent)]
72+
EncodingError(#[from] Utf8Error),
73+
74+
/// we received a header that does not make sense in context
75+
#[error("unexpected header: {0}")]
76+
UnexpectedHeader(&'static str),
77+
78+
/// for security reasons, we do not allow request headers beyond 8kb.
79+
#[error("Head byte length should be less than 8kb")]
80+
HeadersTooLong,
81+
}
82+
83+
/// this crate's result type
84+
pub type Result<T> = std::result::Result<T, Error>;

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ use async_std::io::Cursor;
116116
use body_encoder::BodyEncoder;
117117
pub use client::connect;
118118
pub use server::{accept, accept_with_opts, ServerOptions};
119+
mod error;
120+
pub use error::{Error, Result};
119121

120122
#[derive(Debug)]
121123
pub(crate) enum EncoderState {

src/server/decode.rs

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
//! Process HTTP connections on the server.
22
3-
use std::str::FromStr;
4-
53
use async_dup::{Arc, Mutex};
64
use async_std::io::{BufReader, Read, Write};
75
use async_std::{prelude::*, task};
8-
use http_types::content::ContentLength;
9-
use http_types::headers::{EXPECT, TRANSFER_ENCODING};
10-
use http_types::{ensure, ensure_eq, format_err};
11-
use http_types::{Body, Method, Request, Url};
6+
use http_types::{
7+
content::ContentLength,
8+
headers::{EXPECT, TRANSFER_ENCODING},
9+
Version,
10+
};
11+
use http_types::{Body, Request, Url};
12+
13+
use crate::{Error, Result};
1214

1315
use super::body_reader::BodyReader;
1416
use crate::chunked::ChunkedDecoder;
@@ -17,14 +19,11 @@ use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
1719

1820
const LF: u8 = b'\n';
1921

20-
/// The number returned from httparse when the request is HTTP 1.1
21-
const HTTP_1_1_VERSION: u8 = 1;
22-
2322
const CONTINUE_HEADER_VALUE: &str = "100-continue";
2423
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
2524

2625
/// Decode an HTTP request on the server.
27-
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
26+
pub async fn decode<IO>(mut io: IO) -> Result<Option<(Request, BodyReader<IO>)>>
2827
where
2928
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
3029
{
@@ -42,10 +41,9 @@ where
4241
}
4342

4443
// Prevent CWE-400 DDOS with large HTTP Headers.
45-
ensure!(
46-
buf.len() < MAX_HEAD_LENGTH,
47-
"Head byte length should be less than 8kb"
48-
);
44+
if buf.len() >= MAX_HEAD_LENGTH {
45+
return Err(Error::HeadersTooLong);
46+
}
4947

5048
// We've hit the end delimiter of the stream.
5149
let idx = buf.len() - 1;
@@ -57,44 +55,37 @@ where
5755
// Convert our header buf into an httparse instance, and validate.
5856
let status = httparse_req.parse(&buf)?;
5957

60-
ensure!(!status.is_partial(), "Malformed HTTP head");
61-
62-
// Convert httparse headers + body into a `http_types::Request` type.
63-
let method = httparse_req.method;
64-
let method = method.ok_or_else(|| format_err!("No method found"))?;
65-
66-
let version = httparse_req.version;
67-
let version = version.ok_or_else(|| format_err!("No version found"))?;
58+
if status.is_partial() {
59+
return Err(Error::PartialHead);
60+
}
6861

69-
ensure_eq!(
70-
version,
71-
HTTP_1_1_VERSION,
72-
"Unsupported HTTP version 1.{}",
73-
version
74-
);
62+
let method = httparse_req
63+
.method
64+
.ok_or_else(|| Error::MissingMethod)?
65+
.parse()
66+
.map_err(|_| Error::UnrecognizedMethod(httparse_req.method.unwrap().to_string()))?;
7567

7668
let url = url_from_httparse_req(&httparse_req)?;
7769

78-
let mut req = Request::new(Method::from_str(method)?, url);
70+
let mut req = Request::new(method, url);
7971

80-
req.set_version(Some(http_types::Version::Http1_1));
72+
match httparse_req.version {
73+
Some(1) => req.set_version(Some(Version::Http1_1)),
74+
Some(version) => return Err(Error::UnsupportedVersion(version)),
75+
None => return Err(Error::MissingVersion),
76+
}
8177

8278
for header in httparse_req.headers.iter() {
8379
req.append_header(header.name, std::str::from_utf8(header.value)?);
8480
}
8581

86-
let content_length = ContentLength::from_headers(&req)?;
82+
let content_length =
83+
ContentLength::from_headers(&req).map_err(|_| Error::MalformedHeader("content-length"))?;
8784
let transfer_encoding = req.header(TRANSFER_ENCODING);
8885

89-
// Return a 400 status if both Content-Length and Transfer-Encoding headers
90-
// are set to prevent request smuggling attacks.
91-
//
92-
// https://tools.ietf.org/html/rfc7230#section-3.3.3
93-
http_types::ensure_status!(
94-
content_length.is_none() || transfer_encoding.is_none(),
95-
400,
96-
"Unexpected Content-Length header"
97-
);
86+
if content_length.is_some() && transfer_encoding.is_some() {
87+
return Err(Error::UnexpectedHeader("content-length"));
88+
}
9889

9990
// Establish a channel to wait for the body to be read. This
10091
// allows us to avoid sending 100-continue in situations that
@@ -128,8 +119,8 @@ where
128119
let reader = BufReader::new(reader);
129120
req.set_body(Body::from_reader(reader, None));
130121
return Ok(Some((req, BodyReader::Chunked(reader_clone))));
131-
} else if let Some(len) = content_length {
132-
let len = len.len();
122+
} else if let Some(content_length) = content_length {
123+
let len = content_length.len();
133124
let reader = Arc::new(Mutex::new(reader.take(len)));
134125
req.set_body(Body::from_reader(
135126
BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
@@ -141,14 +132,14 @@ where
141132
}
142133
}
143134

144-
fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {
145-
let path = req.path.ok_or_else(|| format_err!("No uri found"))?;
135+
fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> Result<Url> {
136+
let path = req.path.ok_or_else(|| Error::RequestPathMissing)?;
146137

147138
let host = req
148139
.headers
149140
.iter()
150141
.find(|x| x.name.eq_ignore_ascii_case("host"))
151-
.ok_or_else(|| format_err!("Mandatory Host header missing"))?
142+
.ok_or_else(|| Error::HostHeaderMissing)?
152143
.value;
153144

154145
let host = std::str::from_utf8(host)?;
@@ -160,7 +151,7 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
160151
} else if req.method.unwrap().eq_ignore_ascii_case("connect") {
161152
Ok(Url::parse(&format!("http://{}/", path))?)
162153
} else {
163-
Err(format_err!("unexpected uri format"))
154+
Err(Error::UnexpectedURIFormat)
164155
}
165156
}
166157

0 commit comments

Comments
 (0)