diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index f30b61317146..847150c47731 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -16,7 +16,6 @@ use mullvad_types::{ version::AppVersionInfo, wireguard::{PublicKey, QuantumResistantState, RotationInterval}, }; -#[cfg(target_os = "windows")] use std::path::Path; use std::str::FromStr; #[cfg(target_os = "windows")] @@ -592,7 +591,6 @@ impl MullvadProxyClient { .map(drop) } - #[cfg(target_os = "linux")] pub async fn get_split_tunnel_processes(&mut self) -> Result> { use futures::TryStreamExt; @@ -605,7 +603,6 @@ impl MullvadProxyClient { procs.try_collect().await.map_err(Error::Rpc) } - #[cfg(target_os = "linux")] pub async fn add_split_tunnel_process(&mut self, pid: i32) -> Result<()> { self.0 .add_split_tunnel_process(pid) @@ -614,7 +611,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "linux")] pub async fn remove_split_tunnel_process(&mut self, pid: i32) -> Result<()> { self.0 .remove_split_tunnel_process(pid) @@ -623,7 +619,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "linux")] pub async fn clear_split_tunnel_processes(&mut self) -> Result<()> { self.0 .clear_split_tunnel_processes(()) @@ -632,7 +627,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] pub async fn add_split_tunnel_app>(&mut self, path: P) -> Result<()> { let path = path.as_ref().to_str().ok_or(Error::PathMustBeUtf8)?; self.0 @@ -642,7 +636,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] pub async fn remove_split_tunnel_app>(&mut self, path: P) -> Result<()> { let path = path.as_ref().to_str().ok_or(Error::PathMustBeUtf8)?; self.0 @@ -652,7 +645,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] pub async fn clear_split_tunnel_apps(&mut self) -> Result<()> { self.0 .clear_split_tunnel_apps(()) @@ -661,7 +653,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] pub async fn set_split_tunnel_state(&mut self, state: bool) -> Result<()> { self.0 .set_split_tunnel_state(state) diff --git a/test/Cargo.lock b/test/Cargo.lock index b974ef9217ce..5a7771fb4701 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -463,6 +463,33 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" +[[package]] +name = "color-eyre" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a667583cca8c4f8436db8de46ea8233c42a7d9ae424a82d338f2e4675229204" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.0" @@ -490,6 +517,19 @@ dependencies = [ "memchr", ] +[[package]] +name = "connection-checker" +version = "0.0.0" +dependencies = [ + "clap", + "color-eyre", + "eyre", + "ping", + "reqwest", + "serde", + "socket2 0.5.4", +] + [[package]] name = "const-oid" version = "0.9.5" @@ -745,6 +785,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encoding_rs" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum-as-inner" version = "0.6.0" @@ -821,6 +870,16 @@ dependencies = [ "libc", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fast-socks5" version = "0.9.5" @@ -1197,7 +1256,7 @@ dependencies = [ "rustls-native-certs", "tokio", "tokio-rustls", - "webpki-roots", + "webpki-roots 0.23.1", ] [[package]] @@ -1245,6 +1304,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -1943,6 +2008,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "p256" version = "0.11.1" @@ -2089,6 +2160,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ping" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122ee1f5a6843bec84fcbd5c6ba3622115337a6b8965b93a61aad347648f4e8d" +dependencies = [ + "rand 0.8.5", + "socket2 0.4.9", + "thiserror", +] + [[package]] name = "pkcs8" version = "0.9.0" @@ -2464,6 +2546,47 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +[[package]] +name = "reqwest" +version = "0.11.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" +dependencies = [ + "base64 0.21.4", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-pemfile 1.0.3", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "system-configuration", + "tokio", + "tokio-rustls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots 0.25.4", + "winreg", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -2710,18 +2833,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.188" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", @@ -3006,6 +3129,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "talpid-dbus" version = "0.0.0" @@ -3042,6 +3186,7 @@ dependencies = [ "base64 0.13.1", "ipnetwork 0.16.0", "jnix", + "log", "serde", "thiserror", "x25519-dalek", @@ -3229,6 +3374,7 @@ dependencies = [ "proc-macro2", "quote", "syn 1.0.109", + "test-rpc", ] [[package]] @@ -3552,6 +3698,16 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-error" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e" +dependencies = [ + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-opentelemetry" version = "0.17.4" @@ -3792,6 +3948,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.87" @@ -3840,6 +4008,12 @@ dependencies = [ "rustls-webpki 0.100.3", ] +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + [[package]] name = "which" version = "4.4.2" diff --git a/test/Cargo.toml b/test/Cargo.toml index 977f9082d82b..0fd7b4a2e97e 100644 --- a/test/Cargo.toml +++ b/test/Cargo.toml @@ -7,7 +7,13 @@ rust-version = "1.75.0" [workspace] resolver = "2" -members = ["test-manager", "test-runner", "test-rpc", "socks-server"] +members = [ + "test-manager", + "test-runner", + "test-rpc", + "socks-server", + "connection-checker", +] [workspace.lints.rust] rust_2018_idioms = "deny" diff --git a/test/build.sh b/test/build.sh index 2a8f7c7063fa..d3a3c174704e 100755 --- a/test/build.sh +++ b/test/build.sh @@ -17,9 +17,12 @@ if [[ $TARGET == x86_64-unknown-linux-gnu ]]; then -e CARGO_HOME=/root/.cargo/registry \ -e CARGO_TARGET_DIR=/src/test/target \ mullvadvpn-app-tests \ - /bin/bash -c "cd /src/test/; cargo build --bin test-runner --release --target ${TARGET}" + /bin/bash -c "cd /src/test/; cargo build --bin test-runner --bin connection-checker --release --target ${TARGET}" else - cargo build --bin test-runner --release --target "${TARGET}" + cargo build \ + --bin test-runner \ + --bin connection-checker \ + --release --target "${TARGET}" fi # Only build runner image for Windows diff --git a/test/connection-checker/Cargo.toml b/test/connection-checker/Cargo.toml new file mode 100644 index 000000000000..d579510bd1e8 --- /dev/null +++ b/test/connection-checker/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "connection-checker" +description = "Simple cli for testing Mullvad VPN connections" +authors.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true + +[lints] +workspace = true + +[dependencies] +clap = { workspace = true, features = ["derive"] } +color-eyre = "0.6.2" +eyre = "0.6.12" +ping = "0.5.2" +reqwest = { version = "0.11.24", default-features = false, features = ["blocking", "rustls-tls", "json"] } +serde = { version = "1.0.197", features = ["derive"] } +socket2 = { version = "0.5.4", features = ["all"] } diff --git a/test/connection-checker/src/cli.rs b/test/connection-checker/src/cli.rs new file mode 100644 index 000000000000..dddb348b255c --- /dev/null +++ b/test/connection-checker/src/cli.rs @@ -0,0 +1,36 @@ +use std::net::SocketAddr; + +use clap::Parser; + +/// CLI tool that queries to check if the machine is connected to +/// Mullvad VPN. +#[derive(Parser)] +pub struct Opt { + /// Interactive mode, press enter to check if you are Mullvad. + #[clap(short, long)] + pub interactive: bool, + + /// Timeout for network connection to am.i.mullvad (in millis). + #[clap(short, long, default_value = "3000")] + pub timeout: u64, + + /// Try to send some junk data over TCP to . + #[clap(long, requires = "leak")] + pub leak_tcp: bool, + + /// Try to send some junk data over UDP to . + #[clap(long, requires = "leak")] + pub leak_udp: bool, + + /// Try to send ICMP request to . + #[clap(long, requires = "leak")] + pub leak_icmp: bool, + + /// Target of , or . + #[clap(long)] + pub leak: Option, + + /// Timeout for leak check network connections (in millis). + #[clap(long, default_value = "1000")] + pub leak_timeout: u64, +} diff --git a/test/connection-checker/src/lib.rs b/test/connection-checker/src/lib.rs new file mode 100644 index 000000000000..cb36c236b0be --- /dev/null +++ b/test/connection-checker/src/lib.rs @@ -0,0 +1,2 @@ +pub mod cli; +pub mod net; diff --git a/test/connection-checker/src/main.rs b/test/connection-checker/src/main.rs new file mode 100644 index 000000000000..ed48999970ce --- /dev/null +++ b/test/connection-checker/src/main.rs @@ -0,0 +1,73 @@ +use clap::Parser; +use eyre::{eyre, Context}; +use reqwest::blocking::Client; +use serde::Deserialize; +use std::{io::stdin, time::Duration}; + +use connection_checker::cli::Opt; +use connection_checker::net::{send_ping, send_tcp, send_udp}; + +fn main() -> eyre::Result<()> { + let opt = Opt::parse(); + color_eyre::install()?; + + if opt.interactive { + let stdin = stdin(); + for line in stdin.lines() { + let _ = line.wrap_err("Failed to read from stdin")?; + test_connection(&opt)?; + } + } else { + test_connection(&opt)?; + } + + Ok(()) +} + +fn test_connection(opt: &Opt) -> eyre::Result { + if let Some(destination) = opt.leak { + if opt.leak_tcp { + let _ = send_tcp(opt, destination); + } + if opt.leak_udp { + let _ = send_udp(opt, destination); + } + if opt.leak_icmp { + let _ = send_ping(opt, destination.ip()); + } + } + am_i_mullvad(opt) +} + +/// Check if connected to Mullvad and print the result to stdout +fn am_i_mullvad(opt: &Opt) -> eyre::Result { + #[derive(Debug, Deserialize)] + struct Response { + ip: String, + mullvad_exit_ip_hostname: Option, + } + + let url = "https://am.i.mullvad.net/json"; + + let client = Client::new(); + let response: Response = client + .get(url) + .timeout(Duration::from_millis(opt.timeout)) + .send() + .and_then(|r| r.json()) + .wrap_err_with(|| eyre!("Failed to GET {url}"))?; + + if let Some(server) = &response.mullvad_exit_ip_hostname { + println!( + "You are connected to Mullvad (server {}). Your IP address is {}", + server, response.ip + ); + Ok(true) + } else { + println!( + "You are not connected to Mullvad. Your IP address is {}", + response.ip + ); + Ok(false) + } +} diff --git a/test/connection-checker/src/net.rs b/test/connection-checker/src/net.rs new file mode 100644 index 000000000000..6634be41b0c8 --- /dev/null +++ b/test/connection-checker/src/net.rs @@ -0,0 +1,78 @@ +use eyre::{eyre, Context}; +use std::{ + io::Write, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; + +use crate::cli::Opt; + +pub fn send_tcp(opt: &Opt, destination: SocketAddr) -> eyre::Result<()> { + let bind_addr: SocketAddr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); + + let family = match &destination { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }; + let sock = socket2::Socket::new(family, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) + .wrap_err(eyre!("Failed to create TCP socket"))?; + + eprintln!("Leaking TCP packets to {destination}"); + + sock.bind(&socket2::SockAddr::from(bind_addr)) + .wrap_err(eyre!("Failed to bind TCP socket to {bind_addr}"))?; + + let timeout = Duration::from_millis(opt.leak_timeout); + sock.set_write_timeout(Some(timeout))?; + sock.set_read_timeout(Some(timeout))?; + + sock.connect_timeout(&socket2::SockAddr::from(destination), timeout) + .wrap_err(eyre!("Failed to connect to {destination}"))?; + + let mut stream = std::net::TcpStream::from(sock); + stream + .write_all(b"hello there") + .wrap_err(eyre!("Failed to send message to {destination}"))?; + + Ok(()) +} + +pub fn send_udp(_opt: &Opt, destination: SocketAddr) -> Result<(), eyre::Error> { + let bind_addr: SocketAddr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); + + eprintln!("Leaking UDP packets to {destination}"); + + let family = match &destination { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }; + let sock = socket2::Socket::new(family, socket2::Type::DGRAM, Some(socket2::Protocol::UDP)) + .wrap_err("Failed to create UDP socket")?; + + sock.bind(&socket2::SockAddr::from(bind_addr)) + .wrap_err(eyre!("Failed to bind UDP socket to {bind_addr}"))?; + + //log::debug!("Send message from {bind_addr} to {destination}/UDP"); + + let std_socket = std::net::UdpSocket::from(sock); + std_socket + .send_to(b"Hello there!", destination) + .wrap_err(eyre!("Failed to send message to {destination}"))?; + + Ok(()) +} + +pub fn send_ping(opt: &Opt, destination: IpAddr) -> eyre::Result<()> { + eprintln!("Leaking IMCP packets to {destination}"); + + ping::ping( + destination, + Some(Duration::from_millis(opt.leak_timeout)), + None, + None, + None, + None, + )?; + + Ok(()) +} diff --git a/test/scripts/build-runner-image.sh b/test/scripts/build-runner-image.sh index fe8077b33777..30252d844510 100755 --- a/test/scripts/build-runner-image.sh +++ b/test/scripts/build-runner-image.sh @@ -33,6 +33,7 @@ case $TARGET in mcopy \ -i "${TEST_RUNNER_IMAGE_PATH}" \ "${SCRIPT_DIR}/../target/$TARGET/release/test-runner.exe" \ + "${SCRIPT_DIR}/../target/$TARGET/release/connection-checker.exe" \ "${PACKAGES_DIR}/"*.exe \ "${SCRIPT_DIR}/../openvpn.ca.crt" \ "::" diff --git a/test/scripts/ssh-setup.sh b/test/scripts/ssh-setup.sh index a3809e023036..b3d358f5a013 100644 --- a/test/scripts/ssh-setup.sh +++ b/test/scripts/ssh-setup.sh @@ -16,7 +16,7 @@ echo "Copying test-runner to $RUNNER_DIR" mkdir -p "$RUNNER_DIR" -for file in test-runner $CURRENT_APP $PREVIOUS_APP $UI_RUNNER openvpn.ca.crt; do +for file in test-runner connection-checker $CURRENT_APP $PREVIOUS_APP $UI_RUNNER openvpn.ca.crt; do echo "Moving $file to $RUNNER_DIR" cp -f "$SCRIPT_DIR/$file" "$RUNNER_DIR" done diff --git a/test/test-manager/src/config.rs b/test/test-manager/src/config.rs index 1605661d5307..6921c0b33f3b 100644 --- a/test/test-manager/src/config.rs +++ b/test/test-manager/src/config.rs @@ -139,6 +139,16 @@ pub struct VmConfig { #[serde(default)] #[arg(long)] pub tpm: bool, + + /// Override the path to `OVMF_VARS.secboot.fd`. Requires `tpm`. + #[serde(default)] + #[arg(long, requires("tpm"))] + pub ovmf_vars_path: Option, + + /// Override the path to `OVMF_CODE.secboot.fd`. Requires `tpm`. + #[serde(default)] + #[arg(long, requires("tpm"))] + pub ovmf_code_path: Option, } impl VmConfig { diff --git a/test/test-manager/src/logging.rs b/test/test-manager/src/logging.rs index cd0bd4af2840..e85920b1cd9c 100644 --- a/test/test-manager/src/logging.rs +++ b/test/test-manager/src/logging.rs @@ -1,4 +1,4 @@ -use crate::tests::Error; +use anyhow::Error; use colored::Colorize; use std::sync::{Arc, Mutex}; use test_rpc::logging::{LogOutput, Output}; diff --git a/test/test-manager/src/run_tests.rs b/test/test-manager/src/run_tests.rs index 6af153656277..6b3da3713808 100644 --- a/test/test-manager/src/run_tests.rs +++ b/test/test-manager/src/run_tests.rs @@ -2,9 +2,7 @@ use crate::summary::{self, maybe_log_test_result}; use crate::tests::{config::TEST_CONFIG, TestContext}; use crate::{ logging::{panic_as_string, TestOutput}, - mullvad_daemon, tests, - tests::Error, - vm, + mullvad_daemon, tests, vm, }; use anyhow::{Context, Result}; use futures::FutureExt; @@ -187,7 +185,7 @@ pub async fn run_test( ) -> TestOutput where F: Fn(super::tests::TestContext, ServiceClient, MullvadClient) -> R, - R: Future>, + R: Future>, { let _flushed = runner_rpc.try_poll_output().await; diff --git a/test/test-manager/src/tests/mod.rs b/test/test-manager/src/tests/mod.rs index 0cf135769687..48d75b9e3f96 100644 --- a/test/test-manager/src/tests/mod.rs +++ b/test/test-manager/src/tests/mod.rs @@ -6,6 +6,7 @@ mod helpers; mod install; mod settings; mod software; +mod split_tunnel; mod test_metadata; mod tunnel; mod tunnel_state; @@ -32,7 +33,7 @@ pub type TestWrapperFunction = fn( TestContext, ServiceClient, Box, -) -> BoxFuture<'static, Result<(), Error>>; +) -> BoxFuture<'static, anyhow::Result<()>>; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -40,7 +41,7 @@ pub enum Error { Rpc(#[from] test_rpc::Error), #[error("geoip lookup failed")] - GeoipLookup(test_rpc::Error), + GeoipLookup(#[source] test_rpc::Error), #[error("Found running daemon unexpectedly")] DaemonRunning, diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs new file mode 100644 index 000000000000..336ee5b5ab33 --- /dev/null +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -0,0 +1,357 @@ +use anyhow::{anyhow, bail, ensure, Context}; +use mullvad_management_interface::MullvadProxyClient; +use pcap::Direction; +use pnet_packet::ip::IpNextHeaderProtocols; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + str, + time::Duration, +}; +use test_macro::test_function; +use test_rpc::{meta::Os, ServiceClient, SpawnOpts}; +use tokio::time::{sleep, timeout}; + +use crate::network_monitor::{start_packet_monitor, MonitorOptions}; + +use super::{config::TEST_CONFIG, helpers, TestContext}; + +const CHECKER_FILENAME_WINDOWS: &str = "connection-checker.exe"; +const CHECKER_FILENAME_UNIX: &str = "connection-checker"; +const LEAK_DESTINATION: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1337); + +/// Test that split tunneling works by asserting the following: +/// - Splitting a process shouldn't do anything if tunnel is not connected. +/// - A split process should never push traffic through the tunnel. +/// - Splitting/unsplitting should work regardless if process is running. +#[test_function(target_os = "linux", target_os = "windows")] +pub async fn test_split_tunnel( + _ctx: TestContext, + rpc: ServiceClient, + mut mullvad_client: MullvadProxyClient, +) -> anyhow::Result<()> { + let mut checker = ConnChecker::new(rpc.clone(), mullvad_client.clone()); + + // Test that program is behaving when we are disconnected + (checker.spawn().await?.assert_insecure().await) + .with_context(|| "Test disconnected and unsplit")?; + checker.split().await?; + (checker.spawn().await?.assert_insecure().await) + .with_context(|| "Test disconnected and split")?; + checker.unsplit().await?; + + // Test that program is behaving being split/unsplit while running and we are disconnected + let mut handle = checker.spawn().await?; + handle.split().await?; + (handle.assert_insecure().await) + .with_context(|| "Test disconnected and being split while running")?; + handle.unsplit().await?; + (handle.assert_insecure().await) + .with_context(|| "Test disconnected and being unsplit while running")?; + drop(handle); + + helpers::connect_and_wait(&mut mullvad_client).await?; + + // Test running an unsplit program + checker + .spawn() + .await? + .assert_secure() + .await + .with_context(|| "Test connected and unsplit")?; + + // Test running a split program + checker.split().await?; + checker + .spawn() + .await? + .assert_insecure() + .await + .with_context(|| "Test connected and split")?; + + checker.unsplit().await?; + + // Test splitting and unsplitting a program while it's running + let mut handle = checker.spawn().await?; + (handle.assert_secure().await).with_context(|| "Test connected and unsplit (again)")?; + handle.split().await?; + (handle.assert_insecure().await) + .with_context(|| "Test connected and being split while running")?; + handle.unsplit().await?; + (handle.assert_secure().await) + .with_context(|| "Test connected and being unsplit while running")?; + + Ok(()) +} + +/// This helper spawns a seperate process which checks if we are connected to Mullvad, and tries to +/// leak traffic outside the tunnel by sending TCP, UDP, and ICMP packets to [LEAK_DESTINATION]. +struct ConnChecker { + rpc: ServiceClient, + mullvad_client: MullvadProxyClient, + + /// Path to the process binary. + executable_path: String, + + /// Whether the process should be split when spawned. Needed on Linux. + split: bool, +} + +struct ConnCheckerHandle<'a> { + checker: &'a mut ConnChecker, + + /// ID of the spawned process. + pid: u32, +} + +struct ConnectionStatus { + /// True if reported we are connected. + am_i_mullvad: bool, + + /// True if we sniffed TCP packets going outside the tunnel. + leaked_tcp: bool, + + /// True if we sniffed UDP packets going outside the tunnel. + leaked_udp: bool, + + /// True if we sniffed ICMP packets going outside the tunnel. + leaked_icmp: bool, +} + +impl ConnChecker { + pub fn new(rpc: ServiceClient, mullvad_client: MullvadProxyClient) -> Self { + let artifacts_dir = &TEST_CONFIG.artifacts_dir; + let executable_path = match TEST_CONFIG.os { + Os::Linux | Os::Macos => format!("{artifacts_dir}/{CHECKER_FILENAME_UNIX}"), + Os::Windows => format!("{artifacts_dir}\\{CHECKER_FILENAME_WINDOWS}"), + }; + + Self { + rpc, + mullvad_client, + split: false, + executable_path, + } + } + + /// Spawn the connecton checker process and return a handle to it. + /// + /// Dropping the handle will stop the process. + /// **NOTE**: The handle must be dropped from a tokio runtime context. + pub async fn spawn(&mut self) -> anyhow::Result> { + log::debug!("spawning connection checker"); + + let opts = SpawnOpts { + attach_stdin: true, + attach_stdout: true, + args: [ + "--interactive", + "--timeout", + "10000", + // try to leak traffic to LEAK_DESTINATION + "--leak", + &LEAK_DESTINATION.to_string(), + "--leak-timeout", + "500", + "--leak-tcp", + "--leak-udp", + "--leak-icmp", + ] + .map(String::from) + .to_vec(), + ..SpawnOpts::new(&self.executable_path) + }; + + let pid = self.rpc.spawn(opts).await?; + + if self.split && TEST_CONFIG.os == Os::Linux { + self.mullvad_client + .add_split_tunnel_process(pid as i32) + .await?; + } + + Ok(ConnCheckerHandle { pid, checker: self }) + } + + /// Enable split tunneling for the connection checker. + pub async fn split(&mut self) -> anyhow::Result<()> { + log::debug!("enable split tunnel"); + self.split = true; + + match TEST_CONFIG.os { + Os::Linux => { /* linux programs can't be split until they are spawned */ } + Os::Windows => { + self.mullvad_client + .add_split_tunnel_app(&self.executable_path) + .await?; + self.mullvad_client.set_split_tunnel_state(true).await?; + } + Os::Macos => unimplemented!("MacOS"), + } + + Ok(()) + } + + /// Disable split tunneling for the connection checker. + pub async fn unsplit(&mut self) -> anyhow::Result<()> { + log::debug!("disable split tunnel"); + self.split = false; + + match TEST_CONFIG.os { + Os::Linux => {} + Os::Windows => { + self.mullvad_client.set_split_tunnel_state(false).await?; + self.mullvad_client + .remove_split_tunnel_app(&self.executable_path) + .await?; + } + Os::Macos => unimplemented!("MacOS"), + } + + Ok(()) + } +} + +impl ConnCheckerHandle<'_> { + pub async fn split(&mut self) -> anyhow::Result<()> { + if TEST_CONFIG.os == Os::Linux { + self.checker + .mullvad_client + .add_split_tunnel_process(self.pid as i32) + .await?; + } + + self.checker.split().await + } + + pub async fn unsplit(&mut self) -> anyhow::Result<()> { + if TEST_CONFIG.os == Os::Linux { + self.checker + .mullvad_client + .remove_split_tunnel_process(self.pid as i32) + .await?; + } + + self.checker.unsplit().await + } + + /// Assert that traffic is flowing through the Mullvad tunnel and that no packets are leaked. + pub async fn assert_secure(&mut self) -> anyhow::Result<()> { + log::info!("checking that connection is secure"); + let status = self.check_connection().await?; + ensure!(status.am_i_mullvad); + ensure!(!status.leaked_tcp); + ensure!(!status.leaked_udp); + ensure!(!status.leaked_icmp); + + Ok(()) + } + + /// Assert that traffic is NOT flowing through the Mullvad tunnel and that packets ARE leaked. + pub async fn assert_insecure(&mut self) -> anyhow::Result<()> { + log::info!("checking that connection is not secure"); + let status = self.check_connection().await?; + ensure!(!status.am_i_mullvad); + ensure!(status.leaked_tcp); + ensure!(status.leaked_udp); + ensure!(status.leaked_icmp); + + Ok(()) + } + + async fn check_connection(&mut self) -> anyhow::Result { + // Monitor all pakets going to LEAK_DESTINATION during the check. + let monitor = start_packet_monitor( + |packet| packet.destination.ip() == LEAK_DESTINATION.ip(), + MonitorOptions { + direction: Some(Direction::In), + ..MonitorOptions::default() + }, + ) + .await; + + // Write a newline to the connection checker to prompt it to perform the check. + self.checker + .rpc + .write_child_stdin(self.pid, "Say the line, Bart!\r\n".into()) + .await?; + + // The checker responds when the check is complete. + let line = self.read_stdout_line().await?; + + let monitor_result = monitor + .into_result() + .await + .map_err(|_e| anyhow!("Packet monitor unexpectedly stopped"))?; + + Ok(ConnectionStatus { + am_i_mullvad: parse_am_i_mullvad(line)?, + + leaked_tcp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Tcp), + + leaked_udp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Udp), + + leaked_icmp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Icmp), + }) + } + + /// Try to a single line of output from the spawned process + async fn read_stdout_line(&mut self) -> anyhow::Result { + // Add a timeout to avoid waiting forever. + timeout(Duration::from_secs(8), async { + let mut line = String::new(); + + // tarpc doesn't support streams, so we poll the checker process in a loop instead + loop { + let Some(output) = self.checker.rpc.read_child_stdout(self.pid).await? else { + bail!("got EOF from connection checker process"); + }; + + if output.is_empty() { + sleep(Duration::from_millis(500)).await; + continue; + } + + line.push_str(&output); + + if line.contains('\n') { + log::info!("output from child process: {output:?}"); + return Ok(line); + } + } + }) + .await + .with_context(|| "Timeout reading stdout from connection checker")? + } +} + +impl Drop for ConnCheckerHandle<'_> { + fn drop(&mut self) { + let rpc = self.checker.rpc.clone(); + let pid = self.pid; + + let Ok(runtime_handle) = tokio::runtime::Handle::try_current() else { + log::error!("ConnCheckerHandle dropped outside of a tokio runtime."); + return; + }; + + runtime_handle.spawn(async move { + // Make sure child process is stopped when this handle is dropped. + // Closing stdin does the trick. + let _ = rpc.close_child_stdin(pid).await; + }); + } +} + +/// Parse output from connection-checker. Returns true if connected to Mullvad. +fn parse_am_i_mullvad(result: String) -> anyhow::Result { + Ok(if result.contains("You are connected") { + true + } else if result.contains("You are not connected") { + false + } else { + bail!("Unexpected output from connection-checker: {result:?}") + }) +} diff --git a/test/test-manager/src/tests/test_metadata.rs b/test/test-manager/src/tests/test_metadata.rs index 3e28a4380b6a..d4ffa9bfd029 100644 --- a/test/test-manager/src/tests/test_metadata.rs +++ b/test/test-manager/src/tests/test_metadata.rs @@ -5,7 +5,7 @@ use test_rpc::mullvad_daemon::MullvadClientVersion; pub struct TestMetadata { pub name: &'static str, pub command: &'static str, - pub target_os: Option, + pub targets: &'static [Os], pub mullvad_client_version: MullvadClientVersion, pub func: TestWrapperFunction, pub priority: Option, @@ -16,9 +16,7 @@ pub struct TestMetadata { impl TestMetadata { pub fn should_run_on_os(&self, os: Os) -> bool { - self.target_os - .map(|target_os| target_os == os) - .unwrap_or(true) + self.targets.is_empty() || self.targets.contains(&os) } } diff --git a/test/test-manager/src/vm/provision.rs b/test/test-manager/src/vm/provision.rs index 5f01e8f192b9..8667b6c1338b 100644 --- a/test/test-manager/src/vm/provision.rs +++ b/test/test-manager/src/vm/provision.rs @@ -106,6 +106,11 @@ fn blocking_ssh( ssh_send_file_path(&session, &source, temp_dir) .context("Failed to send test runner to remote")?; + // Transfer connection-checker + let source = local_runner_dir.join("connection-checker"); + ssh_send_file_path(&session, &source, temp_dir) + .context("Failed to send connection-checker to remote")?; + // Transfer app packages ssh_send_file_path(&session, &local_app_manifest.current_app_path, temp_dir) .context("Failed to send current app package to remote")?; diff --git a/test/test-manager/src/vm/qemu.rs b/test/test-manager/src/vm/qemu.rs index 5688f47101c9..62613d5e1daa 100644 --- a/test/test-manager/src/vm/qemu.rs +++ b/test/test-manager/src/vm/qemu.rs @@ -134,7 +134,7 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result // Configure OVMF. Currently, this is enabled implicitly if using a TPM let ovmf_handle = if vm_config.tpm { - let handle = OvmfHandle::new().await?; + let handle = OvmfHandle::new(vm_config).await?; handle.append_qemu_args(&mut qemu_cmd); Some(handle) } else { @@ -202,32 +202,50 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result /// Used to set up UEFI and append options to the QEMU command struct OvmfHandle { temp_vars: TempFile, + ovmf_code_path: String, } impl OvmfHandle { - pub async fn new() -> Result { - const OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd"; + pub async fn new(config: &VmConfig) -> Result { + const DEFAULT_OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd"; + const DEFAULT_OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd"; + + let ovmf_code_path = config + .ovmf_code_path + .as_deref() + .unwrap_or(DEFAULT_OVMF_CODE_PATH) + .to_owned(); + + let ovmf_vars_path = config + .ovmf_vars_path + .as_deref() + .unwrap_or(DEFAULT_OVMF_VARS_PATH); // Create a local copy of OVMF_VARS let temp_vars_path = random_tempfile_name(); - fs::copy(OVMF_VARS_PATH, &temp_vars_path) + fs::copy(ovmf_vars_path, &temp_vars_path) .await .map_err(Error::CopyOvmfVars)?; let temp_vars = TempFile::from_existing(temp_vars_path, async_tempfile::Ownership::Owned) .await .map_err(|_| Error::WrapOvmfVars)?; - Ok(OvmfHandle { temp_vars }) + + Ok(OvmfHandle { + temp_vars, + ovmf_code_path, + }) } pub fn append_qemu_args(&self, qemu_cmd: &mut Command) { - const OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd"; - qemu_cmd.args([ "-global", "driver=cfi.pflash01,property=secure,value=on", "-drive", - &format!("if=pflash,format=raw,unit=0,file={OVMF_CODE_PATH},readonly=on"), + &format!( + "if=pflash,format=raw,unit=0,file={},readonly=on", + self.ovmf_code_path + ), "-drive", &format!( "if=pflash,format=raw,unit=1,file={}", diff --git a/test/test-manager/test_macro/Cargo.toml b/test/test-manager/test_macro/Cargo.toml index a064b6d200f1..19a405d08fe9 100644 --- a/test/test-manager/test_macro/Cargo.toml +++ b/test/test-manager/test_macro/Cargo.toml @@ -14,3 +14,4 @@ proc-macro = true syn = "1.0" quote = "1.0" proc-macro2 = "1.0" +test-rpc = { path = "../../test-rpc" } diff --git a/test/test-manager/test_macro/src/lib.rs b/test/test-manager/test_macro/src/lib.rs index d95c3f883211..7cb8407230eb 100644 --- a/test/test-manager/test_macro/src/lib.rs +++ b/test/test-manager/test_macro/src/lib.rs @@ -1,6 +1,7 @@ use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{AttributeArgs, Lit, Meta, NestedMeta}; +use syn::{AttributeArgs, Lit, Meta, NestedMeta, Result}; +use test_rpc::meta::Os; /// Register an `async` function to be run by `test-manager`. /// @@ -52,7 +53,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta}; /// pub async fn test_function( /// rpc: ServiceClient, /// mut mullvad_client: mullvad_management_interface::MullvadProxyClient, -/// ) -> Result<(), Error> { +/// ) -> anyhow::Result<()> { /// Ok(()) /// } /// ``` @@ -67,7 +68,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta}; /// pub async fn test_function( /// rpc: ServiceClient, /// mut mullvad_client: mullvad_management_interface::MullvadProxyClient, -/// ) -> Result<(), Error> { +/// ) -> anyhow::Result<()> { /// Ok(()) /// } /// ``` @@ -76,7 +77,10 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream let function: syn::ItemFn = syn::parse(code).unwrap(); let attributes = syn::parse_macro_input!(attributes as AttributeArgs); - let test_function = parse_marked_test_function(&attributes, &function); + let test_function = match parse_marked_test_function(&attributes, &function) { + Ok(tf) => tf, + Err(e) => return e.into_compile_error().into(), + }; let register_test = create_test(test_function); @@ -88,73 +92,91 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream .into() } -fn parse_marked_test_function(attributes: &AttributeArgs, function: &syn::ItemFn) -> TestFunction { - let macro_parameters = get_test_macro_parameters(attributes); +/// Shorthand for `return syn::Error::new(...)`. +macro_rules! bail { + ($span:expr, $($tt:tt)*) => {{ + return ::core::result::Result::Err(::syn::Error::new( + ::syn::spanned::Spanned::span(&$span), + ::core::format_args!($($tt)*), + )) + }}; +} - let function_parameters = get_test_function_parameters(&function.sig.inputs); +fn parse_marked_test_function( + attributes: &AttributeArgs, + function: &syn::ItemFn, +) -> Result { + let macro_parameters = get_test_macro_parameters(attributes)?; + let function_parameters = get_test_function_parameters(&function.sig.inputs)?; - TestFunction { + Ok(TestFunction { name: function.sig.ident.clone(), function_parameters, macro_parameters, - } + }) } -fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> MacroParameters { +fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> Result { let mut priority = None; let mut cleanup = true; let mut always_run = false; let mut must_succeed = false; - let mut target_os = None; + let mut targets = vec![]; for attribute in attributes { - if let NestedMeta::Meta(Meta::NameValue(nv)) = attribute { - if nv.path.is_ident("priority") { - match &nv.lit { - Lit::Int(lit_int) => { - priority = Some(lit_int.base10_parse().unwrap()); - } - _ => panic!("'priority' should have an integer value"), - } - } else if nv.path.is_ident("always_run") { - match &nv.lit { - Lit::Bool(lit_bool) => { - always_run = lit_bool.value(); - } - _ => panic!("'always_run' should have a bool value"), - } - } else if nv.path.is_ident("must_succeed") { - match &nv.lit { - Lit::Bool(lit_bool) => { - must_succeed = lit_bool.value(); - } - _ => panic!("'must_succeed' should have a bool value"), - } - } else if nv.path.is_ident("cleanup") { - match &nv.lit { - Lit::Bool(lit_bool) => { - cleanup = lit_bool.value(); - } - _ => panic!("'cleanup' should have a bool value"), - } - } else if nv.path.is_ident("target_os") { - match &nv.lit { - Lit::Str(lit_str) => { - target_os = Some(lit_str.value()); - } - _ => panic!("'target_os' should have a string value"), - } + // we only use name-value attributes + let NestedMeta::Meta(Meta::NameValue(nv)) = attribute else { + bail!(attribute, "unknown attribute"); + }; + let lit = &nv.lit; + + if nv.path.is_ident("priority") { + match lit { + Lit::Int(lit_int) => priority = Some(lit_int.base10_parse().unwrap()), + _ => bail!(nv, "'priority' should have an integer value"), + } + } else if nv.path.is_ident("always_run") { + match lit { + Lit::Bool(lit_bool) => always_run = lit_bool.value(), + _ => bail!(nv, "'always_run' should have a bool value"), } + } else if nv.path.is_ident("must_succeed") { + match lit { + Lit::Bool(lit_bool) => must_succeed = lit_bool.value(), + _ => bail!(nv, "'must_succeed' should have a bool value"), + } + } else if nv.path.is_ident("cleanup") { + match lit { + Lit::Bool(lit_bool) => cleanup = lit_bool.value(), + _ => bail!(nv, "'cleanup' should have a bool value"), + } + } else if nv.path.is_ident("target_os") { + let Lit::Str(lit_str) = lit else { + bail!(nv, "'target_os' should have a string value"); + }; + + let target = match lit_str.value().parse() { + Ok(os) => os, + Err(e) => bail!(lit_str, "{e}"), + }; + + if targets.contains(&target) { + bail!(nv, "Duplicate target"); + } + + targets.push(target); + } else { + bail!(nv, "unknown attribute"); } } - MacroParameters { + Ok(MacroParameters { priority, cleanup, always_run, must_succeed, - target_os, - } + targets, + }) } fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { @@ -162,17 +184,14 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { Some(priority) => quote! { Some(#priority) }, None => quote! { None }, }; - let target_os = match test_function.macro_parameters.target_os.as_deref() { - Some("linux") => quote! { Some(::test_rpc::meta::Os::Linux) }, - Some("macos") => quote! { Some(::test_rpc::meta::Os::Macos) }, - Some("windows") => quote! { Some(::test_rpc::meta::Os::Windows) }, - Some(target_os) => { - return quote! { - compile_error!("invalid target_os: {:?}", #target_os); - }; - } - None => quote! { None }, - }; + let targets: proc_macro2::TokenStream = (test_function.macro_parameters.targets.iter()) + .map(|&os| match os { + Os::Linux => quote! { ::test_rpc::meta::Os::Linux, }, + Os::Macos => quote! { ::test_rpc::meta::Os::Macos, }, + Os::Windows => quote! { ::test_rpc::meta::Os::Windows, }, + }) + .collect(); + let should_cleanup = test_function.macro_parameters.cleanup; let always_run = test_function.macro_parameters.always_run; let must_succeed = test_function.macro_parameters.must_succeed; @@ -193,7 +212,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { use std::any::Any; let mullvad_client = mullvad_client.downcast::<#mullvad_client_type>().expect("invalid mullvad client"); Box::pin(async move { - #func_name(test_context, rpc, *mullvad_client).await + #func_name(test_context, rpc, *mullvad_client).await.map_err(Into::into) }) } } @@ -202,9 +221,9 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { quote! { |test_context: crate::tests::TestContext, rpc: test_rpc::ServiceClient, - mullvad_client: Box| { + _mullvad_client: Box| { Box::pin(async move { - #func_name(test_context, rpc).await + #func_name(test_context, rpc).await.map_err(Into::into) }) } } @@ -215,7 +234,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { inventory::submit!(crate::tests::test_metadata::TestMetadata { name: stringify!(#func_name), command: stringify!(#func_name), - target_os: #target_os, + targets: &[#targets], mullvad_client_version: #function_mullvad_version, func: #wrapper_closure, priority: #test_function_priority, @@ -237,7 +256,7 @@ struct MacroParameters { cleanup: bool, always_run: bool, must_succeed: bool, - target_os: Option, + targets: Vec, } enum MullvadClient { @@ -269,36 +288,38 @@ struct FunctionParameters { } fn get_test_function_parameters( - inputs: &syn::punctuated::Punctuated, -) -> FunctionParameters { - if inputs.len() > 2 { - match inputs[2].clone() { - syn::FnArg::Typed(pat_type) => { - let mullvad_client = match &*pat_type.ty { - syn::Type::Path(syn::TypePath { path, .. }) => { - match path.segments[0].ident.to_string().as_str() { - "mullvad_management_interface" | "MullvadProxyClient" => { - let mullvad_client_version = - quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New }; - MullvadClient::New { - mullvad_client_type: pat_type.ty, - mullvad_client_version, - } - } - _ => panic!("cannot infer mullvad client type"), - } - } - _ => panic!("unexpected 'mullvad_client' type"), - }; - FunctionParameters { mullvad_client } - } - syn::FnArg::Receiver(_) => panic!("unexpected 'mullvad_client' arg"), - } - } else { - FunctionParameters { + args: &syn::punctuated::Punctuated, +) -> Result { + if args.len() <= 2 { + return Ok(FunctionParameters { mullvad_client: MullvadClient::None { - mullvad_client_version: quote! { test_rpc::mullvad_daemon::MullvadClientVersion::None }, + mullvad_client_version: quote! { + test_rpc::mullvad_daemon::MullvadClientVersion::None + }, }, - } + }); } + + let arg = args[2].clone(); + let syn::FnArg::Typed(pat_type) = arg else { + bail!(arg, "unexpected 'mullvad_client' arg"); + }; + + let syn::Type::Path(syn::TypePath { path, .. }) = &*pat_type.ty else { + bail!(pat_type, "unexpected 'mullvad_client' type"); + }; + + let mullvad_client = match path.segments[0].ident.to_string().as_str() { + "mullvad_management_interface" | "MullvadProxyClient" => { + let mullvad_client_version = + quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New }; + MullvadClient::New { + mullvad_client_type: pat_type.ty, + mullvad_client_version, + } + } + _ => bail!(pat_type, "cannot infer mullvad client type"), + }; + + Ok(FunctionParameters { mullvad_client }) } diff --git a/test/test-rpc/src/client.rs b/test/test-rpc/src/client.rs index b4fb67f5c069..324669de3fc6 100644 --- a/test/test-rpc/src/client.rs +++ b/test/test-rpc/src/client.rs @@ -351,4 +351,26 @@ impl ServiceClient { .make_device_json_old(tarpc::context::current()) .await? } + + pub async fn spawn(&self, opts: SpawnOpts) -> Result { + self.client.spawn(tarpc::context::current(), opts).await? + } + + pub async fn read_child_stdout(&self, pid: u32) -> Result, Error> { + self.client + .read_child_stdout(tarpc::context::current(), pid) + .await? + } + + pub async fn write_child_stdin(&self, pid: u32, data: String) -> Result<(), Error> { + self.client + .write_child_stdin(tarpc::context::current(), pid, data) + .await? + } + + pub async fn close_child_stdin(&self, pid: u32) -> Result<(), Error> { + self.client + .close_child_stdin(tarpc::context::current(), pid) + .await? + } } diff --git a/test/test-rpc/src/lib.rs b/test/test-rpc/src/lib.rs index d1515206015f..e0088a67b50b 100644 --- a/test/test-rpc/src/lib.rs +++ b/test/test-rpc/src/lib.rs @@ -57,6 +57,10 @@ pub enum Error { Timeout, #[error("TCP forward error")] TcpForward, + #[error("Unknown process ID: {0}")] + UnknownPid(u32), + #[error("{0}")] + Other(String), } /// Response from am.i.mullvad.net @@ -80,6 +84,27 @@ impl ExecResult { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SpawnOpts { + pub path: String, + pub args: Vec, + pub env: BTreeMap, + pub attach_stdin: bool, + pub attach_stdout: bool, +} + +impl SpawnOpts { + pub fn new(path: impl Into) -> SpawnOpts { + SpawnOpts { + path: path.into(), + args: Default::default(), + env: Default::default(), + attach_stdin: Default::default(), + attach_stdout: Default::default(), + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub enum AppTrace { Path(PathBuf), @@ -197,6 +222,28 @@ mod service { async fn reboot() -> Result<(), Error>; async fn make_device_json_old() -> Result<(), Error>; + + /// Spawn a child process and return the PID. + async fn spawn(opts: SpawnOpts) -> Result; + + /// Read from stdout of a process spawned through [Service::spawn]. + /// + /// Process must have been spawned with `attach_stdout`. + /// Returns `None` if process stdout is closed. + async fn read_child_stdout(pid: u32) -> Result, Error>; + + /// Write to stdin of a process spawned through [Service::spawn]. + /// + /// Process must have been spawned with `attach_stdin`. + async fn write_child_stdin(pid: u32, data: String) -> Result<(), Error>; + + /// Close stdin of a process spawned through [Service::spawn]. + /// + /// Process must have been spawned with `attach_stdin`. + async fn close_child_stdin(pid: u32) -> Result<(), Error>; + + /// Kill a process spawned through [Service::spawn]. + async fn kill_child(pid: u32) -> Result<(), Error>; } } diff --git a/test/test-rpc/src/transport.rs b/test/test-rpc/src/transport.rs index b8086b41456b..f5f461702688 100644 --- a/test/test-rpc/src/transport.rs +++ b/test/test-rpc/src/transport.rs @@ -1,5 +1,5 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::{channel::mpsc, FutureExt, SinkExt, StreamExt}; use serde::{de::DeserializeOwned, Serialize}; use std::{ fmt::Write, @@ -256,13 +256,12 @@ async fn forward_messages< let mut mullvad_daemon_forwarder = LengthDelimitedCodec::new().framed(mullvad_daemon_forwarder); loop { - match futures::future::select( - futures::future::select(serial_stream.next(), handshaker.1.next()), - futures::future::select(runner_forwarder.next(), mullvad_daemon_forwarder.next()), - ) - .await - { - futures::future::Either::Left((futures::future::Either::Left((Some(frame), _)), _)) => { + futures::select! { + frame = serial_stream.next().fuse() => { + let Some(frame) = frame else { + break Ok(()); + }; + let frame = frame.map_err(ForwardError::SerialConnection)?; // @@ -294,7 +293,12 @@ async fn forward_messages< } } } - futures::future::Either::Left((futures::future::Either::Right((Some(()), _)), _)) => { + + handshake = handshaker.1.next().fuse() => { + if handshake.is_none() { + break Ok(()); + } + log::trace!("shake: send"); // Ping the other end @@ -303,10 +307,12 @@ async fn forward_messages< .await .map_err(ForwardError::HandshakeError)?; } - futures::future::Either::Right(( - futures::future::Either::Left((Some(message), _)), - _, - )) => { + + message = runner_forwarder.next().fuse() => { + let Some(message) = message else { + break Ok(()); + }; + let message = message.map_err(ForwardError::TestRunnerChannel)?; // @@ -321,10 +327,16 @@ async fn forward_messages< .await .map_err(ForwardError::SerialConnection)?; } - futures::future::Either::Right(( - futures::future::Either::Right((Some(data), _)), - _, - )) => { + + data = mullvad_daemon_forwarder.next().fuse() => { + let Some(data) = data else { + // + // Force management interface socket to close + // + let _ = serial_stream.send(Frame::DaemonRpc(Bytes::new())).await; + break Ok(()); + }; + let data = data.map_err(ForwardError::DaemonChannel)?; // @@ -336,17 +348,6 @@ async fn forward_messages< .await .map_err(ForwardError::SerialConnection)?; } - futures::future::Either::Right((futures::future::Either::Right((None, _)), _)) => { - // - // Force management interface socket to close - // - let _ = serial_stream.send(Frame::DaemonRpc(Bytes::new())).await; - - break Ok(()); - } - _ => { - break Ok(()); - } } } } diff --git a/test/test-runner/Cargo.toml b/test/test-runner/Cargo.toml index 8e2ae8cbf687..50f3ddda6a91 100644 --- a/test/test-runner/Cargo.toml +++ b/test/test-runner/Cargo.toml @@ -33,7 +33,7 @@ test-rpc = { path = "../test-rpc" } mullvad-paths = { path = "../../mullvad-paths" } talpid-platform-metadata = { path = "../../talpid-platform-metadata" } -socket2 = { version = "0.5", features = ["all"] } +socket2 = { version = "0.5.4", features = ["all"] } [target."cfg(target_os=\"windows\")".dependencies] talpid-windows = { path = "../../talpid-windows" } diff --git a/test/test-runner/src/main.rs b/test/test-runner/src/main.rs index 3511d78cec55..d864968bbee5 100644 --- a/test/test-runner/src/main.rs +++ b/test/test-runner/src/main.rs @@ -1,10 +1,14 @@ -use futures::{pin_mut, SinkExt, StreamExt}; +use futures::{pin_mut, select, select_biased, FutureExt, SinkExt, StreamExt}; use logging::LOGGER; use std::{ collections::{BTreeMap, HashMap}, net::{IpAddr, SocketAddr}, path::{Path, PathBuf}, + process::Stdio, + sync::Arc, + time::Duration, }; +use util::OnDrop; use tarpc::{context, server::Channel}; use test_rpc::{ @@ -12,12 +16,14 @@ use test_rpc::{ net::SockHandleId, package::Package, transport::GrpcForwarder, - AppTrace, Service, + AppTrace, Service, SpawnOpts, }; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - process::Command, - sync::broadcast::error::TryRecvError, + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{ChildStdin, ChildStdout, Command}, + sync::{broadcast::error::TryRecvError, oneshot, Mutex}, + task, + time::sleep, }; use tokio_util::codec::{Decoder, LengthDelimitedCodec}; @@ -27,9 +33,23 @@ mod logging; mod net; mod package; mod sys; +mod util; -#[derive(Clone)] -pub struct TestServer(pub ()); +#[derive(Clone, Default)] +pub struct TestServer(Arc>); + +#[derive(Default)] +struct State { + spawned_procs: HashMap, +} + +struct SpawnedProcess { + stdout: Option, + stdin: Option, + + #[allow(dead_code)] + abort_handle: OnDrop, +} #[tarpc::server] impl Service for TestServer { @@ -319,6 +339,192 @@ impl Service for TestServer { async fn make_device_json_old(self, _: context::Context) -> Result<(), test_rpc::Error> { app::make_device_json_old().await } + + async fn spawn(self, _: context::Context, opts: SpawnOpts) -> Result { + let mut cmd = Command::new(&opts.path); + cmd.args(&opts.args); + + // Make sure that PATH is updated + // TODO: We currently do not need this on non-Windows + #[cfg(target_os = "windows")] + cmd.env("PATH", sys::get_system_path_var()?); + + cmd.envs(opts.env); + + if opts.attach_stdin { + cmd.stdin(Stdio::piped()); + } else { + cmd.stdin(Stdio::null()); + } + + if opts.attach_stdout { + cmd.stdout(Stdio::piped()); + } + + cmd.stderr(Stdio::piped()); + + let mut child = cmd.kill_on_drop(true).spawn().map_err(|error| { + log::error!("Failed to spawn {}: {error}", opts.path); + test_rpc::Error::Syscall + })?; + + let pid = child + .id() + .expect("Child hasn't been polled to completion yet"); + + log::info!("spawned {} (args={:?}) (pid={pid})", opts.path, opts.args); + + let (abort_tx, abort_rx) = oneshot::channel(); + let abort_handle = || { + let _ = abort_tx.send(()); + }; + + let spawned_process = SpawnedProcess { + stdout: child.stdout.take(), + stdin: child.stdin.take(), + abort_handle: OnDrop::new(Box::new(abort_handle)), + }; + + let mut state = self.0.lock().await; + state.spawned_procs.insert(pid, spawned_process); + drop(state); + + // spawn a task to log child stdout + if let Some(stderr) = child.stderr.take() { + task::spawn(async move { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + loop { + match stderr.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim_end_matches(&['\r', '\n']); + log::info!("child stderr (pid={pid}): {trimmed}"); + line.clear(); + } + Err(e) => { + log::error!("failed to read child stderr (pid={pid}): {e}"); + break; + } + } + } + }); + } + + // spawn a task to monitor if the child exits + task::spawn(async move { + select! { + result = child.wait().fuse() => match result { + Err(e) => { + log::error!("failed to await child process (pid={pid}): {e}"); + } + Ok(status) => { + log::info!("child process (pid={pid}) exited with status: {status}"); + } + }, + + _ = abort_rx.fuse() => { + if let Err(e) = child.kill().await { + log::error!("failed to kill child process (pid={pid}): {e}"); + } + } + } + + let mut state = self.0.lock().await; + state.spawned_procs.remove(&pid); + }); + + Ok(pid) + } + + async fn read_child_stdout( + self, + _: context::Context, + pid: u32, + ) -> Result, test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .get_mut(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + let Some(stdout) = child.stdout.as_mut() else { + return Ok(None); + }; + + let mut buf = vec![0u8; 512]; + + let n = select_biased! { + result = stdout.read(&mut buf).fuse() => result + .map_err(|e| format!("Failed to read from child stdout: {e}")) + .map_err(test_rpc::Error::Other)?, + + _ = sleep(Duration::from_millis(500)).fuse() => return Ok(Some(String::new())), + }; + + // check for EOF + if n == 0 { + child.stdout = None; + return Ok(None); + } + + buf.truncate(n); + let output = String::from_utf8(buf) + .map_err(|_| test_rpc::Error::Other("Child wrote non UTF-8 to stdout".into()))?; + + Ok(Some(output)) + } + + async fn write_child_stdin( + self, + _: context::Context, + pid: u32, + data: String, + ) -> Result<(), test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .get_mut(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + let Some(stdin) = child.stdin.as_mut() else { + return Err(test_rpc::Error::Other("Child stdin is closed.".into())); + }; + + stdin + .write_all(data.as_bytes()) + .await + .map_err(|e| format!("Error writing to child stdin: {e}")) + .map_err(test_rpc::Error::Other)?; + + log::debug!("wrote {} bytes to pid {pid}", data.len()); + + Ok(()) + } + + async fn close_child_stdin(self, _: context::Context, pid: u32) -> Result<(), test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .get_mut(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + child.stdin = None; + + Ok(()) + } + + async fn kill_child(self, _: context::Context, pid: u32) -> Result<(), test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .remove(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + drop(child); // I swear officer, it's not what you think! + + Ok(()) + } } fn get_pipe_status() -> ServiceStatus { @@ -364,7 +570,7 @@ async fn main() -> Result<(), Error> { )); let server = tarpc::server::BaseChannel::with_defaults(runner_transport); - server.execute(TestServer(()).serve()).await; + server.execute(TestServer::default().serve()).await; log::error!("Restarting server since it stopped"); } diff --git a/test/test-runner/src/util.rs b/test/test-runner/src/util.rs new file mode 100644 index 000000000000..03a334321412 --- /dev/null +++ b/test/test-runner/src/util.rs @@ -0,0 +1,23 @@ +/// Drop guard that executes the provided callback function when dropped. +pub struct OnDrop> +where + F: FnOnce() + Send, +{ + callback: Option, +} + +impl Drop for OnDrop { + fn drop(&mut self) { + if let Some(callback) = self.callback.take() { + callback(); + } + } +} + +impl OnDrop { + pub fn new(callback: F) -> Self { + Self { + callback: Some(callback), + } + } +}