Skip to content

Commit

Permalink
Move MTU detection to separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
Serock3 authored and dlon committed Feb 12, 2024
1 parent 345eed7 commit f14fdae
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 177 deletions.
193 changes: 16 additions & 177 deletions talpid-wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ use talpid_routing as routing;
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
#[cfg(not(target_os = "android"))]
use talpid_tunnel::IPV4_HEADER_SIZE;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};

use ipnetwork::IpNetwork;
Expand Down Expand Up @@ -59,6 +57,9 @@ pub(crate) mod wireguard_kernel;
#[cfg(windows)]
mod wireguard_nt;

#[cfg(not(target_os = "android"))]
mod mtu_detection;

#[cfg(wireguard_go)]
use self::wireguard_go::WgGoTunnel;

Expand All @@ -73,19 +74,6 @@ pub enum Error {
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] talpid_routing::Error),

/// Failed to set MTU
#[error(display = "Failed to detect MTU because every ping was dropped.")]
MtuDetectionAllDropped,

/// Failed to set MTU
#[error(display = "Failed to detect MTU because of unexpected ping error.")]
MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),

/// Failed to set MTU
#[cfg(target_os = "macos")]
#[error(display = "Failed to set buffer size")]
MtuSetBufferSize(#[error(source)] nix::Error),

/// Tunnel timed out
#[error(display = "Tunnel timed out")]
TimeoutError,
Expand Down Expand Up @@ -396,45 +384,25 @@ impl WireguardMonitor {

#[cfg(not(target_os = "android"))]
if detect_mtu {
let iface_name_clone = iface_name.clone();
let config = config.clone();
let iface_name = iface_name.clone();
tokio::task::spawn(async move {
log::debug!("Starting MTU detection");
let verified_mtu = match auto_mtu_detection(
if let Err(e) = mtu_detection::automatic_mtu_correction(
gateway,
#[cfg(any(target_os = "macos", target_os = "linux"))]
iface_name_clone.clone(),
iface_name,
config.mtu,
#[cfg(windows)]
config.ipv6_gateway.is_some(),
)
.await
{
Ok(mtu) => mtu,
Err(e) => {
log::error!("{}", e.display_chain_with_msg("Failed to detect MTU"));
return;
}
};

if verified_mtu != config.mtu {
log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
#[cfg(any(target_os = "linux", target_os = "macos"))]
let res = unix::set_mtu(&iface_name_clone, verified_mtu);
#[cfg(windows)]
let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then(
|luid| {
talpid_windows::net::set_mtu(
luid,
verified_mtu as u32,
config.ipv6_gateway.is_some(),
)
},
log::error!(
"{}",
e.display_chain_with_msg(
"Failed to automatically adjust MTU based on dropped packets"
)
);

if let Err(e) = res {
log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
};
} else {
log::debug!("MTU {verified_mtu} verified to not drop packets");
}
};
});
}
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -956,7 +924,7 @@ impl WireguardMonitor {

#[cfg(any(target_os = "linux", target_os = "macos"))]
fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute {
use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};

if !config.is_multihop() {
route
Expand Down Expand Up @@ -1009,135 +977,6 @@ impl WireguardMonitor {
}
}

/// Detects the maximum MTU that does not cause dropped packets.
///
/// The detection works by sending evenly spread out range of pings between 576 and the given
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
#[cfg(not(target_os = "android"))]
async fn auto_mtu_detection(
gateway: std::net::Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
current_mtu: u16,
) -> Result<u16> {
use futures::{future, stream::FuturesUnordered, TryStreamExt};
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
use tokio_stream::StreamExt;

/// Max time to wait for any ping, when this expires, we give up and throw an error.
const PING_TIMEOUT: Duration = Duration::from_secs(10);
/// Max time to wait after the first ping arrives. Every ping after this timeout is considered
/// dropped, so we return the largest collected packet size.
const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);

let step_size = 20;
let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);

let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
#[cfg(any(target_os = "macos", target_os = "linux"))]
let config_builder = config_builder.interface(&iface_name);
let client = Client::new(&config_builder.build()).unwrap();
// For macos, the default socket receive buffer size seems to be too small to handle the data we
// are sending here. The consequence will be dropped packets causing the MTU detection to set a
// low value. Here we manually increase this value, which fixes the problem.
// TODO: Make sure this fix is not needed for any other target OS
#[cfg(target_os = "macos")]
{
use nix::sys::socket::{setsockopt, sockopt};
let fd = client.get_socket().get_native_sock();
let buf_size = linspace.iter().map(|sz| usize::from(*sz)).sum();
setsockopt(fd, sockopt::SndBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
setsockopt(fd, sockopt::RcvBuf, &buf_size).map_err(Error::MtuSetBufferSize)?;
}

let payload_buf = vec![0; current_mtu as usize];

let mut ping_stream = linspace
.iter()
.enumerate()
.map(|(i, &mtu)| {
let client = client.clone();
let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
let payload = &payload_buf[0..payload_size];
async move {
log::trace!("Sending ICMP ping of total size {mtu}");
client
.pinger(IpAddr::V4(gateway), PingIdentifier(0))
.await
.timeout(PING_TIMEOUT)
.ping(PingSequence(i as u16), payload)
.await
}
})
.collect::<FuturesUnordered<_>>()
.map_ok(|(packet, _rtt)| {
let surge_ping::IcmpPacket::V4(packet) = packet else {
unreachable!("ICMP ping response was not of IPv4 type");
};
let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
log::trace!("Got ICMP ping response of total size {size}");
debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
size
});

let first_ping_size = ping_stream
.next()
.await
.expect("At least one pings should be sent")
// Short-circuit and return on error
.map_err(|e| match e {
// If the first ping we get back timed out, then all of them did
SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
// Unexpected error type
e => Error::MtuDetectionPingError(e),
})?;

ping_stream
.timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
.map_while(|res| res.ok()) // Stop waiting for pings after this timeout
.try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
.await
.map_err(Error::MtuDetectionPingError)
}

/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
/// points.
#[cfg(not(target_os = "android"))]
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
assert!(mtu_min < mtu_max);
assert!(step_size < mtu_max);
assert_ne!(step_size, 0);

let second_mtu = (mtu_min + 1).next_multiple_of(step_size);
let in_between = (second_mtu..mtu_max).step_by(step_size as usize);

let mut ret = Vec::with_capacity(in_between.clone().count() + 2);
ret.push(mtu_min);
ret.extend(in_between);
ret.push(mtu_max);
ret
}

#[cfg(all(test, not(target_os = "android")))]
mod tests {
use crate::mtu_spacing;
use proptest::prelude::*;

proptest! {
#[test]
fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);

prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
prop_assert!(diffs.all(|diff| diff <= step_size));

}
}
}

#[derive(Debug)]
enum CloseMsg {
Stop,
Expand Down
Loading

0 comments on commit f14fdae

Please sign in to comment.