Skip to content

Commit f30d6d1

Browse files
committed
Add RPCs for running TCP forwarder on test runner
1 parent d6f28b6 commit f30d6d1

File tree

7 files changed

+247
-4
lines changed

7 files changed

+247
-4
lines changed

test/socks-server/src/lib.rs

+22-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@ use std::io;
33
use std::net::SocketAddr;
44

55
#[derive(err_derive::Error, Debug)]
6+
#[error(no_from)]
67
pub enum Error {
78
#[error(display = "Failed to start SOCKS5 server")]
89
StartSocksServer(#[error(source)] io::Error),
910
}
1011

11-
pub async fn spawn(bind_addr: SocketAddr) -> Result<tokio::task::JoinHandle<()>, Error> {
12+
pub struct Handle {
13+
handle: tokio::task::JoinHandle<()>,
14+
}
15+
16+
/// Spawn a SOCKS server bound to `bind_addr`
17+
pub async fn spawn(bind_addr: SocketAddr) -> Result<Handle, Error> {
1218
let socks_server: fast_socks5::server::Socks5Server =
1319
fast_socks5::server::Socks5Server::bind(bind_addr)
1420
.await
@@ -21,6 +27,8 @@ pub async fn spawn(bind_addr: SocketAddr) -> Result<tokio::task::JoinHandle<()>,
2127
match new_client {
2228
Ok(socket) => {
2329
let fut = socket.upgrade_to_socks5();
30+
31+
// Act as normal SOCKS server
2432
tokio::spawn(async move {
2533
match fut.await {
2634
Ok(_socket) => log::info!("socks client disconnected"),
@@ -34,5 +42,17 @@ pub async fn spawn(bind_addr: SocketAddr) -> Result<tokio::task::JoinHandle<()>,
3442
}
3543
}
3644
});
37-
Ok(handle)
45+
Ok(Handle { handle })
46+
}
47+
48+
impl Handle {
49+
pub fn close(&self) {
50+
self.handle.abort();
51+
}
52+
}
53+
54+
impl Drop for Handle {
55+
fn drop(&mut self) {
56+
self.close();
57+
}
3858
}

test/test-manager/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ async fn main() -> Result<()> {
299299
if display {
300300
instance.wait().await;
301301
}
302-
socks.abort();
302+
socks.close();
303303
result
304304
}
305305
Commands::FormatTestReports { reports } => {

test/test-rpc/src/client.rs

+10
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,16 @@ impl ServiceClient {
213213
.await?
214214
}
215215

216+
/// Start forwarding TCP from a server listening on `bind_addr` to the given address, and return a handle that closes the
217+
/// server when dropped
218+
pub async fn start_tcp_forward(
219+
&self,
220+
bind_addr: SocketAddr,
221+
via_addr: SocketAddr,
222+
) -> Result<crate::net::SockHandle, Error> {
223+
crate::net::SockHandle::start_tcp_forward(self.client.clone(), bind_addr, via_addr).await
224+
}
225+
216226
/// Restarts the app.
217227
///
218228
/// Shuts down a running app, making it disconnect from any current tunnel

test/test-rpc/src/lib.rs

+12
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ pub enum Error {
5353
InvalidUrl,
5454
#[error(display = "Timeout")]
5555
Timeout,
56+
#[error(display = "TCP forward error")]
57+
TcpForward,
5658
}
5759

5860
/// Response from am.i.mullvad.net
@@ -148,6 +150,16 @@ mod service {
148150
/// Perform DNS resolution.
149151
async fn resolve_hostname(hostname: String) -> Result<Vec<SocketAddr>, Error>;
150152

153+
/// Start forwarding TCP bound to the given address. Return an ID that can be used with
154+
/// `stop_tcp_forward`, and the address that the listening socket was actually bound to.
155+
async fn start_tcp_forward(
156+
bind_addr: SocketAddr,
157+
via_addr: SocketAddr,
158+
) -> Result<(net::SockHandleId, SocketAddr), Error>;
159+
160+
/// Stop forwarding TCP that was previously started with `start_tcp_forward`.
161+
async fn stop_tcp_forward(id: net::SockHandleId) -> Result<(), Error>;
162+
151163
/// Restart the Mullvad VPN application.
152164
async fn restart_mullvad_daemon() -> Result<(), Error>;
153165

test/test-rpc/src/net.rs

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use futures::channel::oneshot;
12
use hyper::{Client, Uri};
23
use once_cell::sync::Lazy;
3-
use serde::de::DeserializeOwned;
4+
use serde::{de::DeserializeOwned, Deserialize, Serialize};
5+
use std::net::SocketAddr;
46
use tokio_rustls::rustls::ClientConfig;
57

68
use crate::{AmIMullvad, Error};
@@ -17,6 +19,59 @@ static CLIENT_CONFIG: Lazy<ClientConfig> = Lazy::new(|| {
1719
.with_no_client_auth()
1820
});
1921

22+
#[derive(Debug, Serialize, Deserialize, Clone, Copy, Hash, PartialEq, Eq)]
23+
pub struct SockHandleId(pub usize);
24+
25+
pub struct SockHandle {
26+
stop_tx: Option<oneshot::Sender<()>>,
27+
bind_addr: SocketAddr,
28+
}
29+
30+
impl SockHandle {
31+
pub(crate) async fn start_tcp_forward(
32+
client: crate::service::ServiceClient,
33+
bind_addr: SocketAddr,
34+
via_addr: SocketAddr,
35+
) -> Result<Self, Error> {
36+
let (stop_tx, stop_rx) = oneshot::channel();
37+
38+
let (id, bind_addr) = client
39+
.start_tcp_forward(tarpc::context::current(), bind_addr, via_addr)
40+
.await??;
41+
42+
tokio::spawn(async move {
43+
let _ = stop_rx.await;
44+
45+
log::trace!("Stopping TCP forward");
46+
47+
if let Err(error) = client.stop_tcp_forward(tarpc::context::current(), id).await {
48+
log::error!("Failed to stop TCP forward: {error}");
49+
}
50+
});
51+
52+
Ok(SockHandle {
53+
stop_tx: Some(stop_tx),
54+
bind_addr,
55+
})
56+
}
57+
58+
pub fn stop(&mut self) {
59+
if let Some(stop_tx) = self.stop_tx.take() {
60+
let _ = stop_tx.send(());
61+
}
62+
}
63+
64+
pub fn bind_addr(&self) -> SocketAddr {
65+
self.bind_addr
66+
}
67+
}
68+
69+
impl Drop for SockHandle {
70+
fn drop(&mut self) {
71+
self.stop()
72+
}
73+
}
74+
2075
pub async fn geoip_lookup(mullvad_host: String) -> Result<AmIMullvad, Error> {
2176
let uri = Uri::try_from(format!("https://ipv4.am.i.{mullvad_host}/json"))
2277
.map_err(|_| Error::InvalidUrl)?;

test/test-runner/src/forward.rs

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
use once_cell::sync::Lazy;
2+
use std::collections::HashMap;
3+
use std::net::SocketAddr;
4+
use std::sync::atomic::{AtomicUsize, Ordering};
5+
use std::sync::{Arc, Mutex};
6+
use test_rpc::net::SockHandleId;
7+
use tokio::net::TcpListener;
8+
use tokio::net::TcpStream;
9+
10+
static SERVERS: Lazy<Mutex<HashMap<SockHandleId, Handle>>> =
11+
Lazy::new(|| Mutex::new(HashMap::new()));
12+
13+
/// Spawn a TCP forwarder that sends TCP via `via_addr`
14+
pub async fn start_server(
15+
bind_addr: SocketAddr,
16+
via_addr: SocketAddr,
17+
) -> Result<(SockHandleId, SocketAddr), test_rpc::Error> {
18+
let next_nonce = {
19+
static NONCE: AtomicUsize = AtomicUsize::new(0);
20+
|| NONCE.fetch_add(1, Ordering::Relaxed)
21+
};
22+
let id = SockHandleId(next_nonce());
23+
24+
let handle = tcp_forward(bind_addr, via_addr).await.map_err(|error| {
25+
log::error!("Failed to start TCP forwarder listener: {error}");
26+
test_rpc::Error::TcpForward
27+
})?;
28+
29+
let bind_addr = handle.local_addr();
30+
31+
let mut servers = SERVERS.lock().unwrap();
32+
servers.insert(id, handle);
33+
34+
Ok((id, bind_addr))
35+
}
36+
37+
/// Stop TCP forwarder given some ID returned by `start_server`
38+
pub async fn stop_server(id: SockHandleId) -> Result<(), test_rpc::Error> {
39+
let handle = {
40+
let mut servers = SERVERS.lock().unwrap();
41+
servers.remove(&id)
42+
};
43+
44+
if let Some(handle) = handle {
45+
handle.close();
46+
}
47+
Ok(())
48+
}
49+
50+
struct Handle {
51+
handle: tokio::task::JoinHandle<()>,
52+
bind_addr: SocketAddr,
53+
clients: Arc<Mutex<Vec<tokio::task::JoinHandle<()>>>>,
54+
}
55+
56+
impl Handle {
57+
pub fn close(&self) {
58+
self.handle.abort();
59+
60+
let mut clients = self.clients.lock().unwrap();
61+
for client in clients.drain(..) {
62+
client.abort();
63+
}
64+
}
65+
66+
pub fn local_addr(&self) -> SocketAddr {
67+
self.bind_addr
68+
}
69+
}
70+
71+
impl Drop for Handle {
72+
fn drop(&mut self) {
73+
self.close();
74+
}
75+
}
76+
77+
/// Forward TCP traffic via `proxy_addr`
78+
async fn tcp_forward(
79+
bind_addr: SocketAddr,
80+
proxy_addr: SocketAddr,
81+
) -> Result<Handle, test_rpc::Error> {
82+
let listener = TcpListener::bind(&bind_addr).await.map_err(|error| {
83+
log::error!("Failed to bind TCP forward socket: {error}");
84+
test_rpc::Error::TcpForward
85+
})?;
86+
let bind_addr = listener.local_addr().map_err(|error| {
87+
log::error!("Failed to get TCP socket addr: {error}");
88+
test_rpc::Error::TcpForward
89+
})?;
90+
91+
let clients = Arc::new(Mutex::new(vec![]));
92+
93+
let clients_copy = clients.clone();
94+
95+
let handle = tokio::spawn(async move {
96+
loop {
97+
match listener.accept().await {
98+
Ok((mut client, _addr)) => {
99+
let client_handle = tokio::spawn(async move {
100+
let mut proxy = match TcpStream::connect(proxy_addr).await {
101+
Ok(proxy) => proxy,
102+
Err(error) => {
103+
log::error!("failed to connect to TCP proxy: {error}");
104+
return;
105+
}
106+
};
107+
108+
if let Err(error) =
109+
tokio::io::copy_bidirectional(&mut client, &mut proxy).await
110+
{
111+
log::error!("copy_directional failed: {error}");
112+
}
113+
});
114+
clients_copy.lock().unwrap().push(client_handle);
115+
}
116+
Err(error) => {
117+
log::error!("failed to accept TCP client: {error}");
118+
}
119+
}
120+
}
121+
});
122+
Ok(Handle {
123+
handle,
124+
bind_addr,
125+
clients,
126+
})
127+
}

test/test-runner/src/main.rs

+19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use tarpc::context;
1010
use tarpc::server::Channel;
1111
use test_rpc::{
1212
mullvad_daemon::{ServiceStatus, SOCKET_PATH},
13+
net::SockHandleId,
1314
package::Package,
1415
transport::GrpcForwarder,
1516
AppTrace, Service,
@@ -22,6 +23,7 @@ use tokio::{
2223
use tokio_util::codec::{Decoder, LengthDelimitedCodec};
2324

2425
mod app;
26+
mod forward;
2527
mod logging;
2628
mod net;
2729
mod package;
@@ -167,6 +169,23 @@ impl Service for TestServer {
167169
.collect())
168170
}
169171

172+
async fn start_tcp_forward(
173+
self,
174+
_: context::Context,
175+
bind_addr: SocketAddr,
176+
via_addr: SocketAddr,
177+
) -> Result<(SockHandleId, SocketAddr), test_rpc::Error> {
178+
forward::start_server(bind_addr, via_addr).await
179+
}
180+
181+
async fn stop_tcp_forward(
182+
self,
183+
_: context::Context,
184+
id: SockHandleId,
185+
) -> Result<(), test_rpc::Error> {
186+
forward::stop_server(id).await
187+
}
188+
170189
async fn get_interface_ip(
171190
self,
172191
_: context::Context,

0 commit comments

Comments
 (0)