Skip to content

Commit e80115d

Browse files
committed
Merge branch 'set_mtu_on_wireguard'
2 parents b51a6de + a036cbf commit e80115d

File tree

5 files changed

+135
-9
lines changed

5 files changed

+135
-9
lines changed

talpid-core/src/routing/linux.rs

+70
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ pub enum Error {
113113
#[error(display = "No netlink response for route query")]
114114
NoRouteError,
115115

116+
#[error(display = "No link found")]
117+
LinkNotFoundError,
118+
116119
/// Unable to create routing table for tagged connections and packets.
117120
#[error(display = "Cannot find a free routing table ID")]
118121
NoFreeRoutingTableId,
@@ -363,6 +366,9 @@ impl RouteManagerImpl {
363366
RouteManagerCommand::GetDestinationRoute(destination, set_mark, result_tx) => {
364367
let _ = result_tx.send(self.get_destination_route(&destination, set_mark).await);
365368
}
369+
RouteManagerCommand::GetMtuForRoute(ip, result_tx) => {
370+
let _ = result_tx.send(self.get_mtu_for_route(ip).await);
371+
}
366372
RouteManagerCommand::ClearRoutes => {
367373
log::debug!("Clearing routes");
368374
self.cleanup_routes().await;
@@ -720,6 +726,70 @@ impl RouteManagerImpl {
720726
}
721727
}
722728

729+
async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> {
730+
// RECURSION_LIMIT controls how many times we recurse to find the device name by looking up
731+
// an IP with `get_destination_route`.
732+
const RECURSION_LIMIT: usize = 10;
733+
const STANDARD_MTU: u16 = 1500;
734+
let mut attempted_ip = ip;
735+
for _ in 0..RECURSION_LIMIT {
736+
let route = self.get_destination_route(&attempted_ip, false).await?;
737+
match route {
738+
Some(route) => {
739+
let node = route.get_node();
740+
match (node.get_device(), node.get_address()) {
741+
(Some(device), None) => {
742+
let mtu = self.get_device_mtu(device.to_string()).await?;
743+
if mtu != STANDARD_MTU {
744+
log::info!(
745+
"Found MTU: {} on device {} which is different from the standard {}",
746+
mtu,
747+
device,
748+
STANDARD_MTU
749+
);
750+
}
751+
return Ok(mtu);
752+
}
753+
(None, Some(address)) => attempted_ip = address,
754+
_ => {
755+
panic!("Route must contain either an IP or a device.");
756+
}
757+
}
758+
}
759+
None => {
760+
log::error!("No route detected when assigning the mtu to the Wireguard tunnel");
761+
return Err(Error::NoRouteError);
762+
}
763+
}
764+
}
765+
log::error!(
766+
"Retried {} times looking for the correct device and could not find it",
767+
RECURSION_LIMIT
768+
);
769+
Err(Error::NoRouteError)
770+
}
771+
772+
async fn get_device_mtu(&self, device: String) -> Result<u16> {
773+
let mut links = self.handle.link().get().execute();
774+
let target_device = LinkNla::IfName(device);
775+
while let Some(msg) = links
776+
.try_next()
777+
.await
778+
.map_err(|_| Error::LinkNotFoundError)?
779+
{
780+
let found = msg.nlas.iter().any(|e| *e == target_device);
781+
if found {
782+
if let Some(LinkNla::Mtu(mtu)) =
783+
msg.nlas.iter().find(|e| matches!(e, LinkNla::Mtu(_)))
784+
{
785+
return Ok(u16::try_from(*mtu)
786+
.expect("MTU returned by device does not fit into a u16"));
787+
}
788+
}
789+
}
790+
Err(Error::LinkNotFoundError)
791+
}
792+
723793
async fn get_destination_route(
724794
&self,
725795
destination: &IpAddr,

talpid-core/src/routing/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub use imp::{Error, RouteManager};
2222

2323
pub use imp::RouteManagerHandle;
2424

25-
/// A netowrk route with a specific network node, destinaiton and an optional metric.
25+
/// A network route with a specific network node, destinaiton and an optional metric.
2626
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
2727
pub struct Route {
2828
node: Node,

talpid-core/src/routing/unix.rs

+15
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ impl RouteManagerHandle {
133133
.map_err(|_| Error::ManagerChannelDown)?
134134
.map_err(Error::PlatformError)
135135
}
136+
137+
/// Listen for route changes.
138+
#[cfg(target_os = "linux")]
139+
pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16, Error> {
140+
let (response_tx, response_rx) = oneshot::channel();
141+
self.tx
142+
.unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx))
143+
.map_err(|_| Error::RouteManagerDown)?;
144+
response_rx
145+
.await
146+
.map_err(|_| Error::ManagerChannelDown)?
147+
.map_err(Error::PlatformError)
148+
}
136149
}
137150

138151
/// Commands for the underlying route manager object.
@@ -151,6 +164,8 @@ pub(crate) enum RouteManagerCommand {
151164
#[cfg(target_os = "linux")]
152165
NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>),
153166
#[cfg(target_os = "linux")]
167+
GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16, PlatformError>>),
168+
#[cfg(target_os = "linux")]
154169
GetDestinationRoute(
155170
IpAddr,
156171
bool,

talpid-core/src/tunnel/mod.rs

+47-6
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ pub enum Error {
6161
/// There was an error listening for events from the Wireguard tunnel
6262
#[error(display = "Failed while listening for events from the Wireguard tunnel")]
6363
WireguardTunnelMonitoringError(#[error(source)] wireguard::Error),
64+
65+
/// Could not detect and assign the correct mtu
66+
#[error(display = "Could not detect and assign a correct MTU for the Wireguard tunnel")]
67+
AssignMtuError,
6468
}
6569

6670
/// Possible events from the VPN tunnel and the child process managing it.
@@ -101,7 +105,7 @@ impl TunnelMonitor {
101105
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
102106
pub fn start<L>(
103107
runtime: tokio::runtime::Handle,
104-
tunnel_parameters: &TunnelParameters,
108+
tunnel_parameters: &mut TunnelParameters,
105109
log_dir: &Option<PathBuf>,
106110
resource_dir: &Path,
107111
on_event: L,
@@ -134,9 +138,9 @@ impl TunnelMonitor {
134138
#[cfg(target_os = "android")]
135139
TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform),
136140

137-
TunnelParameters::Wireguard(config) => Self::start_wireguard_tunnel(
141+
TunnelParameters::Wireguard(ref mut config) => Self::start_wireguard_tunnel(
138142
runtime,
139-
&config,
143+
config,
140144
log_file,
141145
resource_dir,
142146
on_event,
@@ -172,7 +176,7 @@ impl TunnelMonitor {
172176

173177
fn start_wireguard_tunnel<L>(
174178
runtime: tokio::runtime::Handle,
175-
params: &wireguard_types::TunnelParameters,
179+
params: &mut wireguard_types::TunnelParameters,
176180
log: Option<PathBuf>,
177181
resource_dir: &Path,
178182
on_event: L,
@@ -188,11 +192,13 @@ impl TunnelMonitor {
188192
+ Clone
189193
+ 'static,
190194
{
191-
let config = wireguard::config::Config::from_parameters(&params)?;
195+
#[cfg(target_os = "linux")]
196+
runtime.block_on(Self::assign_mtu(&route_manager, params));
197+
let config = wireguard::config::Config::from_parameters(params)?;
192198
let monitor = wireguard::WireguardMonitor::start(
193199
runtime,
194200
config,
195-
log.as_ref().map(|p| p.as_path()),
201+
log.as_deref(),
196202
resource_dir,
197203
on_event,
198204
tun_provider,
@@ -205,6 +211,41 @@ impl TunnelMonitor {
205211
})
206212
}
207213

214+
#[cfg(target_os = "linux")]
215+
fn set_mtu(params: &mut wireguard_types::TunnelParameters, mtu: u16) {
216+
const WIREGUARD_HEADER_SIZE: u16 = 80;
217+
// The largest tunnel MTU that we allow. Standard MTU - Wireguard header
218+
const MAX_TUNNEL_MTU: u16 = 1420;
219+
// The minimum allowed MTU size for our tunnel in IPv6 is 1280
220+
const MIN_IPV6_MTU: u16 = 1280;
221+
const MIN_IPV4_MTU: u16 = 576;
222+
let min_mtu = match params.generic_options.enable_ipv6 {
223+
true => MIN_IPV6_MTU,
224+
false => MIN_IPV4_MTU,
225+
};
226+
let mtu = std::cmp::max(
227+
mtu.checked_sub(WIREGUARD_HEADER_SIZE).unwrap_or(min_mtu),
228+
min_mtu,
229+
);
230+
let upstream_mtu = std::cmp::min(MAX_TUNNEL_MTU, mtu);
231+
params.options.mtu = Some(upstream_mtu);
232+
}
233+
234+
#[cfg(target_os = "linux")]
235+
async fn assign_mtu(
236+
route_manager: &RouteManagerHandle,
237+
params: &mut wireguard_types::TunnelParameters,
238+
) {
239+
// It is fine to leave the params untouched if getting the mtu for the route fails. In that
240+
// case we will do our regular default.
241+
if let Ok(mtu) = route_manager
242+
.get_mtu_for_route(params.connection.peer.endpoint.ip())
243+
.await
244+
{
245+
Self::set_mtu(params, mtu);
246+
}
247+
}
248+
208249
#[cfg(not(target_os = "android"))]
209250
async fn start_openvpn_tunnel<L>(
210251
config: &openvpn_types::TunnelParameters,

talpid-core/src/tunnel_state_machine/connecting_state.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl ConnectingState {
115115
let (tunnel_close_tx, tunnel_close_rx) = oneshot::channel();
116116
let (tunnel_close_event_tx, tunnel_close_event_rx) = oneshot::channel();
117117

118-
let tunnel_parameters = parameters.clone();
118+
let mut tunnel_parameters = parameters.clone();
119119

120120
tokio::task::spawn_blocking(move || {
121121
let start = Instant::now();
@@ -141,7 +141,7 @@ impl ConnectingState {
141141

142142
let block_reason = match TunnelMonitor::start(
143143
runtime,
144-
&tunnel_parameters,
144+
&mut tunnel_parameters,
145145
&log_dir,
146146
&resource_dir,
147147
on_tunnel_event,

0 commit comments

Comments
 (0)