Skip to content

Commit 885d8e0

Browse files
hultheMarkusPettersson98
authored andcommitted
Add a safe FFI wrapper in wireguard-go-rs
Also: - Use u64 instead of *mut void as log context - Make Tunnel::set_config take a &mut self - Use dyn Error instead of i32s for wg errors
1 parent 856f9ab commit 885d8e0

File tree

16 files changed

+427
-268
lines changed

16 files changed

+427
-268
lines changed

Cargo.lock

+7-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

talpid-wireguard/src/connectivity_check.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use crate::{
55
use std::{
66
cmp,
77
net::Ipv4Addr,
8-
sync::{mpsc, Mutex, Weak},
8+
sync::{mpsc, Weak},
99
time::{Duration, Instant},
1010
};
11+
use tokio::sync::Mutex;
1112

1213
use super::{Tunnel, TunnelError};
1314

@@ -211,11 +212,12 @@ impl ConnectivityMonitor {
211212

212213
/// If None is returned, then the underlying tunnel has already been closed and all subsequent
213214
/// calls will also return None.
215+
///
216+
/// NOTE: will panic if called from within a tokio runtime.
214217
fn get_stats(&self) -> Option<Result<StatsMap, Error>> {
215218
self.tunnel_handle
216219
.upgrade()?
217-
.lock()
218-
.ok()?
220+
.blocking_lock()
219221
.as_ref()
220222
.and_then(|tunnel| match tunnel.get_tunnel_stats() {
221223
Ok(stats) if stats.is_empty() => {
@@ -551,7 +553,7 @@ mod test {
551553
rx_bytes: 0,
552554
},
553555
);
554-
let peers = Mutex::new(map);
556+
let peers = std::sync::Mutex::new(map);
555557
Self {
556558
on_get_stats: Box::new(move || {
557559
let mut peers = peers.lock().unwrap();
@@ -608,7 +610,7 @@ mod test {
608610
}
609611

610612
fn set_config(
611-
&self,
613+
&mut self,
612614
_config: Config,
613615
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>> {
614616
Box::pin(async { Ok(()) })
@@ -746,7 +748,7 @@ mod test {
746748
rx_bytes: 0,
747749
},
748750
);
749-
let tunnel_stats = Mutex::new(map);
751+
let tunnel_stats = std::sync::Mutex::new(map);
750752

751753
let pinger = MockPinger::default();
752754
let (_tunnel_anchor, tunnel) = MockTunnel::new(move || {

talpid-wireguard/src/lib.rs

+23-22
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl Error {
115115
Error::CreateObfuscatorError(_) => true,
116116
Error::ObfuscatorError(_) => true,
117117
Error::PskNegotiationError(_) => true,
118-
Error::TunnelError(TunnelError::RecoverableStartWireguardError) => true,
118+
Error::TunnelError(TunnelError::RecoverableStartWireguardError(..)) => true,
119119

120120
Error::SetupRoutingError(error) => error.is_recoverable(),
121121

@@ -144,7 +144,7 @@ impl Error {
144144
pub struct WireguardMonitor {
145145
runtime: tokio::runtime::Handle,
146146
/// Tunnel implementation
147-
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
147+
tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
148148
/// Callback to signal tunnel events
149149
event_callback: EventCallback,
150150
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
@@ -306,7 +306,7 @@ impl WireguardMonitor {
306306
let (pinger_tx, pinger_rx) = sync_mpsc::channel();
307307
let monitor = WireguardMonitor {
308308
runtime: args.runtime.clone(),
309-
tunnel: Arc::new(Mutex::new(Some(tunnel))),
309+
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
310310
event_callback,
311311
close_msg_receiver: close_obfs_listener,
312312
pinger_stop_sender: pinger_tx,
@@ -473,7 +473,7 @@ impl WireguardMonitor {
473473

474474
#[allow(clippy::too_many_arguments)]
475475
async fn config_ephemeral_peers<F>(
476-
tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
476+
tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
477477
config: &mut Config,
478478
retry_attempt: u32,
479479
on_event: F,
@@ -579,7 +579,7 @@ impl WireguardMonitor {
579579
#[cfg(daita)]
580580
if config.daita {
581581
// Start local DAITA machines
582-
let mut tunnel = tunnel.lock().unwrap();
582+
let mut tunnel = tunnel.lock().await;
583583
if let Some(tunnel) = tunnel.as_mut() {
584584
tunnel
585585
.start_daita()
@@ -601,7 +601,7 @@ impl WireguardMonitor {
601601
/// Reconfigures the tunnel to use the provided config while potentially modifying the config
602602
/// and restarting the obfuscation provider. Returns the new config used by the new tunnel.
603603
async fn reconfigure_tunnel(
604-
tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
604+
tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
605605
mut config: Config,
606606
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
607607
close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
@@ -625,11 +625,12 @@ impl WireguardMonitor {
625625
}
626626
}
627627

628+
let mut tunnel = tunnel.lock().await;
629+
628630
let set_config_future = tunnel
629-
.lock()
630-
.unwrap()
631-
.as_ref()
631+
.as_mut()
632632
.map(|tunnel| tunnel.set_config(config.clone()));
633+
633634
if let Some(f) = set_config_future {
634635
f.await
635636
.map_err(Error::TunnelError)
@@ -846,8 +847,11 @@ impl WireguardMonitor {
846847
wait_result
847848
}
848849

850+
/// Tear down the tunnel.
851+
///
852+
/// NOTE: will panic if called from within a tokio runtime.
849853
fn stop_tunnel(&mut self) {
850-
match self.tunnel.lock().expect("Tunnel lock poisoned").take() {
854+
match self.tunnel.blocking_lock().take() {
851855
Some(tunnel) => {
852856
if let Err(e) = tunnel.stop() {
853857
log::error!("{}", e.display_chain_with_msg("Failed to stop tunnel"));
@@ -1028,10 +1032,10 @@ pub(crate) trait Tunnel: Send {
10281032
fn get_interface_name(&self) -> String;
10291033
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
10301034
fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>;
1031-
fn set_config(
1032-
&self,
1035+
fn set_config<'a>(
1036+
&'a mut self,
10331037
_config: Config,
1034-
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>;
1038+
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'a>>;
10351039
#[cfg(daita)]
10361040
/// A [`Tunnel`] capable of using DAITA.
10371041
fn start_daita(&mut self) -> std::result::Result<(), TunnelError>;
@@ -1045,22 +1049,19 @@ pub enum TunnelError {
10451049
/// This is an error returned by the implementation that indicates that trying to establish the
10461050
/// tunnel again should work normally. The error encountered is known to be sporadic.
10471051
#[error("Recoverable error while starting wireguard tunnel")]
1048-
RecoverableStartWireguardError,
1052+
RecoverableStartWireguardError(#[source] Box<dyn std::error::Error + Send>),
10491053

10501054
/// An unrecoverable error occurred while starting the wireguard tunnel
10511055
///
10521056
/// This is an error returned by the implementation that indicates that trying to establish the
10531057
/// tunnel again will likely fail with the same error. An error was encountered during tunnel
10541058
/// configuration which can't be dealt with gracefully.
10551059
#[error("Failed to start wireguard tunnel")]
1056-
FatalStartWireguardError,
1060+
FatalStartWireguardError(#[source] Box<dyn std::error::Error + Send>),
10571061

10581062
/// Failed to tear down wireguard tunnel.
1059-
#[error("Failed to stop wireguard tunnel. Status: {status}")]
1060-
StopWireguardError {
1061-
/// Returned error code
1062-
status: i32,
1063-
},
1063+
#[error("Failed to tear down wireguard tunnel")]
1064+
StopWireguardError(#[source] Box<dyn std::error::Error + Send>),
10641065

10651066
/// Error whilst trying to parse the WireGuard config to read the stats
10661067
#[error("Reading tunnel stats failed")]
@@ -1114,8 +1115,8 @@ pub enum TunnelError {
11141115

11151116
/// Failed to receive DAITA event
11161117
#[cfg(daita)]
1117-
#[error("Failed to receive DAITA event")]
1118-
DaitaReceiveEvent(i32),
1118+
#[error("Failed to start DAITA")]
1119+
StartDaita(#[source] Box<dyn std::error::Error + Send>),
11191120

11201121
/// This tunnel does not support DAITA.
11211122
#[cfg(daita)]

talpid-wireguard/src/logging.rs

+18-17
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@ use once_cell::sync::Lazy;
22
use parking_lot::Mutex;
33
use std::{collections::HashMap, fmt, fs, io::Write, path::Path};
44

5-
static LOG_MUTEX: Lazy<Mutex<HashMap<u32, fs::File>>> = Lazy::new(|| Mutex::new(HashMap::new()));
5+
static LOG_MUTEX: Lazy<Mutex<LogState>> = Lazy::new(|| Mutex::new(LogState::default()));
66

7-
static mut LOG_CONTEXT_NEXT_ORDINAL: u32 = 0;
7+
#[derive(Default)]
8+
struct LogState {
9+
map: HashMap<u64, fs::File>,
10+
next_ordinal: u64,
11+
}
812

913
/// Errors encountered when initializing logging
1014
#[derive(thiserror::Error, Debug)]
@@ -14,18 +18,15 @@ pub enum Error {
1418
PrepareLogFileError(#[from] std::io::Error),
1519
}
1620

17-
pub fn initialize_logging(log_path: Option<&Path>) -> Result<u32, Error> {
21+
pub fn initialize_logging(log_path: Option<&Path>) -> Result<u64, Error> {
1822
let log_file = create_log_file(log_path)?;
1923

20-
let log_context_ordinal = unsafe {
21-
let mut map = LOG_MUTEX.lock();
22-
let ordinal = LOG_CONTEXT_NEXT_ORDINAL;
23-
LOG_CONTEXT_NEXT_ORDINAL += 1;
24-
map.insert(ordinal, log_file);
25-
ordinal
26-
};
24+
let mut state = LOG_MUTEX.lock();
25+
let ordinal = state.next_ordinal;
26+
state.next_ordinal += 1;
27+
state.map.insert(ordinal, log_file);
2728

28-
Ok(log_context_ordinal)
29+
Ok(ordinal)
2930
}
3031

3132
#[cfg(target_os = "windows")]
@@ -39,9 +40,9 @@ fn create_log_file(log_path: Option<&Path>) -> Result<fs::File, Error> {
3940
.map_err(Error::PrepareLogFileError)
4041
}
4142

42-
pub fn clean_up_logging(ordinal: u32) {
43-
let mut map = LOG_MUTEX.lock();
44-
map.remove(&ordinal);
43+
pub fn clean_up_logging(ordinal: u64) {
44+
let mut state = LOG_MUTEX.lock();
45+
state.map.remove(&ordinal);
4546
}
4647

4748
pub enum LogLevel {
@@ -71,9 +72,9 @@ impl AsRef<str> for LogLevel {
7172
}
7273
}
7374

74-
pub fn log(context: u32, level: LogLevel, tag: &str, msg: &str) {
75-
let mut map = LOG_MUTEX.lock();
76-
if let Some(logfile) = map.get_mut(&{ context }) {
75+
pub fn log(context: u64, level: LogLevel, tag: &str, msg: &str) {
76+
let mut state = LOG_MUTEX.lock();
77+
if let Some(logfile) = state.map.get_mut(&context) {
7778
log_inner(logfile, level, tag, msg);
7879
}
7980
}

talpid-wireguard/src/wireguard_go/daita.rs

-46
This file was deleted.

0 commit comments

Comments
 (0)