Skip to content

Commit

Permalink
Merge branch 'mtu-detection-macos'
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Feb 12, 2024
2 parents 88fcf98 + f14fdae commit 1ad14a0
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 170 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions talpid-core/src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ impl TunnelMonitor {
{
let default_mtu = DEFAULT_MTU;

#[cfg(any(target_os = "linux", target_os = "windows"))]
// Detects the MTU of the device and sets the default tunnel MTU to that minus headers and
// the safety margin
#[cfg(any(target_os = "linux", target_os = "windows"))]
let default_mtu = args
.runtime
.block_on(
Expand All @@ -176,14 +176,14 @@ impl TunnelMonitor {
.map(|mtu| Self::clamp_mtu(params, mtu))
.unwrap_or(default_mtu);

#[cfg(any(target_os = "linux", windows))]
#[cfg(not(target_os = "android"))]
let detect_mtu = params.options.mtu.is_none();

let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?;
let monitor = talpid_wireguard::WireguardMonitor::start(
config,
params.options.quantum_resistant,
#[cfg(any(target_os = "linux", windows))]
#[cfg(not(target_os = "android"))]
detect_mtu,
log.as_deref(),
args,
Expand Down
2 changes: 1 addition & 1 deletion talpid-wireguard/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ futures = "0.3.15"
hex = "0.4"
ipnetwork = "0.16"
once_cell = { workspace = true }
libc = "0.2"
libc = "0.2.150"
log = { workspace = true }
parking_lot = "0.12.0"
talpid-routing = { path = "../talpid-routing" }
Expand Down
182 changes: 19 additions & 163 deletions talpid-wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ use talpid_routing as routing;
use talpid_routing::{self, RequiredRoute};
#[cfg(not(windows))]
use talpid_tunnel::tun_provider;
#[cfg(not(target_os = "android"))]
use talpid_tunnel::IPV4_HEADER_SIZE;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};

use ipnetwork::IpNetwork;
Expand All @@ -50,7 +48,7 @@ mod connectivity_check;
mod logging;
mod ping_monitor;
mod stats;
#[cfg(target_os = "linux")]
#[cfg(any(target_os = "linux", target_os = "macos"))]
mod unix;
#[cfg(wireguard_go)]
mod wireguard_go;
Expand All @@ -59,6 +57,9 @@ pub(crate) mod wireguard_kernel;
#[cfg(windows)]
mod wireguard_nt;

#[cfg(not(target_os = "android"))]
mod mtu_detection;

#[cfg(wireguard_go)]
use self::wireguard_go::WgGoTunnel;

Expand All @@ -73,14 +74,6 @@ pub enum Error {
#[error(display = "Failed to setup routing")]
SetupRoutingError(#[error(source)] talpid_routing::Error),

/// Failed to set MTU
#[error(display = "Failed to detect MTU because every ping was dropped.")]
MtuDetectionAllDropped,

/// Failed to set MTU
#[error(display = "Failed to detect MTU because of unexpected ping error.")]
MtuDetectionPingError(#[error(source)] surge_ping::SurgeError),

/// Tunnel timed out
#[error(display = "Tunnel timed out")]
TimeoutError,
Expand Down Expand Up @@ -269,7 +262,7 @@ impl WireguardMonitor {
>(
mut config: Config,
psk_negotiation: bool,
#[cfg(any(target_os = "linux", windows))] detect_mtu: bool,
#[cfg(not(target_os = "android"))] detect_mtu: bool,
log_path: Option<&Path>,
args: TunnelArgs<'_, F>,
) -> Result<WireguardMonitor> {
Expand Down Expand Up @@ -389,47 +382,27 @@ impl WireguardMonitor {
.await?;
}

#[cfg(any(target_os = "linux", windows))]
#[cfg(not(target_os = "android"))]
if detect_mtu {
let iface_name_clone = iface_name.clone();
let config = config.clone();
let iface_name = iface_name.clone();
tokio::task::spawn(async move {
log::debug!("Starting MTU detection");
let verified_mtu = match auto_mtu_detection(
if let Err(e) = mtu_detection::automatic_mtu_correction(
gateway,
#[cfg(any(target_os = "macos", target_os = "linux"))]
iface_name_clone.clone(),
iface_name,
config.mtu,
#[cfg(windows)]
config.ipv6_gateway.is_some(),
)
.await
{
Ok(mtu) => mtu,
Err(e) => {
log::error!("{}", e.display_chain_with_msg("Failed to detect MTU"));
return;
}
};

if verified_mtu != config.mtu {
log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu);
#[cfg(target_os = "linux")]
let res = unix::set_mtu(&iface_name_clone, verified_mtu);
#[cfg(windows)]
let res = talpid_windows::net::luid_from_alias(iface_name_clone).and_then(
|luid| {
talpid_windows::net::set_mtu(
luid,
verified_mtu as u32,
config.ipv6_gateway.is_some(),
)
},
log::error!(
"{}",
e.display_chain_with_msg(
"Failed to automatically adjust MTU based on dropped packets"
)
);

if let Err(e) = res {
log::error!("{}", e.display_chain_with_msg("Failed to set MTU"))
};
} else {
log::debug!("MTU {verified_mtu} verified to not drop packets");
}
};
});
}
let mut connectivity_monitor = tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -951,7 +924,7 @@ impl WireguardMonitor {

#[cfg(any(target_os = "linux", target_os = "macos"))]
fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute {
use talpid_tunnel::{IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};
use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE};

if !config.is_multihop() {
route
Expand Down Expand Up @@ -1004,123 +977,6 @@ impl WireguardMonitor {
}
}

/// Detects the maximum MTU that does not cause dropped packets.
///
/// The detection works by sending evenly spread out range of pings between 576 and the given
/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout.
#[cfg(any(target_os = "linux", windows))]
async fn auto_mtu_detection(
gateway: std::net::Ipv4Addr,
#[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String,
current_mtu: u16,
) -> Result<u16> {
use futures::{future, stream::FuturesUnordered, TryStreamExt};
use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError};
use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU};
use tokio_stream::StreamExt;

/// Max time to wait for any ping, when this expires, we give up and throw an error.
const PING_TIMEOUT: Duration = Duration::from_secs(10);
/// Max time to wait after the first ping arrives. Every ping after this timeout is considered
/// dropped, so we return the largest collected packet size.
const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2);

let config_builder = Config::builder().kind(surge_ping::ICMP::V4);
#[cfg(any(target_os = "macos", target_os = "linux"))]
let config_builder = config_builder.interface(&iface_name);
let client = Client::new(&config_builder.build()).unwrap();

let step_size = 20;
let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size);

let payload_buf = vec![0; current_mtu as usize];

let mut ping_stream = linspace
.iter()
.enumerate()
.map(|(i, &mtu)| {
let client = client.clone();
let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize;
let payload = &payload_buf[0..payload_size];
async move {
log::trace!("Sending ICMP ping of total size {mtu}");
client
.pinger(IpAddr::V4(gateway), PingIdentifier(0))
.await
.timeout(PING_TIMEOUT)
.ping(PingSequence(i as u16), payload)
.await
}
})
.collect::<FuturesUnordered<_>>()
.map_ok(|(packet, _rtt)| {
let surge_ping::IcmpPacket::V4(packet) = packet else {
unreachable!("ICMP ping response was not of IPv4 type");
};
let size = packet.get_size() as u16 + IPV4_HEADER_SIZE;
log::trace!("Got ICMP ping response of total size {size}");
debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]);
size
});

let first_ping_size = ping_stream
.next()
.await
.expect("At least one pings should be sent")
// Short-circuit and return on error
.map_err(|e| match e {
// If the first ping we get back timed out, then all of them did
SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped,
// Unexpected error type
e => Error::MtuDetectionPingError(e),
})?;

ping_stream
.timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout
.map_while(|res| res.ok()) // Stop waiting for pings after this timeout
.try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping
.await
.map_err(Error::MtuDetectionPingError)
}

/// Creates a linear spacing of MTU values with the given step size. Always includes the given end
/// points.
#[cfg(any(target_os = "linux", windows))]
fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec<u16> {
assert!(mtu_min < mtu_max);
assert!(step_size < mtu_max);
assert_ne!(step_size, 0);

let second_mtu = (mtu_min + 1).next_multiple_of(step_size);
let in_between = (second_mtu..mtu_max).step_by(step_size as usize);

let mut ret = Vec::with_capacity(in_between.clone().count() + 2);
ret.push(mtu_min);
ret.extend(in_between);
ret.push(mtu_max);
ret
}

#[cfg(all(test, any(target_os = "linux", windows)))]
mod tests {
use crate::mtu_spacing;
use proptest::prelude::*;

proptest! {
#[test]
fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) {
let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size);

prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1);
prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1);
prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len());
let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]);
prop_assert!(diffs.all(|diff| diff <= step_size));

}
}
}

#[derive(Debug)]
enum CloseMsg {
Stop,
Expand Down
Loading

0 comments on commit 1ad14a0

Please sign in to comment.