From 5d04fbf3e7f6e0b195056933ba0aa979b7b42969 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Fri, 18 Dec 2020 15:43:58 -0800 Subject: [PATCH 1/3] change: async-h1 endpoints always return a response --- examples/server.rs | 4 ++-- src/lib.rs | 2 +- src/server/mod.rs | 10 +++++----- tests/accept.rs | 14 +++++++------- tests/test_utils.rs | 8 ++++---- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/server.rs b/examples/server.rs index 1ea9491..7081018 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -26,11 +26,11 @@ async fn main() -> http_types::Result<()> { // Take a TCP stream, and convert it into sequential HTTP request / response pairs. async fn accept(stream: TcpStream) -> http_types::Result<()> { println!("starting new connection from {}", stream.peer_addr()?); - async_h1::accept(stream.clone(), |_req| async move { + async_h1::accept(stream, |_req| async move { let mut res = Response::new(StatusCode::Ok); res.insert_header("Content-Type", "text/plain"); res.set_body("Hello world"); - Ok(res) + res }) .await?; Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 55d569b..9781be8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,7 +81,7 @@ //! let mut res = Response::new(StatusCode::Ok); //! res.insert_header("Content-Type", "text/plain"); //! res.set_body("Hello"); -//! Ok(res) +//! res //! }) //! .await?; //! Ok(()) diff --git a/src/server/mod.rs b/src/server/mod.rs index 1cfa4e9..8a96474 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -35,7 +35,7 @@ pub async fn accept(io: RW, endpoint: F) -> http_types::Result<()> where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { Server::new(io, endpoint).accept().await } @@ -51,7 +51,7 @@ pub async fn accept_with_opts( where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { Server::new(io, endpoint).with_opts(opts).accept().await } @@ -79,7 +79,7 @@ impl Server where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { /// builds a new server pub fn new(io: RW, endpoint: F) -> Self { @@ -108,7 +108,7 @@ where where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { // Decode a new request, timing out if this takes longer than the timeout duration. let fut = decode(self.io.clone()); @@ -142,7 +142,7 @@ where let method = req.method(); // Pass the request to the endpoint and encode the response. - let mut res = (self.endpoint)(req).await?; + let mut res = (self.endpoint)(req).await; close_connection |= res .header(CONNECTION) diff --git a/tests/accept.rs b/tests/accept.rs index 92283a8..58438b4 100644 --- a/tests/accept.rs +++ b/tests/accept.rs @@ -11,7 +11,7 @@ mod accept { let mut response = Response::new(200); let len = req.len(); response.set_body(Body::from_reader(req, len)); - Ok(response) + response }); let content_length = 10; @@ -35,7 +35,7 @@ mod accept { #[async_std::test] async fn request_close() -> Result<()> { - let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + let mut server = TestServer::new(|_| async { Response::new(200) }); server .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: Close\r\n\r\n") @@ -53,7 +53,7 @@ mod accept { let mut server = TestServer::new(|_| async { let mut response = Response::new(200); response.insert_header(CONNECTION, "close"); - Ok(response) + response }); server @@ -69,7 +69,7 @@ mod accept { #[async_std::test] async fn keep_alive_short_fixed_length_unread_body() -> Result<()> { - let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + let mut server = TestServer::new(|_| async { Response::new(200) }); let content_length = 10; @@ -95,7 +95,7 @@ mod accept { #[async_std::test] async fn keep_alive_short_chunked_unread_body() -> Result<()> { - let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + let mut server = TestServer::new(|_| async { Response::new(200) }); let content_length = 100; @@ -125,7 +125,7 @@ mod accept { #[async_std::test] async fn keep_alive_long_fixed_length_unread_body() -> Result<()> { - let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + let mut server = TestServer::new(|_| async { Response::new(200) }); let content_length = 10000; @@ -151,7 +151,7 @@ mod accept { #[async_std::test] async fn keep_alive_long_chunked_unread_body() -> Result<()> { - let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + let mut server = TestServer::new(|_| async { Response::new(200) }); let content_length = 10000; diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 8194590..d7fdd24 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -3,7 +3,7 @@ use async_h1::{ server::{ConnectionStatus, Server}, }; use async_std::io::{Read, Write}; -use http_types::{Request, Response, Result}; +use http_types::{Request, Response}; use std::{ fmt::{Debug, Display}, future::Future, @@ -25,7 +25,7 @@ pub struct TestServer { impl TestServer where F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { #[allow(dead_code)] pub fn new(f: F) -> Self { @@ -61,7 +61,7 @@ where impl Read for TestServer where F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { fn poll_read( self: Pin<&mut Self>, @@ -75,7 +75,7 @@ where impl Write for TestServer where F: Fn(Request) -> Fut, - Fut: Future>, + Fut: Future, { fn poll_write( self: Pin<&mut Self>, From d30fa6d58b60eebd1f0edf527a2f6c564a64324b Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Fri, 18 Dec 2020 18:20:35 -0800 Subject: [PATCH 2/3] add opt-in http/1.0 support --- src/server/decode.rs | 81 +++++++++++---------- src/server/mod.rs | 46 ++++++++++-- tests/continue.rs | 8 ++- tests/server-chunked-encode-large.rs | 2 +- tests/server_decode.rs | 101 +++++++++++++++++++++++++-- 5 files changed, 186 insertions(+), 52 deletions(-) diff --git a/src/server/decode.rs b/src/server/decode.rs index 133b58e..65def6e 100644 --- a/src/server/decode.rs +++ b/src/server/decode.rs @@ -5,26 +5,29 @@ use std::str::FromStr; use async_dup::{Arc, Mutex}; use async_std::io::{BufReader, Read, Write}; use async_std::{prelude::*, task}; -use http_types::content::ContentLength; -use http_types::headers::{EXPECT, TRANSFER_ENCODING}; -use http_types::{ensure, ensure_eq, format_err}; +use http_types::{content::ContentLength, Version}; +use http_types::{ensure, format_err}; +use http_types::{ + headers::{EXPECT, TRANSFER_ENCODING}, + StatusCode, +}; use http_types::{Body, Method, Request, Url}; use super::body_reader::BodyReader; -use crate::chunked::ChunkedDecoder; use crate::read_notifier::ReadNotifier; +use crate::{chunked::ChunkedDecoder, ServerOptions}; use crate::{MAX_HEADERS, MAX_HEAD_LENGTH}; const LF: u8 = b'\n'; -/// The number returned from httparse when the request is HTTP 1.1 -const HTTP_1_1_VERSION: u8 = 1; - const CONTINUE_HEADER_VALUE: &str = "100-continue"; const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n"; /// Decode an HTTP request on the server. -pub async fn decode(mut io: IO) -> http_types::Result)>> +pub async fn decode( + mut io: IO, + opts: &ServerOptions, +) -> http_types::Result)>> where IO: Read + Write + Clone + Send + Sync + Unpin + 'static, { @@ -63,21 +66,22 @@ where let method = httparse_req.method; let method = method.ok_or_else(|| format_err!("No method found"))?; - let version = httparse_req.version; - let version = version.ok_or_else(|| format_err!("No version found"))?; - - ensure_eq!( - version, - HTTP_1_1_VERSION, - "Unsupported HTTP version 1.{}", - version - ); + let version = match (&opts.default_host, httparse_req.version) { + (Some(_), None) | (Some(_), Some(0)) => Version::Http1_0, + (_, Some(1)) => Version::Http1_1, + _ => { + let mut err = format_err!("http version not supported"); + err.set_status(StatusCode::HttpVersionNotSupported); + return Err(err); + } + }; - let url = url_from_httparse_req(&httparse_req)?; + let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref()) + .ok_or_else(|| format_err!("unable to construct url from request"))?; let mut req = Request::new(Method::from_str(method)?, url); - req.set_version(Some(http_types::Version::Http1_1)); + req.set_version(Some(version)); for header in httparse_req.headers.iter() { req.append_header(header.name, std::str::from_utf8(header.value)?); @@ -141,26 +145,27 @@ where } } -fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result { - let path = req.path.ok_or_else(|| format_err!("No uri found"))?; +fn url_from_httparse_req( + req: &httparse::Request<'_, '_>, + default_host: Option<&str>, +) -> Option { + let path = req.path?; let host = req .headers .iter() .find(|x| x.name.eq_ignore_ascii_case("host")) - .ok_or_else(|| format_err!("Mandatory Host header missing"))? - .value; - - let host = std::str::from_utf8(host)?; + .and_then(|x| std::str::from_utf8(x.value).ok()) + .or(default_host)?; if path.starts_with("http://") || path.starts_with("https://") { - Ok(Url::parse(path)?) + Url::parse(path).ok() } else if path.starts_with('/') { - Ok(Url::parse(&format!("http://{}{}", host, path))?) + Url::parse(&format!("http://{}{}", host, path)).ok() } else if req.method.unwrap().eq_ignore_ascii_case("connect") { - Ok(Url::parse(&format!("http://{}/", path))?) + Url::parse(&format!("http://{}/", path)).ok() } else { - Err(format_err!("unexpected uri format")) + None } } @@ -180,7 +185,7 @@ mod tests { httparse_req( "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/"); }, ); @@ -191,7 +196,7 @@ mod tests { httparse_req( "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/some/resource"); }, ) @@ -202,7 +207,7 @@ mod tests { httparse_req( "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://domain.com/some/resource"); // host header MUST be ignored according to spec }, ) @@ -213,7 +218,7 @@ mod tests { httparse_req( "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/"); }, ) @@ -224,7 +229,7 @@ mod tests { httparse_req( "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n", |req| { - assert!(url_from_httparse_req(&req).is_err()); + assert!(url_from_httparse_req(&req, None).is_none()); }, ) } @@ -234,7 +239,7 @@ mod tests { httparse_req( "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!( url.as_str(), "http://server.example.com:443//double/slashes" @@ -247,7 +252,7 @@ mod tests { httparse_req( "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!( url.as_str(), "http://server.example.com:443///triple/slashes" @@ -261,7 +266,7 @@ mod tests { httparse_req( "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1"); }, ) @@ -272,7 +277,7 @@ mod tests { httparse_req( "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!( url.as_str(), "http://server.example.com:443/foo?bar=1#anchor" diff --git a/src/server/mod.rs b/src/server/mod.rs index 8a96474..dac1204 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,8 +2,11 @@ use async_std::future::{timeout, Future, TimeoutError}; use async_std::io::{self, Read, Write}; -use http_types::headers::{CONNECTION, UPGRADE}; use http_types::upgrade::Connection; +use http_types::{ + headers::{CONNECTION, UPGRADE}, + Version, +}; use http_types::{Request, Response, StatusCode}; use std::{marker::PhantomData, time::Duration}; mod body_reader; @@ -18,12 +21,42 @@ pub use encode::Encoder; pub struct ServerOptions { /// Timeout to handle headers. Defaults to 60s. headers_timeout: Option, + default_host: Option, +} + +impl ServerOptions { + /// constructs a new ServerOptions with default settings + pub fn new() -> Self { + Self::default() + } + + /// sets the timeout by which the headers must have been received + pub fn with_headers_timeout(mut self, headers_timeout: Duration) -> Self { + self.headers_timeout = Some(headers_timeout); + self + } + + /// Sets the default http 1.0 host for this server. If no host + /// header is provided on an http/1.0 request, this host will be + /// used to construct the Request Url. + /// + /// If this is not provided, the server will respond to all + /// http/1.0 requests with status `505 http version not + /// supported`, whether or not a host header is provided. + /// + /// The default value for this is None, and as a result async-h1 + /// is by default an http-1.1-only server. + pub fn with_default_host(mut self, default_host: &str) -> Self { + self.default_host = Some(default_host.into()); + self + } } impl Default for ServerOptions { fn default() -> Self { Self { headers_timeout: Some(Duration::from_secs(60)), + default_host: None, } } } @@ -111,7 +144,7 @@ where Fut: Future, { // Decode a new request, timing out if this takes longer than the timeout duration. - let fut = decode(self.io.clone()); + let fut = decode(self.io.clone(), &self.opts); let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout { match timeout(timeout_duration, fut).await { @@ -135,7 +168,12 @@ where let connection_header_is_upgrade = connection_header_as_str .split(',') .any(|s| s.trim().eq_ignore_ascii_case("upgrade")); - let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close"); + + let mut close_connection = if req.version() == Some(Version::Http1_0) { + !connection_header_as_str.eq_ignore_ascii_case("keep-alive") + } else { + connection_header_as_str.eq_ignore_ascii_case("close") + }; let upgrade_requested = has_upgrade_header && connection_header_is_upgrade; @@ -170,7 +208,7 @@ where if let Some(upgrade_sender) = upgrade_sender { upgrade_sender.send(Connection::new(self.io.clone())).await; - return Ok(ConnectionStatus::Close); + Ok(ConnectionStatus::Close) } else if close_connection { Ok(ConnectionStatus::Close) } else { diff --git a/tests/continue.rs b/tests/continue.rs index 933fbfe..2a84540 100644 --- a/tests/continue.rs +++ b/tests/continue.rs @@ -16,7 +16,9 @@ async fn test_with_expect_when_reading_body() -> Result<()> { let (mut client, server) = TestIO::new(); client.write_all(REQUEST_WITH_EXPECT).await?; - let (mut request, _) = async_h1::server::decode(server).await?.unwrap(); + let (mut request, _) = async_h1::server::decode(server, &Default::default()) + .await? + .unwrap(); task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written @@ -44,7 +46,9 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> { let (mut client, server) = TestIO::new(); client.write_all(REQUEST_WITH_EXPECT).await?; - let (_, _) = async_h1::server::decode(server).await?.unwrap(); + let (_, _) = async_h1::server::decode(server, &Default::default()) + .await? + .unwrap(); task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel diff --git a/tests/server-chunked-encode-large.rs b/tests/server-chunked-encode-large.rs index 3252b0c..df916a9 100644 --- a/tests/server-chunked-encode-large.rs +++ b/tests/server-chunked-encode-large.rs @@ -76,7 +76,7 @@ async fn server_chunked_large() -> Result<()> { let (mut client, server) = TestIO::new(); async_std::io::copy(&mut client::Encoder::new(request), &mut client).await?; - let (request, _) = server::decode(server).await?.unwrap(); + let (request, _) = server::decode(server, &Default::default()).await?.unwrap(); let mut response = Response::new(200); response.set_body(Body::from_reader(request, None)); diff --git a/tests/server_decode.rs b/tests/server_decode.rs index 10c6701..eafe689 100644 --- a/tests/server_decode.rs +++ b/tests/server_decode.rs @@ -1,26 +1,31 @@ mod test_utils; mod server_decode { use super::test_utils::TestIO; + use async_h1::ServerOptions; use async_std::io::prelude::*; - use http_types::headers::TRANSFER_ENCODING; use http_types::Request; use http_types::Result; use http_types::Url; + use http_types::{headers::TRANSFER_ENCODING, Version}; use pretty_assertions::assert_eq; - async fn decode_lines(lines: Vec<&str>) -> Result> { + async fn decode_lines(lines: Vec<&str>, options: ServerOptions) -> Result> { let s = lines.join("\r\n"); let (mut client, server) = TestIO::new(); client.write_all(s.as_bytes()).await?; client.close(); - async_h1::server::decode(server) + async_h1::server::decode(server, &options) .await .map(|r| r.map(|(r, _)| r)) } + async fn decode_lines_default(lines: Vec<&str>) -> Result> { + decode_lines(lines, ServerOptions::default()).await + } + #[async_std::test] async fn post_with_body() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "POST / HTTP/1.1", "host: localhost:8080", "content-length: 5", @@ -53,7 +58,7 @@ mod server_decode { #[async_std::test] async fn chunked() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "POST / HTTP/1.1", "host: localhost:8080", "transfer-encoding: chunked", @@ -83,7 +88,7 @@ mod server_decode { "#] #[async_std::test] async fn invalid_trailer() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "GET / HTTP/1.1", "host: domain.com", "content-type: application/octet-stream", @@ -104,7 +109,7 @@ mod server_decode { #[async_std::test] async fn unexpected_eof() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "POST / HTTP/1.1", "host: example.com", "content-type: text/plain", @@ -125,4 +130,86 @@ mod server_decode { Ok(()) } + + #[async_std::test] + async fn http_1_0_without_host_header() -> Result<()> { + let request = decode_lines( + vec!["GET /path?query#fragment HTTP/1.0", "", ""], + ServerOptions::new().with_default_host("website.com"), + ) + .await? + .unwrap(); + + assert_eq!(request.version(), Some(Version::Http1_0)); + assert_eq!( + request.url().to_string(), + "http://website.com/path?query#fragment" + ); + Ok(()) + } + + #[async_std::test] + async fn http_1_1_without_host_header() -> Result<()> { + let result = decode_lines( + vec!["GET /path?query#fragment HTTP/1.1", "", ""], + ServerOptions::default(), + ) + .await; + + assert!(result.is_err()); + + Ok(()) + } + + #[async_std::test] + async fn http_1_0_with_host_header() -> Result<()> { + let request = decode_lines( + vec![ + "GET /path?query#fragment HTTP/1.0", + "host: example.com", + "", + "", + ], + ServerOptions::new().with_default_host("website.com"), + ) + .await? + .unwrap(); + + assert_eq!(request.version(), Some(Version::Http1_0)); + assert_eq!( + request.url().to_string(), + "http://example.com/path?query#fragment" + ); + Ok(()) + } + + #[async_std::test] + async fn http_1_0_request_with_no_default_host_is_provided() -> Result<()> { + let request = decode_lines( + vec!["GET /path?query#fragment HTTP/1.0", "", ""], + ServerOptions::default(), + ) + .await; + + assert!(request.is_err()); + Ok(()) + } + + #[async_std::test] + async fn http_1_0_request_with_no_default_host_is_provided_even_if_host_header_exists( + ) -> Result<()> { + let result = decode_lines( + vec![ + "GET /path?query#fragment HTTP/1.0", + "host: example.com", + "", + "", + ], + ServerOptions::default(), + ) + .await; + + assert!(result.is_err()); + Ok(()) + } } From 695ba391fb539075a847e6a623ba74f471690390 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Fri, 18 Dec 2020 15:15:12 -0800 Subject: [PATCH 3/3] use a concrete error type in async-h1 --- Cargo.toml | 1 + src/client/decode.rs | 59 ++++++++++-------- src/client/mod.rs | 2 +- src/error.rs | 84 ++++++++++++++++++++++++++ src/lib.rs | 2 + src/server/decode.rs | 90 ++++++++++++++-------------- src/server/mod.rs | 19 +++--- tests/client_decode.rs | 19 +++--- tests/server-chunked-encode-large.rs | 2 +- tests/server_decode.rs | 4 +- tests/test_utils.rs | 2 +- 11 files changed, 189 insertions(+), 95 deletions(-) create mode 100644 src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 16e9ea7..4911848 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ log = "0.4.11" pin-project = "1.0.2" async-channel = "1.5.1" async-dup = "1.2.2" +thiserror = "1.0.22" [dev-dependencies] pretty_assertions = "0.6.1" diff --git a/src/client/decode.rs b/src/client/decode.rs index 9ef6317..5c190cc 100644 --- a/src/client/decode.rs +++ b/src/client/decode.rs @@ -1,22 +1,22 @@ use async_std::io::{BufReader, Read}; use async_std::prelude::*; -use http_types::{ensure, ensure_eq, format_err}; +use http_types::content::ContentLength; use http_types::{ - headers::{CONTENT_LENGTH, DATE, TRANSFER_ENCODING}, + headers::{DATE, TRANSFER_ENCODING}, Body, Response, StatusCode, }; use std::convert::TryFrom; -use crate::chunked::ChunkedDecoder; use crate::date::fmt_http_date; +use crate::{chunked::ChunkedDecoder, Error}; use crate::{MAX_HEADERS, MAX_HEAD_LENGTH}; const CR: u8 = b'\r'; const LF: u8 = b'\n'; /// Decode an HTTP response on the client. -pub async fn decode(reader: R) -> http_types::Result +pub async fn decode(reader: R) -> crate::Result> where R: Read + Unpin + Send + Sync + 'static, { @@ -29,13 +29,14 @@ where loop { let bytes_read = reader.read_until(LF, &mut buf).await?; // No more bytes are yielded from the stream. - assert!(bytes_read != 0, "Empty response"); // TODO: ensure? + if bytes_read == 0 { + return Ok(None); + } // Prevent CWE-400 DDOS with large HTTP Headers. - ensure!( - buf.len() < MAX_HEAD_LENGTH, - "Head byte length should be less than 8kb" - ); + if buf.len() >= MAX_HEAD_LENGTH { + return Err(Error::HeadersTooLong); + } // We've hit the end delimiter of the stream. let idx = buf.len() - 1; @@ -49,17 +50,23 @@ where // Convert our header buf into an httparse instance, and validate. let status = httparse_res.parse(&buf)?; - ensure!(!status.is_partial(), "Malformed HTTP head"); + if status.is_partial() { + return Err(Error::PartialHead); + } - let code = httparse_res.code; - let code = code.ok_or_else(|| format_err!("No status code found"))?; + let code = httparse_res.code.ok_or(Error::MissingStatusCode)?; // Convert httparse headers + body into a `http_types::Response` type. - let version = httparse_res.version; - let version = version.ok_or_else(|| format_err!("No version found"))?; - ensure_eq!(version, 1, "Unsupported HTTP version"); + let version = httparse_res.version.ok_or(Error::MissingVersion)?; + + if version != 1 { + return Err(Error::UnsupportedVersion(version)); + } + + let status_code = + StatusCode::try_from(code).map_err(|_| Error::UnrecognizedStatusCode(code))?; + let mut res = Response::new(status_code); - let mut res = Response::new(StatusCode::try_from(code)?); for header in httparse_res.headers.iter() { res.append_header(header.name, std::str::from_utf8(header.value)?); } @@ -69,13 +76,13 @@ where res.insert_header(DATE, &format!("date: {}\r\n", date)[..]); } - let content_length = res.header(CONTENT_LENGTH); + let content_length = + ContentLength::from_headers(&res).map_err(|_| Error::MalformedHeader("content-length"))?; let transfer_encoding = res.header(TRANSFER_ENCODING); - ensure!( - content_length.is_none() || transfer_encoding.is_none(), - "Unexpected Content-Length header" - ); + if content_length.is_some() && transfer_encoding.is_some() { + return Err(Error::UnexpectedHeader("content-length")); + } if let Some(encoding) = transfer_encoding { if encoding.last().as_str() == "chunked" { @@ -84,16 +91,16 @@ where res.set_body(Body::from_reader(reader, None)); // Return the response. - return Ok(res); + return Ok(Some(res)); } } // Check for Content-Length. - if let Some(len) = content_length { - let len = len.last().as_str().parse::()?; - res.set_body(Body::from_reader(reader.take(len as u64), Some(len))); + if let Some(content_length) = content_length { + let len = content_length.len(); + res.set_body(Body::from_reader(reader.take(len), Some(len as usize))); } // Return the response. - Ok(res) + Ok(Some(res)) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 5e5edad..5fe1f10 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -10,7 +10,7 @@ pub use decode::decode; pub use encode::Encoder; /// Opens an HTTP/1.1 connection to a remote host. -pub async fn connect(mut stream: RW, req: Request) -> http_types::Result +pub async fn connect(mut stream: RW, req: Request) -> crate::Result> where RW: Read + Write + Send + Sync + Unpin + 'static, { diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..7233dc7 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,84 @@ +use std::str::Utf8Error; + +use http_types::url; +use thiserror::Error; + +/// Concrete errors that occur within async-h1 +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum Error { + /// [`std::io::Error`] + #[error(transparent)] + IO(#[from] std::io::Error), + + /// [`url::ParseError`] + #[error(transparent)] + Url(#[from] url::ParseError), + + /// this error describes a malformed request with a path that does + /// not start with / or http:// or https:// + #[error("unexpected uri format")] + UnexpectedURIFormat, + + /// this error describes a http 1.1 request that is missing a Host + /// header + #[error("mandatory host header missing")] + HostHeaderMissing, + + /// this error describes a request that does not specify a path + #[error("request path missing")] + RequestPathMissing, + + /// [`httparse::Error`] + #[error(transparent)] + Httparse(#[from] httparse::Error), + + /// an incomplete http head + #[error("partial http head")] + PartialHead, + + /// we were unable to parse a header + #[error("malformed http header {0}")] + MalformedHeader(&'static str), + + /// async-h1 doesn't speak this http version + /// this error is deprecated + #[error("unsupported http version 1.{0}")] + UnsupportedVersion(u8), + + /// we were unable to parse this http method + #[error("unsupported http method {0}")] + UnrecognizedMethod(String), + + /// this request did not have a method + #[error("missing method")] + MissingMethod, + + /// this request did not have a status code + #[error("missing status code")] + MissingStatusCode, + + /// we were unable to parse this http method + #[error("unrecognized http status code {0}")] + UnrecognizedStatusCode(u16), + + /// this request did not have a version, but we expect one + /// this error is deprecated + #[error("missing version")] + MissingVersion, + + /// we expected utf8, but there was an encoding error + #[error(transparent)] + EncodingError(#[from] Utf8Error), + + /// we received a header that does not make sense in context + #[error("unexpected header: {0}")] + UnexpectedHeader(&'static str), + + /// for security reasons, we do not allow request headers beyond 8kb. + #[error("Head byte length should be less than 8kb")] + HeadersTooLong, +} + +/// this crate's result type +pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index 9781be8..8436912 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,6 +116,8 @@ use async_std::io::Cursor; use body_encoder::BodyEncoder; pub use client::connect; pub use server::{accept, accept_with_opts, ServerOptions}; +mod error; +pub use error::{Error, Result}; #[derive(Debug)] pub(crate) enum EncoderState { diff --git a/src/server/decode.rs b/src/server/decode.rs index 65def6e..27d1357 100644 --- a/src/server/decode.rs +++ b/src/server/decode.rs @@ -1,17 +1,17 @@ //! Process HTTP connections on the server. -use std::str::FromStr; - use async_dup::{Arc, Mutex}; use async_std::io::{BufReader, Read, Write}; use async_std::{prelude::*, task}; -use http_types::{content::ContentLength, Version}; -use http_types::{ensure, format_err}; + use http_types::{ + content::ContentLength, headers::{EXPECT, TRANSFER_ENCODING}, - StatusCode, + Version, }; -use http_types::{Body, Method, Request, Url}; +use http_types::{Body, Request, Url}; + +use crate::{Error, Result}; use super::body_reader::BodyReader; use crate::read_notifier::ReadNotifier; @@ -24,10 +24,11 @@ const CONTINUE_HEADER_VALUE: &str = "100-continue"; const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n"; /// Decode an HTTP request on the server. + pub async fn decode( mut io: IO, opts: &ServerOptions, -) -> http_types::Result)>> +) -> Result)>> where IO: Read + Write + Clone + Send + Sync + Unpin + 'static, { @@ -45,10 +46,9 @@ where } // Prevent CWE-400 DDOS with large HTTP Headers. - ensure!( - buf.len() < MAX_HEAD_LENGTH, - "Head byte length should be less than 8kb" - ); + if buf.len() >= MAX_HEAD_LENGTH { + return Err(Error::HeadersTooLong); + } // We've hit the end delimiter of the stream. let idx = buf.len() - 1; @@ -60,26 +60,27 @@ where // Convert our header buf into an httparse instance, and validate. let status = httparse_req.parse(&buf)?; - ensure!(!status.is_partial(), "Malformed HTTP head"); + if status.is_partial() { + return Err(Error::PartialHead); + } // Convert httparse headers + body into a `http_types::Request` type. - let method = httparse_req.method; - let method = method.ok_or_else(|| format_err!("No method found"))?; + let method = httparse_req + .method + .ok_or(Error::MissingMethod)? + .parse() + .map_err(|_| Error::UnrecognizedMethod(httparse_req.method.unwrap().to_string()))?; let version = match (&opts.default_host, httparse_req.version) { (Some(_), None) | (Some(_), Some(0)) => Version::Http1_0, (_, Some(1)) => Version::Http1_1, - _ => { - let mut err = format_err!("http version not supported"); - err.set_status(StatusCode::HttpVersionNotSupported); - return Err(err); - } + (None, Some(0)) | (None, None) => return Err(Error::HostHeaderMissing), + (_, Some(other_version)) => return Err(Error::UnsupportedVersion(other_version)), }; - let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref()) - .ok_or_else(|| format_err!("unable to construct url from request"))?; + let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref())?; - let mut req = Request::new(Method::from_str(method)?, url); + let mut req = Request::new(method, url); req.set_version(Some(version)); @@ -87,18 +88,13 @@ where req.append_header(header.name, std::str::from_utf8(header.value)?); } - let content_length = ContentLength::from_headers(&req)?; + let content_length = + ContentLength::from_headers(&req).map_err(|_| Error::MalformedHeader("content-length"))?; let transfer_encoding = req.header(TRANSFER_ENCODING); - // Return a 400 status if both Content-Length and Transfer-Encoding headers - // are set to prevent request smuggling attacks. - // - // https://tools.ietf.org/html/rfc7230#section-3.3.3 - http_types::ensure_status!( - content_length.is_none() || transfer_encoding.is_none(), - 400, - "Unexpected Content-Length header" - ); + if content_length.is_some() && transfer_encoding.is_some() { + return Err(Error::UnexpectedHeader("content-length")); + } // Establish a channel to wait for the body to be read. This // allows us to avoid sending 100-continue in situations that @@ -132,8 +128,8 @@ where let reader = BufReader::new(reader); req.set_body(Body::from_reader(reader, None)); return Ok(Some((req, BodyReader::Chunked(reader_clone)))); - } else if let Some(len) = content_length { - let len = len.len(); + } else if let Some(content_length) = content_length { + let len = content_length.len(); let reader = Arc::new(Mutex::new(reader.take(len))); req.set_body(Body::from_reader( BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)), @@ -148,24 +144,27 @@ where fn url_from_httparse_req( req: &httparse::Request<'_, '_>, default_host: Option<&str>, -) -> Option { - let path = req.path?; +) -> Result { + let path = req.path.ok_or(Error::RequestPathMissing)?; let host = req .headers .iter() - .find(|x| x.name.eq_ignore_ascii_case("host")) - .and_then(|x| std::str::from_utf8(x.value).ok()) - .or(default_host)?; + .find(|x| x.name.eq_ignore_ascii_case("host")); + + let host = match host { + Some(header) => std::str::from_utf8(header.value)?, + None => default_host.ok_or(Error::HostHeaderMissing)?, + }; if path.starts_with("http://") || path.starts_with("https://") { - Url::parse(path).ok() + Ok(Url::parse(path)?) } else if path.starts_with('/') { - Url::parse(&format!("http://{}{}", host, path)).ok() + Ok(Url::parse(&format!("http://{}{}", host, path))?) } else if req.method.unwrap().eq_ignore_ascii_case("connect") { - Url::parse(&format!("http://{}/", path)).ok() + Ok(Url::parse(&format!("http://{}/", path))?) } else { - None + Err(Error::UnexpectedURIFormat) } } @@ -229,7 +228,10 @@ mod tests { httparse_req( "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n", |req| { - assert!(url_from_httparse_req(&req, None).is_none()); + assert!(matches!( + url_from_httparse_req(&req, None), + Err(Error::UnexpectedURIFormat) + )); }, ) } diff --git a/src/server/mod.rs b/src/server/mod.rs index dac1204..99838d9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -64,7 +64,7 @@ impl Default for ServerOptions { /// Accept a new incoming HTTP/1.1 connection. /// /// Supports `KeepAlive` requests by default. -pub async fn accept(io: RW, endpoint: F) -> http_types::Result<()> +pub async fn accept(io: RW, endpoint: F) -> crate::Result<()> where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, @@ -80,7 +80,7 @@ pub async fn accept_with_opts( io: RW, endpoint: F, opts: ServerOptions, -) -> http_types::Result<()> +) -> crate::Result<()> where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, @@ -131,13 +131,13 @@ where } /// accept in a loop - pub async fn accept(&mut self) -> http_types::Result<()> { + pub async fn accept(&mut self) -> crate::Result<()> { while ConnectionStatus::KeepAlive == self.accept_one().await? {} Ok(()) } /// accept one request - pub async fn accept_one(&mut self) -> http_types::Result + pub async fn accept_one(&mut self) -> crate::Result where RW: Read + Write + Clone + Send + Sync + Unpin + 'static, F: Fn(Request) -> Fut, @@ -180,22 +180,23 @@ where let method = req.method(); // Pass the request to the endpoint and encode the response. - let mut res = (self.endpoint)(req).await; + let mut response = (self.endpoint)(req).await; - close_connection |= res + close_connection |= response .header(CONNECTION) .map(|c| c.as_str().eq_ignore_ascii_case("close")) .unwrap_or(false); - let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade(); + let upgrade_provided = + response.status() == StatusCode::SwitchingProtocols && response.has_upgrade(); let upgrade_sender = if upgrade_requested && upgrade_provided { - Some(res.send_upgrade()) + Some(response.send_upgrade()) } else { None }; - let mut encoder = Encoder::new(res, method); + let mut encoder = Encoder::new(response, method); let bytes_written = io::copy(&mut encoder, &mut self.io).await?; log::trace!("wrote {} response bytes", bytes_written); diff --git a/tests/client_decode.rs b/tests/client_decode.rs index 65cd1f5..269cd65 100644 --- a/tests/client_decode.rs +++ b/tests/client_decode.rs @@ -6,7 +6,7 @@ mod client_decode { use http_types::Result; use pretty_assertions::assert_eq; - async fn decode_lines(s: Vec<&str>) -> Result { + async fn decode_lines(s: Vec<&str>) -> async_h1::Result> { client::decode(Cursor::new(s.join("\r\n"))).await } @@ -19,7 +19,8 @@ mod client_decode { "", "", ]) - .await?; + .await? + .unwrap(); assert_eq!(res.header(&headers::DATE).is_some(), true); Ok(()) @@ -36,7 +37,8 @@ mod client_decode { "", "", ]) - .await?; + .await? + .unwrap(); assert_eq!(res.header(&headers::SET_COOKIE).unwrap().iter().count(), 2); Ok(()) @@ -53,15 +55,10 @@ mod client_decode { "http specifies headers are separated with \r\n but many servers don't do that", "", ]) - .await?; + .await? + .unwrap(); - assert_eq!( - res[headers::CONTENT_LENGTH] - .as_str() - .parse::() - .unwrap(), - 78 - ); + assert_eq!(res[headers::CONTENT_LENGTH], "78"); Ok(()) } diff --git a/tests/server-chunked-encode-large.rs b/tests/server-chunked-encode-large.rs index df916a9..43e97d5 100644 --- a/tests/server-chunked-encode-large.rs +++ b/tests/server-chunked-encode-large.rs @@ -83,7 +83,7 @@ async fn server_chunked_large() -> Result<()> { let response_encoder = server::Encoder::new(response, Method::Get); - let mut response = client::decode(response_encoder).await?; + let mut response = client::decode(response_encoder).await?.unwrap(); assert_eq!(response.body_string().await?, BODY); Ok(()) diff --git a/tests/server_decode.rs b/tests/server_decode.rs index eafe689..3fd39b7 100644 --- a/tests/server_decode.rs +++ b/tests/server_decode.rs @@ -14,9 +14,9 @@ mod server_decode { let (mut client, server) = TestIO::new(); client.write_all(s.as_bytes()).await?; client.close(); - async_h1::server::decode(server, &options) + Ok(async_h1::server::decode(server, &options) .await - .map(|r| r.map(|(r, _)| r)) + .map(|r| r.map(|(r, _)| r))?) } async fn decode_lines_default(lines: Vec<&str>) -> Result> { diff --git a/tests/test_utils.rs b/tests/test_utils.rs index d7fdd24..6814c6a 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -37,7 +37,7 @@ where } #[allow(dead_code)] - pub async fn accept_one(&mut self) -> http_types::Result { + pub async fn accept_one(&mut self) -> async_h1::Result { self.server.accept_one().await }