Skip to content

Commit ea8d82e

Browse files
committed
Add automatic MTU detection
1 parent 33c14ca commit ea8d82e

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

Cargo.lock

+71
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

talpid-wireguard/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ chrono = { workspace = true, features = ["clock"] }
2828
tokio = { workspace = true, features = ["process", "rt-multi-thread", "fs"] }
2929
tunnel-obfuscation = { path = "../tunnel-obfuscation" }
3030
rand = "0.8.5"
31+
surge-ping = "0.8.0"
3132

3233
[target.'cfg(target_os="android")'.dependencies]
3334
duct = "0.13"

talpid-wireguard/src/lib.rs

+103
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ pub enum Error {
7474
#[error(display = "Failed to setup routing")]
7575
SetupRoutingError(#[error(source)] talpid_routing::Error),
7676

77+
/// Failed to set MTU
78+
#[error(display = "Failed to detect MTU because every ping was dropped.")]
79+
MtuDetectionAllDropped,
80+
81+
/// Failed to set MTU
82+
#[error(display = "Failed to detect MTU because of unexpected ping error.")]
83+
MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),
84+
7785
/// Tunnel timed out
7886
#[error(display = "Tunnel timed out")]
7987
TimeoutError,
@@ -949,6 +957,101 @@ impl WireguardMonitor {
949957
}
950958
}
951959

960+
/// Detects the maximum MTU that does not cause dropped packets.
961+
///
962+
/// The detection works by sending evenly spread out range of pings between 576 and the given
963+
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
964+
#[cfg(target_os = "linux")]
965+
async fn auto_mtu_detection(
966+
gateway: std::net::Ipv4Addr,
967+
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
968+
current_mtu: u16,
969+
) -> Result<u16> {
970+
use futures::{future, stream::FuturesUnordered, TryStreamExt};
971+
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
972+
use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
973+
use tokio_stream::StreamExt;
974+
975+
/// Max time to wait for any ping, when this expires, we give up and throw an error.
976+
const PING_TIMEOUT: Duration = Duration::from_secs(10);
977+
/// Max time to wait after the first ping arrives. Every ping after this timeout is considered
978+
/// dropped, so we return the largest collected packet size.
979+
const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);
980+
981+
let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
982+
#[cfg(any(target_os = "macos", target_os = "linux"))]
983+
let config_builder = config_builder.interface(&iface_name);
984+
let client = Client::new(&config_builder.build()).unwrap();
985+
986+
let step_size = 20;
987+
let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);
988+
989+
let payload_buf = vec![0; current_mtu as usize];
990+
991+
let mut ping_stream = linspace
992+
.iter()
993+
.enumerate()
994+
.map(|(i, &mtu)| {
995+
let client = client.clone();
996+
let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
997+
let payload = &payload_buf[0..payload_size];
998+
async move {
999+
log::trace!("Sending ICMP ping of total size {mtu}");
1000+
client
1001+
.pinger(IpAddr::V4(gateway), PingIdentifier(0))
1002+
.await
1003+
.timeout(PING_TIMEOUT)
1004+
.ping(PingSequence(i as u16), payload)
1005+
.await
1006+
}
1007+
})
1008+
.collect::<FuturesUnordered<_>>()
1009+
.map_ok(|(packet, _rtt)| {
1010+
let surge_ping::IcmpPacket::V4(packet) = packet else {
1011+
unreachable!("ICMP ping response was not of IPv4 type");
1012+
};
1013+
let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
1014+
log::trace!("Got ICMP ping response of total size {size}");
1015+
debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
1016+
size
1017+
});
1018+
1019+
let first_ping_size = ping_stream
1020+
.next()
1021+
.await
1022+
.expect("At least one pings should be sent")
1023+
// Short-circuit and return on error
1024+
.map_err(|e| match e {
1025+
// If the first ping we get back timed out, then all of them did
1026+
SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
1027+
// Unexpected error type
1028+
e => Error::MtuDetectionPingError(e),
1029+
})?;
1030+
1031+
ping_stream
1032+
.timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
1033+
.map_while(|res| res.ok()) // Stop waiting for pings after this timeout
1034+
.try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
1035+
.await
1036+
.map_err(Error::MtuDetectionPingError)
1037+
}
1038+
1039+
/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
1040+
/// points.
1041+
#[cfg(target_os = "linux")]
1042+
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
1043+
if mtu_min > mtu_max {
1044+
panic!("Invalid MTU detection range: `mtu_min`={mtu_min}, `mtu_max`={mtu_max}.");
1045+
}
1046+
let second_mtu = mtu_min.next_multiple_of(step_size);
1047+
let in_between = (second_mtu..mtu_max).step_by(step_size as usize);
1048+
let mut ret = Vec::with_capacity(((mtu_max - second_mtu).div_ceil(step_size) + 2) as usize);
1049+
ret.push(mtu_min);
1050+
ret.extend(in_between);
1051+
ret.push(mtu_max);
1052+
ret
1053+
}
1054+
9521055
#[derive(Debug)]
9531056
enum CloseMsg {
9541057
Stop,

0 commit comments

Comments
 (0)