Skip to content

Optional http/1.0 support #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 43 additions & 38 deletions src/server/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,29 @@ use std::str::FromStr;
use async_dup::{Arc, Mutex};
use async_std::io::{BufReader, Read, Write};
use async_std::{prelude::*, task};
use http_types::content::ContentLength;
use http_types::headers::{EXPECT, TRANSFER_ENCODING};
use http_types::{ensure, ensure_eq, format_err};
use http_types::{content::ContentLength, Version};
use http_types::{ensure, format_err};
use http_types::{
headers::{EXPECT, TRANSFER_ENCODING},
StatusCode,
};
use http_types::{Body, Method, Request, Url};

use super::body_reader::BodyReader;
use crate::chunked::ChunkedDecoder;
use crate::read_notifier::ReadNotifier;
use crate::{chunked::ChunkedDecoder, ServerOptions};
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};

const LF: u8 = b'\n';

/// The number returned from httparse when the request is HTTP 1.1
const HTTP_1_1_VERSION: u8 = 1;

const CONTINUE_HEADER_VALUE: &str = "100-continue";
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";

/// Decode an HTTP request on the server.
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
pub async fn decode<IO>(
mut io: IO,
opts: &ServerOptions,
) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
where
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
{
Expand Down Expand Up @@ -63,21 +66,22 @@ where
let method = httparse_req.method;
let method = method.ok_or_else(|| format_err!("No method found"))?;

let version = httparse_req.version;
let version = version.ok_or_else(|| format_err!("No version found"))?;

ensure_eq!(
version,
HTTP_1_1_VERSION,
"Unsupported HTTP version 1.{}",
version
);
let version = match (&opts.default_host, httparse_req.version) {
(Some(_), None) | (Some(_), Some(0)) => Version::Http1_0,
(_, Some(1)) => Version::Http1_1,
_ => {
let mut err = format_err!("http version not supported");
err.set_status(StatusCode::HttpVersionNotSupported);
return Err(err);
}
};

let url = url_from_httparse_req(&httparse_req)?;
let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref())
.ok_or_else(|| format_err!("unable to construct url from request"))?;

let mut req = Request::new(Method::from_str(method)?, url);

req.set_version(Some(http_types::Version::Http1_1));
req.set_version(Some(version));

for header in httparse_req.headers.iter() {
req.append_header(header.name, std::str::from_utf8(header.value)?);
Expand Down Expand Up @@ -141,26 +145,27 @@ where
}
}

fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {
let path = req.path.ok_or_else(|| format_err!("No uri found"))?;
fn url_from_httparse_req(
req: &httparse::Request<'_, '_>,
default_host: Option<&str>,
) -> Option<Url> {
let path = req.path?;

let host = req
.headers
.iter()
.find(|x| x.name.eq_ignore_ascii_case("host"))
.ok_or_else(|| format_err!("Mandatory Host header missing"))?
.value;

let host = std::str::from_utf8(host)?;
.and_then(|x| std::str::from_utf8(x.value).ok())
.or(default_host)?;

if path.starts_with("http://") || path.starts_with("https://") {
Ok(Url::parse(path)?)
Url::parse(path).ok()
} else if path.starts_with('/') {
Ok(Url::parse(&format!("http://{}{}", host, path))?)
Url::parse(&format!("http://{}{}", host, path)).ok()
} else if req.method.unwrap().eq_ignore_ascii_case("connect") {
Ok(Url::parse(&format!("http://{}/", path))?)
Url::parse(&format!("http://{}/", path)).ok()
} else {
Err(format_err!("unexpected uri format"))
None
}
}

Expand All @@ -180,7 +185,7 @@ mod tests {
httparse_req(
"CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(url.as_str(), "http://server.example.com:443/");
},
);
Expand All @@ -191,7 +196,7 @@ mod tests {
httparse_req(
"GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
},
)
Expand All @@ -202,7 +207,7 @@ mod tests {
httparse_req(
"GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(url.as_str(), "http://domain.com/some/resource"); // host header MUST be ignored according to spec
},
)
Expand All @@ -213,7 +218,7 @@ mod tests {
httparse_req(
"CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(url.as_str(), "http://server.example.com:443/");
},
)
Expand All @@ -224,7 +229,7 @@ mod tests {
httparse_req(
"GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
|req| {
assert!(url_from_httparse_req(&req).is_err());
assert!(url_from_httparse_req(&req, None).is_none());
},
)
}
Expand All @@ -234,7 +239,7 @@ mod tests {
httparse_req(
"GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(
url.as_str(),
"http://server.example.com:443//double/slashes"
Expand All @@ -247,7 +252,7 @@ mod tests {
httparse_req(
"GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(
url.as_str(),
"http://server.example.com:443///triple/slashes"
Expand All @@ -261,7 +266,7 @@ mod tests {
httparse_req(
"GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
},
)
Expand All @@ -272,7 +277,7 @@ mod tests {
httparse_req(
"GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
|req| {
let url = url_from_httparse_req(&req).unwrap();
let url = url_from_httparse_req(&req, None).unwrap();
assert_eq!(
url.as_str(),
"http://server.example.com:443/foo?bar=1#anchor"
Expand Down
46 changes: 42 additions & 4 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

use async_std::future::{timeout, Future, TimeoutError};
use async_std::io::{self, Read, Write};
use http_types::headers::{CONNECTION, UPGRADE};
use http_types::upgrade::Connection;
use http_types::{
headers::{CONNECTION, UPGRADE},
Version,
};
use http_types::{Request, Response, StatusCode};
use std::{marker::PhantomData, time::Duration};
mod body_reader;
Expand All @@ -18,12 +21,42 @@ pub use encode::Encoder;
pub struct ServerOptions {
/// Timeout to handle headers. Defaults to 60s.
headers_timeout: Option<Duration>,
default_host: Option<String>,
}

impl ServerOptions {
/// constructs a new ServerOptions with default settings
pub fn new() -> Self {
Self::default()
}

/// sets the timeout by which the headers must have been received
pub fn with_headers_timeout(mut self, headers_timeout: Duration) -> Self {
self.headers_timeout = Some(headers_timeout);
self
}

/// Sets the default http 1.0 host for this server. If no host
/// header is provided on an http/1.0 request, this host will be
/// used to construct the Request Url.
///
/// If this is not provided, the server will respond to all
/// http/1.0 requests with status `505 http version not
/// supported`, whether or not a host header is provided.
Comment on lines +39 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 -- I really like how this works! This can very naturally translate to a CLI flag which sets the default URL, or a config which mandates the hostname is set. This feels exactly right.

///
/// The default value for this is None, and as a result async-h1
/// is by default an http-1.1-only server.
pub fn with_default_host(mut self, default_host: &str) -> Self {
self.default_host = Some(default_host.into());
self
}
Comment on lines +49 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: I'm increasingly convinced that async-h1 could be nicer to use if we had concrete Client and Server structs, and not the intermediate short-hand constructors. Rather than using a separate Options struct, exposing the host param through the constructor may be easier:

let mut server = Server::new();
assert_eq!(server.host(), None);

let mut server = Server::with_host("website.com");
assert_eq!(server.host(), Some("website.com"));

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entirely agreed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's been my intent for a few PRs — I've been leaving the functions around until the interface for the structs feels right. My hunch is that implementing future for the struct will simplify a bunch of things

}

impl Default for ServerOptions {
fn default() -> Self {
Self {
headers_timeout: Some(Duration::from_secs(60)),
default_host: None,
}
}
}
Expand Down Expand Up @@ -111,7 +144,7 @@ where
Fut: Future<Output = Response>,
{
// Decode a new request, timing out if this takes longer than the timeout duration.
let fut = decode(self.io.clone());
let fut = decode(self.io.clone(), &self.opts);

let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
match timeout(timeout_duration, fut).await {
Expand All @@ -133,7 +166,12 @@ where
.unwrap_or("");

let connection_header_is_upgrade = connection_header_as_str.eq_ignore_ascii_case("upgrade");
let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");

let mut close_connection = if req.version() == Some(Version::Http1_0) {
!connection_header_as_str.eq_ignore_ascii_case("keep-alive")
} else {
connection_header_as_str.eq_ignore_ascii_case("close")
};

let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;

Expand Down Expand Up @@ -168,7 +206,7 @@ where

if let Some(upgrade_sender) = upgrade_sender {
upgrade_sender.send(Connection::new(self.io.clone())).await;
return Ok(ConnectionStatus::Close);
Ok(ConnectionStatus::Close)
} else if close_connection {
Ok(ConnectionStatus::Close)
} else {
Expand Down
8 changes: 6 additions & 2 deletions tests/continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ async fn test_with_expect_when_reading_body() -> Result<()> {
let (mut client, server) = TestIO::new();
client.write_all(REQUEST_WITH_EXPECT).await?;

let (mut request, _) = async_h1::server::decode(server).await?.unwrap();
let (mut request, _) = async_h1::server::decode(server, &Default::default())
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written

Expand Down Expand Up @@ -44,7 +46,9 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> {
let (mut client, server) = TestIO::new();
client.write_all(REQUEST_WITH_EXPECT).await?;

let (_, _) = async_h1::server::decode(server).await?.unwrap();
let (_, _) = async_h1::server::decode(server, &Default::default())
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel

Expand Down
2 changes: 1 addition & 1 deletion tests/server-chunked-encode-large.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async fn server_chunked_large() -> Result<()> {
let (mut client, server) = TestIO::new();
async_std::io::copy(&mut client::Encoder::new(request), &mut client).await?;

let (request, _) = server::decode(server).await?.unwrap();
let (request, _) = server::decode(server, &Default::default()).await?.unwrap();

let mut response = Response::new(200);
response.set_body(Body::from_reader(request, None));
Expand Down
Loading