diff --git a/src/substream/mod.rs b/src/substream/mod.rs index f4e9f9a8..f3d4615d 100644 --- a/src/substream/mod.rs +++ b/src/substream/mod.rs @@ -645,6 +645,13 @@ impl Stream for Substream { } this.offset = 0; + // Handle empty payloads detected as 0-length frame. + // The offset must be cleared to 0 to not interfere + // with next framing. + if size == 0 { + return Poll::Ready(Some(Ok(BytesMut::new()))); + } + this.current_frame_size = Some(size); this.read_buffer = BytesMut::zeroed(size); } diff --git a/tests/protocol/notification.rs b/tests/protocol/notification.rs index 10bf748c..afb6d74d 100644 --- a/tests/protocol/notification.rs +++ b/tests/protocol/notification.rs @@ -2009,6 +2009,280 @@ async fn dialer_fallback_protocol_works(transport1: Transport, transport2: Trans ); } +#[tokio::test] +async fn zero_byte_handshake_tcp() { + // Full node role. + zero_byte_handshake( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + vec![1], + ) + .await; + + // Invalid role set as `ObservedRole::NONE`. + zero_byte_handshake( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + vec![0], + ) + .await; + + // Light client role provided by smoldot. + zero_byte_handshake( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + vec![], + ) + .await; +} + +#[cfg(feature = "quic")] +#[tokio::test] +async fn zero_byte_handshake_quic() { + zero_byte_handshake( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + vec![1], + ) + .await; + + zero_byte_handshake( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + vec![0], + ) + .await; + + zero_byte_handshake( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + vec![], + ) + .await; +} + +#[cfg(feature = "websocket")] +#[tokio::test] +async fn zero_byte_handshake_websocket() { + // Full node role. + zero_byte_handshake( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + vec![1], + ) + .await; + + // Invalid role set as `ObservedRole::NONE`. + zero_byte_handshake( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + vec![0], + ) + .await; + + // Light client role provided by smoldot. + zero_byte_handshake( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + vec![], + ) + .await; +} + +async fn zero_byte_handshake(transport1: Transport, transport2: Transport, handshake: Vec) { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(handshake.clone()) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = add_transport(config1, transport1).build(); + + let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(handshake.clone()) + .build(); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = add_transport(config2, transport2).build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + tracing::info!("Opening substream handle1 => handle2"); + handle1.open_substream(peer2).await.unwrap(); + + tracing::info!("Expecting validate substream event..."); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: handshake.clone(), + } + ); + + tracing::info!("Send validation result... peer2 => peer1"); + handle2.send_validation_result(peer1, ValidationResult::Accept); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: handshake.clone(), + } + ); + + tracing::info!("Send validation result... peer1 => peer2"); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + tracing::info!("Handle2 expecting notification stream opened event..."); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer1, + handshake: handshake.clone(), + } + ); + + tracing::info!("Handle1 expecting notification stream opened event..."); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: handshake, + } + ); + + // This step ensures we have not messed with the notification frames. + tracing::info!("Send sync notification..."); + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); + + // Ensure the handle can send empty notifications. + tracing::info!("Send empty sync notification..."); + handle1.send_sync_notification(peer2, vec![]).unwrap(); + handle2.send_sync_notification(peer1, vec![]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[][..]), + } + ); + + // Double check non-empty notifications. + tracing::info!("Send sync notification..."); + handle1.send_sync_notification(peer2, vec![1, 3, 3, 9]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 4]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 9][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 4][..]), + } + ); +} + #[tokio::test] async fn listener_fallback_protocol_works_tcp() { listener_fallback_protocol_works(