diff --git a/src/error.rs b/src/error.rs index 9ad4c0e5b3..48917db970 100644 --- a/src/error.rs +++ b/src/error.rs @@ -240,6 +240,10 @@ impl Error { /// Returns true if the error was caused by a timeout. pub fn is_timeout(&self) -> bool { + #[cfg(all(feature = "http1", feature = "server"))] + if matches!(self.inner.kind, Kind::HeaderTimeout) { + return true; + } self.find_source::().is_some() } diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 59b2eb66f0..bea8faa221 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -1122,6 +1122,14 @@ impl State { if !T::should_read_first() { self.notify_read = true; } + + #[cfg(feature = "server")] + if self.h1_header_read_timeout.is_some() { + // Next read will start and poll the header read timeout, + // so we can close the connection if another header isn't + // received in a timely manner. + self.notify_read = true; + } } fn is_idle(&self) -> bool { diff --git a/tests/server.rs b/tests/server.rs index 253868c844..2ba6f92ca3 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1504,7 +1504,6 @@ async fn header_read_timeout_slow_writes() { tcp.write_all( b"\ Something: 1\r\n\ - \r\n\ ", ) .expect("write 2"); @@ -1512,6 +1511,7 @@ async fn header_read_timeout_slow_writes() { tcp.write_all( b"\ Works: 0\r\n\ + \r\n\ ", ) .expect_err("write 3"); @@ -1553,7 +1553,7 @@ async fn header_read_timeout_starts_immediately() { .timer(TokioTimer) .header_read_timeout(Duration::from_secs(2)) .serve_connection(socket, unreachable_service()); - conn.await.expect_err("header timeout"); + assert!(conn.await.unwrap_err().is_timeout()); } #[tokio::test] @@ -1601,7 +1601,6 @@ async fn header_read_timeout_slow_writes_multiple_requests() { b"\ GET / HTTP/1.1\r\n\ Something: 1\r\n\ - \r\n\ ", ) .expect("write 5"); @@ -1609,6 +1608,7 @@ async fn header_read_timeout_slow_writes_multiple_requests() { tcp.write_all( b"\ Works: 0\r\n\ + \r\n ", ) .expect_err("write 6"); @@ -1629,7 +1629,51 @@ async fn header_read_timeout_slow_writes_multiple_requests() { future::ready(Ok::<_, hyper::Error>(res)) }), ); - conn.without_shutdown().await.expect_err("header timeout"); + assert!(conn.without_shutdown().await.unwrap_err().is_timeout()); +} + +#[tokio::test] +async fn header_read_timeout_as_idle_timeout() { + let (listener, addr) = setup_tcp_listener(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + \r\n\ + ", + ) + .expect("request 1"); + + thread::sleep(Duration::from_secs(6)); + + tcp.write_all( + b"\ + GET / HTTP/1.1\r\n\ + \r\n\ + ", + ) + .expect_err("request 2"); + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + let conn = http1::Builder::new() + .timer(TokioTimer) + .header_read_timeout(Duration::from_secs(3)) + .serve_connection( + socket, + service_fn(|_| { + let res = Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(); + future::ready(Ok::<_, hyper::Error>(res)) + }), + ); + assert!(conn.without_shutdown().await.unwrap_err().is_timeout()); } #[tokio::test]