diff --git a/src/server/decode.rs b/src/server/decode.rs index 133b58e..65def6e 100644 --- a/src/server/decode.rs +++ b/src/server/decode.rs @@ -5,26 +5,29 @@ use std::str::FromStr; use async_dup::{Arc, Mutex}; use async_std::io::{BufReader, Read, Write}; use async_std::{prelude::*, task}; -use http_types::content::ContentLength; -use http_types::headers::{EXPECT, TRANSFER_ENCODING}; -use http_types::{ensure, ensure_eq, format_err}; +use http_types::{content::ContentLength, Version}; +use http_types::{ensure, format_err}; +use http_types::{ + headers::{EXPECT, TRANSFER_ENCODING}, + StatusCode, +}; use http_types::{Body, Method, Request, Url}; use super::body_reader::BodyReader; -use crate::chunked::ChunkedDecoder; use crate::read_notifier::ReadNotifier; +use crate::{chunked::ChunkedDecoder, ServerOptions}; use crate::{MAX_HEADERS, MAX_HEAD_LENGTH}; const LF: u8 = b'\n'; -/// The number returned from httparse when the request is HTTP 1.1 -const HTTP_1_1_VERSION: u8 = 1; - const CONTINUE_HEADER_VALUE: &str = "100-continue"; const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n"; /// Decode an HTTP request on the server. -pub async fn decode(mut io: IO) -> http_types::Result)>> +pub async fn decode( + mut io: IO, + opts: &ServerOptions, +) -> http_types::Result)>> where IO: Read + Write + Clone + Send + Sync + Unpin + 'static, { @@ -63,21 +66,22 @@ where let method = httparse_req.method; let method = method.ok_or_else(|| format_err!("No method found"))?; - let version = httparse_req.version; - let version = version.ok_or_else(|| format_err!("No version found"))?; - - ensure_eq!( - version, - HTTP_1_1_VERSION, - "Unsupported HTTP version 1.{}", - version - ); + let version = match (&opts.default_host, httparse_req.version) { + (Some(_), None) | (Some(_), Some(0)) => Version::Http1_0, + (_, Some(1)) => Version::Http1_1, + _ => { + let mut err = format_err!("http version not supported"); + err.set_status(StatusCode::HttpVersionNotSupported); + return Err(err); + } + }; - let url = url_from_httparse_req(&httparse_req)?; + let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref()) + .ok_or_else(|| format_err!("unable to construct url from request"))?; let mut req = Request::new(Method::from_str(method)?, url); - req.set_version(Some(http_types::Version::Http1_1)); + req.set_version(Some(version)); for header in httparse_req.headers.iter() { req.append_header(header.name, std::str::from_utf8(header.value)?); @@ -141,26 +145,27 @@ where } } -fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result { - let path = req.path.ok_or_else(|| format_err!("No uri found"))?; +fn url_from_httparse_req( + req: &httparse::Request<'_, '_>, + default_host: Option<&str>, +) -> Option { + let path = req.path?; let host = req .headers .iter() .find(|x| x.name.eq_ignore_ascii_case("host")) - .ok_or_else(|| format_err!("Mandatory Host header missing"))? - .value; - - let host = std::str::from_utf8(host)?; + .and_then(|x| std::str::from_utf8(x.value).ok()) + .or(default_host)?; if path.starts_with("http://") || path.starts_with("https://") { - Ok(Url::parse(path)?) + Url::parse(path).ok() } else if path.starts_with('/') { - Ok(Url::parse(&format!("http://{}{}", host, path))?) + Url::parse(&format!("http://{}{}", host, path)).ok() } else if req.method.unwrap().eq_ignore_ascii_case("connect") { - Ok(Url::parse(&format!("http://{}/", path))?) + Url::parse(&format!("http://{}/", path)).ok() } else { - Err(format_err!("unexpected uri format")) + None } } @@ -180,7 +185,7 @@ mod tests { httparse_req( "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/"); }, ); @@ -191,7 +196,7 @@ mod tests { httparse_req( "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/some/resource"); }, ) @@ -202,7 +207,7 @@ mod tests { httparse_req( "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://domain.com/some/resource"); // host header MUST be ignored according to spec }, ) @@ -213,7 +218,7 @@ mod tests { httparse_req( "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/"); }, ) @@ -224,7 +229,7 @@ mod tests { httparse_req( "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n", |req| { - assert!(url_from_httparse_req(&req).is_err()); + assert!(url_from_httparse_req(&req, None).is_none()); }, ) } @@ -234,7 +239,7 @@ mod tests { httparse_req( "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!( url.as_str(), "http://server.example.com:443//double/slashes" @@ -247,7 +252,7 @@ mod tests { httparse_req( "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!( url.as_str(), "http://server.example.com:443///triple/slashes" @@ -261,7 +266,7 @@ mod tests { httparse_req( "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1"); }, ) @@ -272,7 +277,7 @@ mod tests { httparse_req( "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n", |req| { - let url = url_from_httparse_req(&req).unwrap(); + let url = url_from_httparse_req(&req, None).unwrap(); assert_eq!( url.as_str(), "http://server.example.com:443/foo?bar=1#anchor" diff --git a/src/server/mod.rs b/src/server/mod.rs index c9304ad..f3eb671 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,8 +2,11 @@ use async_std::future::{timeout, Future, TimeoutError}; use async_std::io::{self, Read, Write}; -use http_types::headers::{CONNECTION, UPGRADE}; use http_types::upgrade::Connection; +use http_types::{ + headers::{CONNECTION, UPGRADE}, + Version, +}; use http_types::{Request, Response, StatusCode}; use std::{marker::PhantomData, time::Duration}; mod body_reader; @@ -18,12 +21,42 @@ pub use encode::Encoder; pub struct ServerOptions { /// Timeout to handle headers. Defaults to 60s. headers_timeout: Option, + default_host: Option, +} + +impl ServerOptions { + /// constructs a new ServerOptions with default settings + pub fn new() -> Self { + Self::default() + } + + /// sets the timeout by which the headers must have been received + pub fn with_headers_timeout(mut self, headers_timeout: Duration) -> Self { + self.headers_timeout = Some(headers_timeout); + self + } + + /// Sets the default http 1.0 host for this server. If no host + /// header is provided on an http/1.0 request, this host will be + /// used to construct the Request Url. + /// + /// If this is not provided, the server will respond to all + /// http/1.0 requests with status `505 http version not + /// supported`, whether or not a host header is provided. + /// + /// The default value for this is None, and as a result async-h1 + /// is by default an http-1.1-only server. + pub fn with_default_host(mut self, default_host: &str) -> Self { + self.default_host = Some(default_host.into()); + self + } } impl Default for ServerOptions { fn default() -> Self { Self { headers_timeout: Some(Duration::from_secs(60)), + default_host: None, } } } @@ -111,7 +144,7 @@ where Fut: Future, { // Decode a new request, timing out if this takes longer than the timeout duration. - let fut = decode(self.io.clone()); + let fut = decode(self.io.clone(), &self.opts); let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout { match timeout(timeout_duration, fut).await { @@ -133,7 +166,12 @@ where .unwrap_or(""); let connection_header_is_upgrade = connection_header_as_str.eq_ignore_ascii_case("upgrade"); - let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close"); + + let mut close_connection = if req.version() == Some(Version::Http1_0) { + !connection_header_as_str.eq_ignore_ascii_case("keep-alive") + } else { + connection_header_as_str.eq_ignore_ascii_case("close") + }; let upgrade_requested = has_upgrade_header && connection_header_is_upgrade; @@ -168,7 +206,7 @@ where if let Some(upgrade_sender) = upgrade_sender { upgrade_sender.send(Connection::new(self.io.clone())).await; - return Ok(ConnectionStatus::Close); + Ok(ConnectionStatus::Close) } else if close_connection { Ok(ConnectionStatus::Close) } else { diff --git a/tests/continue.rs b/tests/continue.rs index 933fbfe..2a84540 100644 --- a/tests/continue.rs +++ b/tests/continue.rs @@ -16,7 +16,9 @@ async fn test_with_expect_when_reading_body() -> Result<()> { let (mut client, server) = TestIO::new(); client.write_all(REQUEST_WITH_EXPECT).await?; - let (mut request, _) = async_h1::server::decode(server).await?.unwrap(); + let (mut request, _) = async_h1::server::decode(server, &Default::default()) + .await? + .unwrap(); task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written @@ -44,7 +46,9 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> { let (mut client, server) = TestIO::new(); client.write_all(REQUEST_WITH_EXPECT).await?; - let (_, _) = async_h1::server::decode(server).await?.unwrap(); + let (_, _) = async_h1::server::decode(server, &Default::default()) + .await? + .unwrap(); task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel diff --git a/tests/server-chunked-encode-large.rs b/tests/server-chunked-encode-large.rs index 3252b0c..df916a9 100644 --- a/tests/server-chunked-encode-large.rs +++ b/tests/server-chunked-encode-large.rs @@ -76,7 +76,7 @@ async fn server_chunked_large() -> Result<()> { let (mut client, server) = TestIO::new(); async_std::io::copy(&mut client::Encoder::new(request), &mut client).await?; - let (request, _) = server::decode(server).await?.unwrap(); + let (request, _) = server::decode(server, &Default::default()).await?.unwrap(); let mut response = Response::new(200); response.set_body(Body::from_reader(request, None)); diff --git a/tests/server_decode.rs b/tests/server_decode.rs index 10c6701..eafe689 100644 --- a/tests/server_decode.rs +++ b/tests/server_decode.rs @@ -1,26 +1,31 @@ mod test_utils; mod server_decode { use super::test_utils::TestIO; + use async_h1::ServerOptions; use async_std::io::prelude::*; - use http_types::headers::TRANSFER_ENCODING; use http_types::Request; use http_types::Result; use http_types::Url; + use http_types::{headers::TRANSFER_ENCODING, Version}; use pretty_assertions::assert_eq; - async fn decode_lines(lines: Vec<&str>) -> Result> { + async fn decode_lines(lines: Vec<&str>, options: ServerOptions) -> Result> { let s = lines.join("\r\n"); let (mut client, server) = TestIO::new(); client.write_all(s.as_bytes()).await?; client.close(); - async_h1::server::decode(server) + async_h1::server::decode(server, &options) .await .map(|r| r.map(|(r, _)| r)) } + async fn decode_lines_default(lines: Vec<&str>) -> Result> { + decode_lines(lines, ServerOptions::default()).await + } + #[async_std::test] async fn post_with_body() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "POST / HTTP/1.1", "host: localhost:8080", "content-length: 5", @@ -53,7 +58,7 @@ mod server_decode { #[async_std::test] async fn chunked() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "POST / HTTP/1.1", "host: localhost:8080", "transfer-encoding: chunked", @@ -83,7 +88,7 @@ mod server_decode { "#] #[async_std::test] async fn invalid_trailer() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "GET / HTTP/1.1", "host: domain.com", "content-type: application/octet-stream", @@ -104,7 +109,7 @@ mod server_decode { #[async_std::test] async fn unexpected_eof() -> Result<()> { - let mut request = decode_lines(vec![ + let mut request = decode_lines_default(vec![ "POST / HTTP/1.1", "host: example.com", "content-type: text/plain", @@ -125,4 +130,86 @@ mod server_decode { Ok(()) } + + #[async_std::test] + async fn http_1_0_without_host_header() -> Result<()> { + let request = decode_lines( + vec!["GET /path?query#fragment HTTP/1.0", "", ""], + ServerOptions::new().with_default_host("website.com"), + ) + .await? + .unwrap(); + + assert_eq!(request.version(), Some(Version::Http1_0)); + assert_eq!( + request.url().to_string(), + "http://website.com/path?query#fragment" + ); + Ok(()) + } + + #[async_std::test] + async fn http_1_1_without_host_header() -> Result<()> { + let result = decode_lines( + vec!["GET /path?query#fragment HTTP/1.1", "", ""], + ServerOptions::default(), + ) + .await; + + assert!(result.is_err()); + + Ok(()) + } + + #[async_std::test] + async fn http_1_0_with_host_header() -> Result<()> { + let request = decode_lines( + vec![ + "GET /path?query#fragment HTTP/1.0", + "host: example.com", + "", + "", + ], + ServerOptions::new().with_default_host("website.com"), + ) + .await? + .unwrap(); + + assert_eq!(request.version(), Some(Version::Http1_0)); + assert_eq!( + request.url().to_string(), + "http://example.com/path?query#fragment" + ); + Ok(()) + } + + #[async_std::test] + async fn http_1_0_request_with_no_default_host_is_provided() -> Result<()> { + let request = decode_lines( + vec!["GET /path?query#fragment HTTP/1.0", "", ""], + ServerOptions::default(), + ) + .await; + + assert!(request.is_err()); + Ok(()) + } + + #[async_std::test] + async fn http_1_0_request_with_no_default_host_is_provided_even_if_host_header_exists( + ) -> Result<()> { + let result = decode_lines( + vec![ + "GET /path?query#fragment HTTP/1.0", + "host: example.com", + "", + "", + ], + ServerOptions::default(), + ) + .await; + + assert!(result.is_err()); + Ok(()) + } }