Skip to content

Commit 8275764

Browse files
Fix off-by-one error when choosing access method candidates
1 parent 042aa20 commit 8275764

File tree

11 files changed

+86
-102
lines changed

11 files changed

+86
-102
lines changed

mullvad-api/src/bin/relay_list.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ async fn main() {
1111
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
1212
.expect("Failed to load runtime");
1313

14-
let relay_list_request = RelayListProxy::new(
15-
runtime
16-
.mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat())
17-
.await,
18-
)
14+
let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(
15+
ApiConnectionMode::Direct,
16+
ApiConnectionMode::Direct.into_repeat(),
17+
))
1918
.relay_list(None)
2019
.await;
2120

mullvad-api/src/lib.rs

+22-24
Original file line numberDiff line numberDiff line change
@@ -384,38 +384,37 @@ impl Runtime {
384384
}
385385

386386
/// Creates a new request service and returns a handle to it.
387-
async fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
387+
fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
388388
&self,
389389
sni_hostname: Option<String>,
390+
initial_connection_mode: ApiConnectionMode,
390391
proxy_provider: T,
391392
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
392393
) -> rest::RequestServiceHandle {
393394
rest::RequestService::spawn(
394395
sni_hostname,
395396
self.api_availability.handle(),
396397
self.address_cache.clone(),
398+
initial_connection_mode,
397399
proxy_provider,
398400
#[cfg(target_os = "android")]
399401
socket_bypass_tx,
400402
)
401-
.await
402403
}
403404

404405
/// Returns a request factory initialized to create requests for the master API
405-
pub async fn mullvad_rest_handle<
406-
T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
407-
>(
406+
pub fn mullvad_rest_handle<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
408407
&self,
408+
initial_connection_mode: ApiConnectionMode,
409409
proxy_provider: T,
410410
) -> rest::MullvadRestHandle {
411-
let service = self
412-
.new_request_service(
413-
Some(API.host().to_string()),
414-
proxy_provider,
415-
#[cfg(target_os = "android")]
416-
self.socket_bypass_tx.clone(),
417-
)
418-
.await;
411+
let service = self.new_request_service(
412+
Some(API.host().to_string()),
413+
initial_connection_mode,
414+
proxy_provider,
415+
#[cfg(target_os = "android")]
416+
self.socket_bypass_tx.clone(),
417+
);
419418
let token_store = access::AccessTokenStore::new(service.clone());
420419
let factory = rest::RequestFactory::new(API.host(), Some(token_store));
421420

@@ -428,15 +427,14 @@ impl Runtime {
428427
}
429428

430429
/// This is only to be used in test code
431-
pub async fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
432-
let service = self
433-
.new_request_service(
434-
Some(hostname.clone()),
435-
futures::stream::repeat(ApiConnectionMode::Direct),
436-
#[cfg(target_os = "android")]
437-
self.socket_bypass_tx.clone(),
438-
)
439-
.await;
430+
pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
431+
let service = self.new_request_service(
432+
Some(hostname.clone()),
433+
ApiConnectionMode::Direct,
434+
futures::stream::repeat(ApiConnectionMode::Direct),
435+
#[cfg(target_os = "android")]
436+
self.socket_bypass_tx.clone(),
437+
);
440438
let token_store = access::AccessTokenStore::new(service.clone());
441439
let factory = rest::RequestFactory::new(hostname, Some(token_store));
442440

@@ -449,14 +447,14 @@ impl Runtime {
449447
}
450448

451449
/// Returns a new request service handle
452-
pub async fn rest_handle(&self) -> rest::RequestServiceHandle {
450+
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
453451
self.new_request_service(
454452
None,
453+
ApiConnectionMode::Direct,
455454
ApiConnectionMode::Direct.into_repeat(),
456455
#[cfg(target_os = "android")]
457456
None,
458457
)
459-
.await
460458
}
461459

462460
pub fn handle(&mut self) -> &mut tokio::runtime::Handle {

mullvad-api/src/rest.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,12 @@ pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> {
131131

132132
impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> {
133133
/// Constructs a new request service.
134-
pub async fn spawn(
134+
pub fn spawn(
135135
sni_hostname: Option<String>,
136136
api_availability: ApiAvailabilityHandle,
137137
address_cache: AddressCache,
138-
mut proxy_config_provider: T,
138+
initial_connection_mode: ApiConnectionMode,
139+
proxy_config_provider: T,
139140
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
140141
) -> RequestServiceHandle {
141142
let (connector, connector_handle) = HttpsConnectorWithSni::new(
@@ -145,9 +146,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
145146
socket_bypass_tx.clone(),
146147
);
147148

148-
if let Some(config) = proxy_config_provider.next().await {
149-
connector_handle.set_connection_mode(config);
150-
}
149+
connector_handle.set_connection_mode(initial_connection_mode);
151150

152151
let (command_tx, command_rx) = mpsc::unbounded();
153152
let client = Client::builder().build(connector);

mullvad-daemon/src/access_method.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,10 @@ where
260260

261261
/// Create an [`ApiProxy`] which will perform all REST requests against one
262262
/// specific endpoint `proxy_provider`.
263-
pub async fn create_limited_api_proxy(
264-
&mut self,
265-
proxy_provider: ApiConnectionMode,
266-
) -> ApiProxy {
263+
pub fn create_limited_api_proxy(&mut self, proxy_provider: ApiConnectionMode) -> ApiProxy {
267264
let rest_handle = self
268265
.api_runtime
269-
.mullvad_rest_handle(proxy_provider.into_repeat())
270-
.await;
266+
.mullvad_rest_handle(proxy_provider, futures::stream::empty());
271267
ApiProxy::new(rest_handle)
272268
}
273269

mullvad-daemon/src/api.rs

+19-15
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ impl AccessModeSelector {
243243
) -> Result<AccessModeSelectorHandle> {
244244
let (cmd_tx, cmd_rx) = mpsc::unbounded();
245245

246-
let (index, next) = Self::get_next_inner(0, &access_method_settings);
246+
// Always start looking from the position of `Direct`.
247+
let (index, next) = Self::select_next_active(0, &access_method_settings);
247248
let initial_connection_mode =
248249
Self::resolve_inner(next, &relay_selector, &address_cache).await;
249250

@@ -396,25 +397,28 @@ impl AccessModeSelector {
396397
if let Some(access_method) = self.set.take() {
397398
access_method
398399
} else {
399-
let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings);
400-
self.index = index;
400+
let (next_index, next) =
401+
Self::select_next_active(self.index + 1, &self.access_method_settings);
402+
self.index = next_index;
401403
next
402404
}
403405
}
404406

405-
fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
406-
let xs: Vec<_> = access_methods.iter().collect();
407-
for offset in 1..=access_methods.cardinality() {
408-
let index = (start + offset) % access_methods.cardinality();
409-
if let Some(&candidate) = xs.get(index) {
410-
if candidate.enabled {
411-
return (index, candidate.clone());
412-
}
413-
}
414-
}
415-
(0, access_methods.direct().clone())
407+
/// Find the next access method to use.
408+
///
409+
/// * `start`: From which point in `access_methods` to start the search.
410+
/// * `access_methods`: The search space.
411+
fn select_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
412+
access_methods
413+
.iter()
414+
.cloned()
415+
.enumerate()
416+
.cycle()
417+
.skip(start)
418+
.take(access_methods.cardinality())
419+
.find(|(_index, access_method)| access_method.enabled())
420+
.unwrap_or_else(|| (0, access_methods.direct().clone()))
416421
}
417-
418422
fn on_update_access_methods(
419423
&mut self,
420424
tx: ResponseTx<()>,

mullvad-daemon/src/lib.rs

+15-18
Original file line numberDiff line numberDiff line change
@@ -717,9 +717,15 @@ where
717717
.await
718718
.map_err(Error::ApiConnectionModeError)?;
719719

720-
let api_handle = api_runtime
721-
.mullvad_rest_handle(Box::pin(connection_modes_handler.clone().into_stream()))
722-
.await;
720+
let initial_connection_mode = connection_modes_handler
721+
.get_current()
722+
.await
723+
.map_err(Error::ApiConnectionModeError)?;
724+
725+
let api_handle = api_runtime.mullvad_rest_handle(
726+
initial_connection_mode.connection_mode,
727+
Box::pin(connection_modes_handler.clone().into_stream()),
728+
);
723729

724730
let migration_complete = if let Some(migration_data) = migration_data {
725731
migrations::migrate_device(
@@ -787,11 +793,6 @@ where
787793
let _ = param_gen_tx.unbounded_send(settings.tunnel_options.to_owned());
788794
});
789795

790-
let initial_api_endpoint = connection_modes_handler
791-
.get_current()
792-
.await
793-
.map_err(Error::ApiConnectionModeError)?
794-
.endpoint;
795796
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
796797
#[cfg(target_os = "windows")]
797798
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
@@ -800,7 +801,7 @@ where
800801
allow_lan: settings.allow_lan,
801802
block_when_disconnected: settings.block_when_disconnected,
802803
dns_servers: dns::addresses_from_options(&settings.tunnel_options.dns_options),
803-
allowed_endpoint: initial_api_endpoint,
804+
allowed_endpoint: initial_connection_mode.endpoint,
804805
reset_firewall: *target_state != TargetState::Secured,
805806
#[cfg(windows)]
806807
exclude_paths,
@@ -851,7 +852,7 @@ where
851852
relay_list_updater.update().await;
852853

853854
let location_handler = GeoIpHandler::new(
854-
api_runtime.rest_handle().await,
855+
api_runtime.rest_handle(),
855856
internal_event_tx.clone().to_specialized_sender(),
856857
);
857858

@@ -1248,9 +1249,7 @@ where
12481249
GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx),
12491250
SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await,
12501251
TestApiAccessMethodById(tx, method) => self.on_test_api_access_method(tx, method).await,
1251-
TestCustomApiAccessMethod(tx, proxy) => {
1252-
self.on_test_proxy_as_access_method(tx, proxy).await
1253-
}
1252+
TestCustomApiAccessMethod(tx, proxy) => self.on_test_proxy_as_access_method(tx, proxy),
12541253
IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx),
12551254
GetCurrentVersion(tx) => self.on_get_current_version(tx),
12561255
#[cfg(not(target_os = "android"))]
@@ -2478,7 +2477,7 @@ where
24782477
});
24792478
}
24802479

2481-
async fn on_test_proxy_as_access_method(
2480+
fn on_test_proxy_as_access_method(
24822481
&mut self,
24832482
tx: ResponseTx<bool, Error>,
24842483
proxy: talpid_types::net::proxy::CustomProxy,
@@ -2487,7 +2486,7 @@ where
24872486
use talpid_types::net::AllowedEndpoint;
24882487

24892488
let connection_mode = ApiConnectionMode::Proxied(ProxyConfig::from(proxy.clone()));
2490-
let api_proxy = self.create_limited_api_proxy(connection_mode.clone()).await;
2489+
let api_proxy = self.create_limited_api_proxy(connection_mode.clone());
24912490
let proxy_endpoint = AllowedEndpoint {
24922491
endpoint: proxy.get_remote_endpoint().endpoint,
24932492
clients: api::allowed_clients(&connection_mode),
@@ -2533,9 +2532,7 @@ where
25332532
}
25342533
};
25352534

2536-
let api_proxy = self
2537-
.create_limited_api_proxy(test_subject.connection_mode)
2538-
.await;
2535+
let api_proxy = self.create_limited_api_proxy(test_subject.connection_mode);
25392536
let daemon_event_sender = self.tx.to_specialized_sender();
25402537
let access_method_selector = self.connection_modes_handler.clone();
25412538

mullvad-problem-report/src/lib.rs

+2-7
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,9 @@ async fn send_problem_report_inner(
299299
.await
300300
.map_err(Error::CreateRpcClientError)?;
301301

302+
let connection_mode = ApiConnectionMode::try_from_cache(cache_dir).await;
302303
let api_client = mullvad_api::ProblemReportProxy::new(
303-
api_runtime
304-
.mullvad_rest_handle(
305-
ApiConnectionMode::try_from_cache(cache_dir)
306-
.await
307-
.into_repeat(),
308-
)
309-
.await,
304+
api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
310305
);
311306

312307
for _attempt in 0..MAX_SEND_ATTEMPTS {

mullvad-setup/src/main.rs

+2-7
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,9 @@ async fn remove_device() -> Result<(), Error> {
159159
.await
160160
.map_err(Error::RpcInitializationError)?;
161161

162+
let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await;
162163
let proxy = mullvad_api::DevicesProxy::new(
163-
api_runtime
164-
.mullvad_rest_handle(
165-
ApiConnectionMode::try_from_cache(&cache_path)
166-
.await
167-
.into_repeat(),
168-
)
169-
.await,
164+
api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
170165
);
171166

172167
let device_removal = retry_future(

mullvad-types/src/access_method.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl Settings {
8383
}
8484

8585
/// Iterate over references of built-in & custom access methods.
86-
pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> {
86+
pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> + Clone {
8787
use std::iter::once;
8888
once(&self.direct)
8989
.chain(once(&self.mullvad_bridges))

test/test-manager/src/tests/account.rs

+11-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub async fn test_login(
2323
// Instruct daemon to log in
2424
//
2525

26-
clear_devices(&new_device_client().await)
26+
clear_devices(&new_device_client())
2727
.await
2828
.expect("failed to clear devices");
2929

@@ -65,7 +65,7 @@ pub async fn test_too_many_devices(
6565
) -> Result<(), Error> {
6666
log::info!("Using up all devices");
6767

68-
let device_client = new_device_client().await;
68+
let device_client = new_device_client();
6969

7070
const MAX_ATTEMPTS: usize = 15;
7171

@@ -151,7 +151,7 @@ pub async fn test_revoked_device(
151151

152152
log::debug!("Removing current device");
153153

154-
let device_client = new_device_client().await;
154+
let device_client = new_device_client();
155155
retry_if_throttled(|| {
156156
device_client.remove(TEST_CONFIG.account_number.clone(), device_id.clone())
157157
})
@@ -217,9 +217,10 @@ pub async fn clear_devices(device_client: &DevicesProxy) -> Result<(), mullvad_a
217217
Ok(())
218218
}
219219

220-
pub async fn new_device_client() -> DevicesProxy {
221-
let api_endpoint = mullvad_api::ApiEndpoint::from_env_vars();
220+
pub fn new_device_client() -> DevicesProxy {
221+
use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint, API};
222222

223+
let api_endpoint = ApiEndpoint::from_env_vars();
223224
let api_host = format!("api.{}", TEST_CONFIG.mullvad_host);
224225
let api_address = format!("{api_host}:443")
225226
.to_socket_addrs()
@@ -228,17 +229,18 @@ pub async fn new_device_client() -> DevicesProxy {
228229
.unwrap();
229230

230231
// Override the API endpoint to use the one specified in the test config
231-
let _ = mullvad_api::API.override_init(mullvad_api::ApiEndpoint {
232+
let _ = API.override_init(ApiEndpoint {
232233
host: Some(api_host),
233234
address: Some(api_address),
234235
..api_endpoint
235236
});
236237

237238
let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
238239
.expect("failed to create api runtime");
239-
let rest_handle = api
240-
.mullvad_rest_handle(mullvad_api::proxy::ApiConnectionMode::Direct.into_repeat())
241-
.await;
240+
let rest_handle = api.mullvad_rest_handle(
241+
ApiConnectionMode::Direct,
242+
ApiConnectionMode::Direct.into_repeat(),
243+
);
242244
DevicesProxy::new(rest_handle)
243245
}
244246

0 commit comments

Comments
 (0)