Skip to content

Commit 51693fe

Browse files
Serock3dlon
authored andcommitted
Add windows MTU detection implementation.
1 parent 4011170 commit 51693fe

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

talpid-core/src/tunnel/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,14 @@ impl TunnelMonitor {
176176
.map(|mtu| Self::clamp_mtu(params, mtu))
177177
.unwrap_or(default_mtu);
178178

179-
#[cfg(target_os = "linux")]
179+
#[cfg(any(target_os = "linux", windows))]
180180
let detect_mtu = params.options.mtu.is_none();
181181

182182
let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?;
183183
let monitor = talpid_wireguard::WireguardMonitor::start(
184184
config,
185185
params.options.quantum_resistant,
186-
#[cfg(target_os = "linux")]
186+
#[cfg(any(target_os = "linux", windows))]
187187
detect_mtu,
188188
log.as_deref(),
189189
args,

talpid-windows/src/net.rs

+22
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,28 @@ pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Resul
332332
win32_err!(unsafe { CreateUnicastIpAddressEntry(&row) }).map_err(Error::CreateUnicastEntry)
333333
}
334334

335+
/// Sets MTU on the specified network interface identified by `luid`.
336+
pub fn set_mtu(luid: NET_LUID_LH, mtu: u32, use_ipv6: bool) -> io::Result<()> {
337+
let ip_families: &[AddressFamily] = if use_ipv6 {
338+
&[AddressFamily::Ipv4, AddressFamily::Ipv6]
339+
} else {
340+
&[AddressFamily::Ipv4]
341+
};
342+
for family in ip_families {
343+
let mut row = match get_ip_interface_entry(*family, &luid) {
344+
Ok(row) => row,
345+
Err(error) if error.raw_os_error() == Some(ERROR_NOT_FOUND as i32) => continue,
346+
Err(error) => return Err(error),
347+
};
348+
349+
row.NlMtu = mtu;
350+
351+
set_ip_interface_entry(&mut row)?;
352+
}
353+
354+
Ok(())
355+
}
356+
335357
/// Returns the unicast IP address table. If `family` is `None`, then addresses for all families are
336358
/// returned.
337359
pub fn get_unicast_table(

talpid-wireguard/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ duct = "0.13"
3737
byteorder = "1"
3838
internet-checksum = "0.2"
3939
socket2 = { version = "0.5.3", features = ["all"] }
40+
tokio-stream = { version = "0.1", features = ["io-util"] }
4041

4142
[target.'cfg(unix)'.dependencies]
4243
nix = "0.23"
@@ -48,7 +49,6 @@ netlink-packet-route = "0.13"
4849
netlink-packet-utils = "0.5.1"
4950
netlink-proto = "0.10"
5051
talpid-dbus = { path = "../talpid-dbus" }
51-
tokio-stream = { version = "0.1", features = ["io-util"] }
5252

5353
[target.'cfg(windows)'.dependencies]
5454
bitflags = "1.2"

talpid-wireguard/src/lib.rs

+28-11
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
#![deny(missing_docs)]
44

55
use self::config::Config;
6-
use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
76
#[cfg(windows)]
8-
use futures::{channel::mpsc, StreamExt};
7+
use futures::channel::mpsc;
8+
use futures::future::{abortable, AbortHandle as FutureAbortHandle, BoxFuture, Future};
99
#[cfg(target_os = "linux")]
1010
use once_cell::sync::Lazy;
1111
#[cfg(target_os = "android")]
@@ -26,6 +26,8 @@ use talpid_routing as routing;
2626
use talpid_routing::{self, RequiredRoute};
2727
#[cfg(not(windows))]
2828
use talpid_tunnel::tun_provider;
29+
#[cfg(not(target_os = "android"))]
30+
use talpid_tunnel::IPV4_HEADER_SIZE;
2931
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
3032

3133
use ipnetwork::IpNetwork;
@@ -42,9 +44,6 @@ use tunnel_obfuscation::{
4244
create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings,
4345
};
4446

45-
#[cfg(any(target_os = "linux", target_os = "macos"))]
46-
use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
47-
4847
/// WireGuard config data-types
4948
pub mod config;
5049
mod connectivity_check;
@@ -270,7 +269,7 @@ impl WireguardMonitor {
270269
>(
271270
mut config: Config,
272271
psk_negotiation: bool,
273-
#[cfg(target_os = "linux")] detect_mtu: bool,
272+
#[cfg(any(target_os = "linux", windows))] detect_mtu: bool,
274273
log_path: Option<&Path>,
275274
args: TunnelArgs<'_, F>,
276275
) -> Result<WireguardMonitor> {
@@ -389,7 +388,8 @@ impl WireguardMonitor {
389388
)
390389
.await?;
391390
}
392-
#[cfg(target_os = "linux")]
391+
392+
#[cfg(any(target_os = "linux", windows))]
393393
if detect_mtu {
394394
let iface_name_clone = iface_name.clone();
395395
tokio::task::spawn(async move {
@@ -411,7 +411,20 @@ impl WireguardMonitor {
411411

412412
if verified_mtu != config.mtu {
413413
log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
414-
if let Err(e) = unix::set_mtu(&iface_name_clone, verified_mtu) {
414+
#[cfg(target_os = "linux")]
415+
let res = unix::set_mtu(&iface_name_clone, verified_mtu);
416+
#[cfg(windows)]
417+
let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then(
418+
|luid| {
419+
talpid_windows::net::set_mtu(
420+
luid,
421+
verified_mtu as u32,
422+
config.ipv6_gateway.is_some(),
423+
)
424+
},
425+
);
426+
427+
if let Err(e) = res {
415428
log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
416429
};
417430
} else {
@@ -664,6 +677,8 @@ impl WireguardMonitor {
664677
addresses: &[IpAddr],
665678
mut setup_done_rx: mpsc::Receiver<std::result::Result<(), BoxedError>>,
666679
) -> std::result::Result<(), CloseMsg> {
680+
use futures::StreamExt;
681+
667682
setup_done_rx
668683
.next()
669684
.await
@@ -936,6 +951,8 @@ impl WireguardMonitor {
936951

937952
#[cfg(any(target_os = "linux", target_os = "macos"))]
938953
fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute {
954+
use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
955+
939956
if !config.is_multihop() {
940957
route
941958
} else {
@@ -991,7 +1008,7 @@ impl WireguardMonitor {
9911008
///
9921009
/// The detection works by sending evenly spread out range of pings between 576 and the given
9931010
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
994-
#[cfg(target_os = "linux")]
1011+
#[cfg(any(target_os = "linux", windows))]
9951012
async fn auto_mtu_detection(
9961013
gateway: std::net::Ipv4Addr,
9971014
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
@@ -1068,7 +1085,7 @@ async fn auto_mtu_detection(
10681085

10691086
/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
10701087
/// points.
1071-
#[cfg(target_os = "linux")]
1088+
#[cfg(any(target_os = "linux", windows))]
10721089
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
10731090
assert!(mtu_min < mtu_max);
10741091
assert!(step_size < mtu_max);
@@ -1084,7 +1101,7 @@ fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
10841101
ret
10851102
}
10861103

1087-
#[cfg(all(test, target_os = "linux"))]
1104+
#[cfg(all(test, any(target_os = "linux", windows)))]
10881105
mod tests {
10891106
use crate::mtu_spacing;
10901107
use proptest::prelude::*;

0 commit comments

Comments
 (0)