Skip to content

Commit d8014bb

Browse files
committed
Move MTU detection to separate module
1 parent 175f8dd commit d8014bb

File tree

2 files changed

+211
-181
lines changed

2 files changed

+211
-181
lines changed

talpid-wireguard/src/lib.rs

+14-181
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ use talpid_routing as routing;
2626
use talpid_routing::{self, RequiredRoute};
2727
#[cfg(not(windows))]
2828
use talpid_tunnel::tun_provider;
29-
#[cfg(not(target_os = "android"))]
30-
use talpid_tunnel::IPV4_HEADER_SIZE;
3129
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
3230

3331
use ipnetwork::IpNetwork;
@@ -59,6 +57,9 @@ pub(crate) mod wireguard_kernel;
5957
#[cfg(windows)]
6058
mod wireguard_nt;
6159

60+
#[cfg(not(target_os = "android"))]
61+
mod mtu_detection;
62+
6263
#[cfg(wireguard_go)]
6364
use self::wireguard_go::WgGoTunnel;
6465

@@ -73,19 +74,6 @@ pub enum Error {
7374
#[error(display = "Failed to setup routing")]
7475
SetupRoutingError(#[error(source)] talpid_routing::Error),
7576

76-
/// Failed to set MTU
77-
#[error(display = "Failed to detect MTU because every ping was dropped.")]
78-
MtuDetectionAllDropped,
79-
80-
/// Failed to set MTU
81-
#[error(display = "Failed to detect MTU because of unexpected ping error.")]
82-
MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),
83-
84-
/// Failed to set MTU
85-
#[cfg(target_os = "macos")]
86-
#[error(display = "Failed to set buffer size")]
87-
MtuSetBufferSize(#[error(source)] nix::Error),
88-
8977
/// Tunnel timed out
9078
#[error(display = "Tunnel timed out")]
9179
TimeoutError,
@@ -396,45 +384,19 @@ impl WireguardMonitor {
396384

397385
#[cfg(not(target_os = "android"))]
398386
if detect_mtu {
399-
let iface_name_clone = iface_name.clone();
387+
let iface_name = iface_name.clone();
388+
let config = config.clone();
400389
tokio::task::spawn(async move {
401-
log::debug!("Starting MTU detection");
402-
let verified_mtu = match auto_mtu_detection(
403-
gateway,
404-
#[cfg(any(target_os = "macos", target_os = "linux"))]
405-
iface_name_clone.clone(),
406-
config.mtu,
407-
)
408-
.await
390+
if let Err(e) =
391+
mtu_detection::automatic_mtu_correction(gateway, iface_name, &config).await
409392
{
410-
Ok(mtu) => mtu,
411-
Err(e) => {
412-
log::error!("{}", e.display_chain_with_msg("Failed to detect MTU"));
413-
return;
414-
}
415-
};
416-
417-
if verified_mtu != config.mtu {
418-
log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
419-
#[cfg(any(target_os = "linux", target_os = "macos"))]
420-
let res = unix::set_mtu(&iface_name_clone, verified_mtu);
421-
#[cfg(windows)]
422-
let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then(
423-
|luid| {
424-
talpid_windows::net::set_mtu(
425-
luid,
426-
verified_mtu as u32,
427-
config.ipv6_gateway.is_some(),
428-
)
429-
},
393+
log::error!(
394+
"{}",
395+
e.display_chain_with_msg(
396+
"Failed to automatically adjust MTU based on dropped packets"
397+
)
430398
);
431-
432-
if let Err(e) = res {
433-
log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
434-
};
435-
} else {
436-
log::debug!("MTU {verified_mtu} verified to not drop packets");
437-
}
399+
};
438400
});
439401
}
440402
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
@@ -956,7 +918,7 @@ impl WireguardMonitor {
956918

957919
#[cfg(any(target_os = "linux", target_os = "macos"))]
958920
fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute {
959-
use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
921+
use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
960922

961923
if !config.is_multihop() {
962924
route
@@ -1009,135 +971,6 @@ impl WireguardMonitor {
1009971
}
1010972
}
1011973

1012-
/// Detects the maximum MTU that does not cause dropped packets.
1013-
///
1014-
/// The detection works by sending evenly spread out range of pings between 576 and the given
1015-
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
1016-
#[cfg(not(target_os = "android"))]
1017-
async fn auto_mtu_detection(
1018-
gateway: std::net::Ipv4Addr,
1019-
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
1020-
current_mtu: u16,
1021-
) -> Result<u16> {
1022-
use futures::{future, stream::FuturesUnordered, TryStreamExt};
1023-
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
1024-
use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
1025-
use tokio_stream::StreamExt;
1026-
1027-
/// Max time to wait for any ping, when this expires, we give up and throw an error.
1028-
const PING_TIMEOUT: Duration = Duration::from_secs(10);
1029-
/// Max time to wait after the first ping arrives. Every ping after this timeout is considered
1030-
/// dropped, so we return the largest collected packet size.
1031-
const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
1032-
1033-
let step_size = 20;
1034-
let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);
1035-
1036-
let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
1037-
#[cfg(any(target_os = "macos", target_os = "linux"))]
1038-
let config_builder = config_builder.interface(&iface_name);
1039-
let client = Client::new(&config_builder.build()).unwrap();
1040-
// For macos, the default socket receive buffer size seems to be too small to handle the data we
1041-
// are sending here. The consequence will be dropped packets causing the MTU detection to set a
1042-
// low value. Here we manually increase this value, which fixes the problem.
1043-
// TODO: Make sure this fix is not needed for any other target OS
1044-
#[cfg(target_os = "macos")]
1045-
{
1046-
use nix::sys::socket::{setsockopt, sockopt};
1047-
let fd = client.get_socket().get_native_sock();
1048-
let buf_size = linspace.iter().map(|sz| usize::from(*sz)).sum();
1049-
setsockopt(fd, sockopt::SndBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
1050-
setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
1051-
}
1052-
1053-
let payload_buf = vec![0; current_mtu as usize];
1054-
1055-
let mut ping_stream = linspace
1056-
.iter()
1057-
.enumerate()
1058-
.map(|(i, &mtu)| {
1059-
let client = client.clone();
1060-
let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
1061-
let payload = &payload_buf[0..payload_size];
1062-
async move {
1063-
log::trace!("Sending ICMP ping of total size {mtu}");
1064-
client
1065-
.pinger(IpAddr::V4(gateway), PingIdentifier(0))
1066-
.await
1067-
.timeout(PING_TIMEOUT)
1068-
.ping(PingSequence(i as u16), payload)
1069-
.await
1070-
}
1071-
})
1072-
.collect::<FuturesUnordered<_>>()
1073-
.map_ok(|(packet, _rtt)| {
1074-
let surge_ping::IcmpPacket::V4(packet) = packet else {
1075-
unreachable!("ICMP ping response was not of IPv4 type");
1076-
};
1077-
let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
1078-
log::trace!("Got ICMP ping response of total size {size}");
1079-
debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
1080-
size
1081-
});
1082-
1083-
let first_ping_size = ping_stream
1084-
.next()
1085-
.await
1086-
.expect("At least one pings should be sent")
1087-
// Short-circuit and return on error
1088-
.map_err(|e| match e {
1089-
// If the first ping we get back timed out, then all of them did
1090-
SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
1091-
// Unexpected error type
1092-
e => Error::MtuDetectionPingError(e),
1093-
})?;
1094-
1095-
ping_stream
1096-
.timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
1097-
.map_while(|res| res.ok()) // Stop waiting for pings after this timeout
1098-
.try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
1099-
.await
1100-
.map_err(Error::MtuDetectionPingError)
1101-
}
1102-
1103-
/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
1104-
/// points.
1105-
#[cfg(not(target_os = "android"))]
1106-
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
1107-
assert!(mtu_min < mtu_max);
1108-
assert!(step_size < mtu_max);
1109-
assert_ne!(step_size, 0);
1110-
1111-
let second_mtu = (mtu_min + 1).next_multiple_of(step_size);
1112-
let in_between = (second_mtu..mtu_max).step_by(step_size as usize);
1113-
1114-
let mut ret = Vec::with_capacity(in_between.clone().count() + 2);
1115-
ret.push(mtu_min);
1116-
ret.extend(in_between);
1117-
ret.push(mtu_max);
1118-
ret
1119-
}
1120-
1121-
#[cfg(all(test, not(target_os = "android")))]
1122-
mod tests {
1123-
use crate::mtu_spacing;
1124-
use proptest::prelude::*;
1125-
1126-
proptest! {
1127-
#[test]
1128-
fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
1129-
let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);
1130-
1131-
prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
1132-
prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
1133-
prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
1134-
let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
1135-
prop_assert!(diffs.all(|diff| diff <= step_size));
1136-
1137-
}
1138-
}
1139-
}
1140-
1141974
#[derive(Debug)]
1142975
enum CloseMsg {
1143976
Stop,

0 commit comments

Comments
 (0)