From 0e4847bfebac67efdedd14335e9936f16bf0187a Mon Sep 17 00:00:00 2001 From: Daniel Sharifi Date: Thu, 5 Sep 2024 09:10:22 +0000 Subject: [PATCH] . --- .../quic_transport/src/connection_handle.rs | 55 +++++++++++++++---- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/rs/p2p/quic_transport/src/connection_handle.rs b/rs/p2p/quic_transport/src/connection_handle.rs index 43f9518055b..db76735cf32 100644 --- a/rs/p2p/quic_transport/src/connection_handle.rs +++ b/rs/p2p/quic_transport/src/connection_handle.rs @@ -11,7 +11,7 @@ use bytes::Bytes; use ic_base_types::NodeId; use ic_protobuf::transport::v1 as pb; use prost::Message; -use quinn::{Connection, RecvStream, SendStream}; +use quinn::{Connection, RecvStream, SendStream, VarInt}; use crate::{ metrics::{ @@ -21,6 +21,36 @@ use crate::{ ConnId, MessagePriority, MAX_MESSAGE_SIZE_BYTES, }; +/// Drop guard to send a [`SendStream::reset`] frame on drop. QUINN sends a [`SendStream::finish`] frame by default when dropping a [`SendStream`], +/// which can lead to the peer receiving the stream thinking a complete message was sent. This guard is used to send a reset frame instead, to signal +/// that the transmission of the message was cancelled. +struct SendStreamDropGuard { + send_stream: SendStream, + armed: bool, +} + +impl SendStreamDropGuard { + /// Disarm the guard, preventing it from sending a reset frame on drop. + fn disarm(mut self) { + self.armed = false; + } + + fn new(send_stream: SendStream) -> Self { + Self { + send_stream, + armed: true, + } + } +} + +impl Drop for SendStreamDropGuard { + fn drop(&mut self) { + if self.armed { + let _ = self.send_stream.reset(VarInt::from_u32(0)); + } + } +} + #[derive(Clone, Debug)] pub(crate) struct ConnectionHandle { pub peer_id: NodeId, @@ -66,13 +96,16 @@ impl ConnectionHandle { .connection_handle_bytes_received_total .with_label_values(&[request.uri().path()]); - let (mut send_stream, recv_stream) = self.connection.open_bi().await.map_err(|err| { + let (send_stream, recv_stream) = self.connection.open_bi().await.map_err(|err| { self.metrics .connection_handle_errors_total .with_label_values(&[REQUEST_TYPE_RPC, ERROR_TYPE_OPEN]); err })?; + let mut send_stream_guard = SendStreamDropGuard::new(send_stream); + let send_stream = &mut send_stream_guard.send_stream; + let priority = request .extensions() .get::() @@ -80,15 +113,13 @@ impl ConnectionHandle { .unwrap_or_default(); let _ = send_stream.set_priority(priority.into()); - write_request(&mut send_stream, request) - .await - .map_err(|err| { - self.metrics - .connection_handle_errors_total - .with_label_values(&[REQUEST_TYPE_RPC, ERROR_TYPE_WRITE]) - .inc(); - err - })?; + write_request(send_stream, request).await.map_err(|err| { + self.metrics + .connection_handle_errors_total + .with_label_values(&[REQUEST_TYPE_RPC, ERROR_TYPE_WRITE]) + .inc(); + err + })?; send_stream.finish().map_err(|err| { self.metrics @@ -116,7 +147,7 @@ impl ConnectionHandle { // Propagate PeerId from this request to upper layers. response.extensions_mut().insert(self.peer_id); - + send_stream_guard.disarm(); in_counter.inc_by(response.body().len() as u64); Ok(response) }