Skip to content

Commit 8ceeb4d

Browse files
Set DAITA per peer instead of device
Add safety comments
1 parent accd268 commit 8ceeb4d

File tree

3 files changed

+30
-23
lines changed

3 files changed

+30
-23
lines changed

Diff for: talpid-wireguard/src/wireguard_go/daita.rs

+16-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use std::{ffi::CStr, io};
22

3+
use talpid_types::net::wireguard::PublicKey;
4+
use wireguard_go_rs::wgActivateDaita;
5+
36
/// Maximum number of events that can be stored in the underlying buffer
47
const EVENTS_CAPACITY: u32 = 1000;
58
/// Maximum number of actions that can be stored in the underlying buffer
@@ -11,12 +14,21 @@ pub struct Session {
1114
}
1215

1316
impl Session {
14-
/// Call `wgActivateDaita` for an existing WireGuard interface
15-
pub(super) fn from_adapter(tunnel_handle: i32, machines: &CStr) -> io::Result<Session> {
17+
/// Enable DAITA for an existing WireGuard interface.
18+
pub(super) fn from_adapter(
19+
tunnel_handle: i32,
20+
peer_public_key: &PublicKey,
21+
machines: &CStr,
22+
) -> io::Result<Session> {
23+
// SAFETY:
24+
// peer_public_key and machines lives for the duration of this function call.
25+
26+
// TODO: ´machines` must be valid UTF-8
1627
let res = unsafe {
17-
super::wgActivateDaita(
18-
machines.as_ptr(),
28+
wgActivateDaita(
1929
tunnel_handle,
30+
peer_public_key.as_bytes().as_ptr(),
31+
machines.as_ptr(),
2032
EVENTS_CAPACITY,
2133
ACTIONS_CAPACITY,
2234
)
@@ -29,7 +41,4 @@ impl Session {
2941
_tunnel_handle: tunnel_handle,
3042
})
3143
}
32-
33-
// TODO:
34-
// pub(super) fn stop(self) { ... }
3544
}

Diff for: talpid-wireguard/src/wireguard_go/mod.rs

+13-15
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ impl Drop for LoggingContext {
4545

4646
pub struct WgGoTunnel {
4747
interface_name: String,
48-
handle: Option<i32>, // TODO(sebastian): Remove option
48+
tunnel_handle: i32,
4949
// holding on to the tunnel device and the log file ensures that the associated file handles
5050
// live long enough and get closed when the tunnel is stopped
5151
_tunnel_device: Tun,
@@ -55,6 +55,7 @@ pub struct WgGoTunnel {
5555
tun_provider: Arc<Mutex<TunProvider>>,
5656
daita_handle: Option<daita::Session>,
5757
resource_dir: PathBuf,
58+
config: Config,
5859
}
5960

6061
impl WgGoTunnel {
@@ -95,18 +96,17 @@ impl WgGoTunnel {
9596
Self::bypass_tunnel_sockets(&mut tunnel_device, handle)
9697
.map_err(TunnelError::BypassError)?;
9798

98-
let wg_go_tunnel = WgGoTunnel {
99+
Ok(WgGoTunnel {
99100
interface_name,
100-
handle: Some(handle),
101+
tunnel_handle: handle,
101102
_tunnel_device: tunnel_device,
102103
_logging_context: logging_context,
103104
#[cfg(target_os = "android")]
104105
tun_provider: tun_provider_clone,
105106
resource_dir: resource_dir.to_owned(),
106107
daita_handle: None,
107-
};
108-
109-
Ok(wg_go_tunnel)
108+
config: config.clone(),
109+
})
110110
}
111111

112112
fn create_tunnel_config(
@@ -156,11 +156,9 @@ impl WgGoTunnel {
156156
}
157157

158158
fn stop_tunnel(&mut self) -> Result<()> {
159-
if let Some(handle) = self.handle.take() {
160-
let status = unsafe { wgTurnOff(handle) };
161-
if status < 0 {
162-
return Err(TunnelError::StopWireguardError { status });
163-
}
159+
let status = unsafe { wgTurnOff(self.tunnel_handle) };
160+
if status < 0 {
161+
return Err(TunnelError::StopWireguardError { status });
164162
}
165163
Ok(())
166164
}
@@ -215,7 +213,7 @@ impl Tunnel for WgGoTunnel {
215213

216214
fn get_tunnel_stats(&self) -> Result<StatsMap> {
217215
let config_str = unsafe {
218-
let ptr = wgGetConfig(self.handle.unwrap());
216+
let ptr = wgGetConfig(self.tunnel_handle);
219217
if ptr.is_null() {
220218
log::error!("Failed to get config !");
221219
return Err(TunnelError::GetConfigError);
@@ -250,7 +248,7 @@ impl Tunnel for WgGoTunnel {
250248
config: Config,
251249
) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
252250
let wg_config_str = config.to_userspace_format();
253-
let handle = self.handle.unwrap();
251+
let handle = self.tunnel_handle;
254252
#[cfg(target_os = "android")]
255253
let tun_provider = self.tun_provider.clone();
256254
Box::pin(async move {
@@ -297,8 +295,8 @@ impl Tunnel for WgGoTunnel {
297295
})?;
298296

299297
log::info!("Initializing DAITA for wireguard device");
300-
let tunnel_handle = self.handle.expect("Tunnel should be active");
301-
let session = daita::Session::from_adapter(tunnel_handle, machines)
298+
let peer_public_key = &self.config.entry_peer.public_key;
299+
let session = daita::Session::from_adapter(self.tunnel_handle, peer_public_key, machines)
302300
.expect("Wireguard-go should fetch current tunnel from ID");
303301
self.daita_handle = Some(session);
304302

Diff for: wireguard-go-rs/libwg/libwg.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include <stdbool.h>
33

44
/// Activate DAITA for the specified tunnel.
5-
int32_t wgActivateDaita(int8_t* machines, int32_t tunnelHandle, uint32_t eventsCapacity, uint32_t actionsCapacity);
5+
int32_t wgActivateDaita(int32_t tunnelHandle, uint8_t* noisePublic, char* machines, uint32_t eventsCapacity, uint32_t actionsCapacity);
66
char* wgGetConfig(int32_t tunnelHandle);
77
int32_t wgSetConfig(int32_t tunnelHandle, char* cSettings);
88
void wgFreePtr(void*);

0 commit comments

Comments
 (0)