From 6be05759ea00c7947a85c06cf72ccaf7d2e3f922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20=C5=81abor?= Date: Mon, 3 Mar 2025 20:35:05 -0500 Subject: [PATCH] Handle stream error correctly --- tests/integration_tests/tests/status.rs | 90 +++++++++++++++++++++++++ tonic/src/codec/decode.rs | 10 ++- 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 84337d130..d75052b80 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -6,9 +6,12 @@ use integration_tests::pb::{ test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, OutputStream, }; +use integration_tests::BoxFuture; use std::error::Error; +use std::task::{Context, Poll}; use std::time::Duration; use tokio::{net::TcpListener, sync::oneshot}; +use tonic::body::Body; use tonic::metadata::{MetadataMap, MetadataValue}; use tonic::{ transport::{server::TcpIncoming, Endpoint, Server}, @@ -209,6 +212,93 @@ async fn status_from_server_stream_with_source() { source.downcast_ref::().unwrap(); } +#[tokio::test] +async fn status_from_server_stream_with_inferred_status() { + integration_tests::trace_init(); + + struct Svc; + + #[tonic::async_trait] + impl test_stream_server::TestStream for Svc { + type StreamCallStream = Stream; + + async fn stream_call( + &self, + _: Request, + ) -> Result, Status> { + let s = tokio_stream::once(Ok(OutputStream {})); + Ok(Response::new(Box::pin(s) as Self::StreamCallStream)) + } + } + + #[derive(Clone)] + struct TestLayer; + + impl tower::Layer for TestLayer { + type Service = TestService; + + fn layer(&self, _: S) -> Self::Service { + TestService + } + } + + #[derive(Clone)] + struct TestService; + + impl tower::Service> for TestService { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: http::Request) -> Self::Future { + Box::pin(async { + Ok(http::Response::builder() + .status(http::StatusCode::BAD_GATEWAY) + .body(Body::empty()) + .unwrap()) + }) + } + } + + let svc = test_stream_server::TestStreamServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming: TcpIncoming = TcpIncoming::from(listener).with_nodelay(Some(true)); + + tokio::spawn(async move { + Server::builder() + .layer(TestLayer) + .add_service(svc) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = test_stream_client::TestStreamClient::connect(format!("http://{addr}")) + .await + .unwrap(); + + let mut stream = client + .stream_call(InputStream {}) + .await + .unwrap() + .into_inner(); + + assert_eq!( + stream.message().await.unwrap_err().code(), + Code::Unavailable + ); + + assert_eq!(stream.message().await.unwrap(), None); +} + #[tokio::test] async fn message_and_then_status_from_server_stream() { integration_tests::trace_init(); diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 1e6876fc8..4742291fd 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -400,14 +400,12 @@ impl Stream for Streaming { } if ready!(self.inner.poll_frame(cx))?.is_none() { - break; + match self.inner.response() { + Ok(()) => return Poll::Ready(None), + Err(err) => self.inner.state = State::Error(Some(err)), + } } } - - Poll::Ready(match self.inner.response() { - Ok(()) => None, - Err(err) => Some(Err(err)), - }) } }