Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic MTU detection linux #5736

Merged
merged 8 commits into from
Feb 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add automatic MTU detection
  • Loading branch information
Serock3 authored and dlon committed Feb 8, 2024
commit 62b803248405eeb17068e41ced9b8998c5d20095
71 changes: 71 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions talpid-wireguard/Cargo.toml
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ chrono = { workspace = true, features = ["clock"] }
tokio = { workspace = true, features = ["process", "rt-multi-thread", "fs"] }
tunnel-obfuscation = { path = "../tunnel-obfuscation" }
rand = "0.8.5"
surge-ping = "0.8.0"

[target.'cfg(target_os="android")'.dependencies]
duct = "0.13"
103 changes: 103 additions & 0 deletions talpid-wireguard/src/lib.rs
Original file line number Diff line number Diff line change
@@ -74,6 +74,14 @@ pub enum Error {
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] talpid_routing::Error),

/// Failed to set MTU
#[error(display = "Failed to detect MTU because every ping was dropped.")]
MtuDetectionAllDropped,

/// Failed to set MTU
#[error(display = "Failed to detect MTU because of unexpected ping error.")]
MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),

/// Tunnel timed out
#[error(display = "Tunnel timed out")]
TimeoutError,
@@ -949,6 +957,101 @@ impl WireguardMonitor {
}
}

/// Detects the maximum MTU that does not cause dropped packets.
///
/// The detection works by sending evenly spread out range of pings between 576 and the given
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
#[cfg(target_os = "linux")]
async fn auto_mtu_detection(
gateway: std::net::Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
current_mtu: u16,
) -> Result<u16> {
use futures::{future, stream::FuturesUnordered, TryStreamExt};
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
use tokio_stream::StreamExt;

/// Max time to wait for any ping, when this expires, we give up and throw an error.
const PING_TIMEOUT: Duration = Duration::from_secs(10);
/// Max time to wait after the first ping arrives. Every ping after this timeout is considered
/// dropped, so we return the largest collected packet size.
const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);

let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
#[cfg(any(target_os = "macos", target_os = "linux"))]
let config_builder = config_builder.interface(&iface_name);
let client = Client::new(&config_builder.build()).unwrap();

let step_size = 20;
let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);

let payload_buf = vec![0; current_mtu as usize];

let mut ping_stream = linspace
.iter()
.enumerate()
.map(|(i, &mtu)| {
let client = client.clone();
let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
let payload = &payload_buf[0..payload_size];
async move {
log::trace!("Sending ICMP ping of total size {mtu}");
client
.pinger(IpAddr::V4(gateway), PingIdentifier(0))
.await
.timeout(PING_TIMEOUT)
.ping(PingSequence(i as u16), payload)
.await
}
})
.collect::<FuturesUnordered<_>>()
.map_ok(|(packet, _rtt)| {
let surge_ping::IcmpPacket::V4(packet) = packet else {
unreachable!("ICMP ping response was not of IPv4 type");
};
let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
log::trace!("Got ICMP ping response of total size {size}");
debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
size
});

let first_ping_size = ping_stream
.next()
.await
.expect("At least one pings should be sent")
// Short-circuit and return on error
.map_err(|e| match e {
// If the first ping we get back timed out, then all of them did
SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
// Unexpected error type
e => Error::MtuDetectionPingError(e),
})?;

ping_stream
.timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
.map_while(|res| res.ok()) // Stop waiting for pings after this timeout
.try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
.await
.map_err(Error::MtuDetectionPingError)
}

/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
/// points.
#[cfg(target_os = "linux")]
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
if mtu_min > mtu_max {
panic!("Invalid MTU detection range: `mtu_min`={mtu_min}, `mtu_max`={mtu_max}.");
}
let second_mtu = mtu_min.next_multiple_of(step_size);
let in_between = (second_mtu..mtu_max).step_by(step_size as usize);
let mut ret = Vec::with_capacity(((mtu_max - second_mtu).div_ceil(step_size) + 2) as usize);
ret.push(mtu_min);
ret.extend(in_between);
ret.push(mtu_max);
ret
}

#[derive(Debug)]
enum CloseMsg {
Stop,