Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix off-by-one error when choosing access method candidates #5811

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");

let relay_list_request = RelayListProxy::new(
runtime
.mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat())
.await,
)
let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(
ApiConnectionMode::Direct,
ApiConnectionMode::Direct.into_repeat(),
))
.relay_list(None)
.await;

Expand Down
46 changes: 22 additions & 24 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,38 +382,37 @@ impl Runtime {
}

/// Creates a new request service and returns a handle to it.
async fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
&self,
sni_hostname: Option<String>,
initial_connection_mode: ApiConnectionMode,
proxy_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.handle(),
self.address_cache.clone(),
initial_connection_mode,
proxy_provider,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
.await
}

/// Returns a request factory initialized to create requests for the master API
pub async fn mullvad_rest_handle<
T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static,
>(
pub fn mullvad_rest_handle<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
&self,
initial_connection_mode: ApiConnectionMode,
proxy_provider: T,
) -> rest::MullvadRestHandle {
let service = self
.new_request_service(
Some(API.host().to_string()),
proxy_provider,
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
)
.await;
let service = self.new_request_service(
Some(API.host().to_string()),
initial_connection_mode,
proxy_provider,
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
let token_store = access::AccessTokenStore::new(service.clone());
let factory = rest::RequestFactory::new(API.host(), Some(token_store));

Expand All @@ -426,15 +425,14 @@ impl Runtime {
}

/// This is only to be used in test code
pub async fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
let service = self
.new_request_service(
Some(hostname.clone()),
futures::stream::repeat(ApiConnectionMode::Direct),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
)
.await;
pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct,
futures::stream::repeat(ApiConnectionMode::Direct),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
let token_store = access::AccessTokenStore::new(service.clone());
let factory = rest::RequestFactory::new(hostname, Some(token_store));

Expand All @@ -447,14 +445,14 @@ impl Runtime {
}

/// Returns a new request service handle
pub async fn rest_handle(&self) -> rest::RequestServiceHandle {
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct,
ApiConnectionMode::Direct.into_repeat(),
#[cfg(target_os = "android")]
None,
)
.await
}

pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
Expand Down
9 changes: 4 additions & 5 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> {

impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> {
/// Constructs a new request service.
pub async fn spawn(
pub fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailabilityHandle,
address_cache: AddressCache,
mut proxy_config_provider: T,
initial_connection_mode: ApiConnectionMode,
proxy_config_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
Expand All @@ -145,9 +146,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
socket_bypass_tx.clone(),
);

if let Some(config) = proxy_config_provider.next().await {
connector_handle.set_connection_mode(config);
}
connector_handle.set_connection_mode(initial_connection_mode);

let (command_tx, command_rx) = mpsc::unbounded();
let client = Client::builder().build(connector);
Expand Down
8 changes: 2 additions & 6 deletions mullvad-daemon/src/access_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,10 @@ where

/// Create an [`ApiProxy`] which will perform all REST requests against one
/// specific endpoint `proxy_provider`.
pub async fn create_limited_api_proxy(
&mut self,
proxy_provider: ApiConnectionMode,
) -> ApiProxy {
pub fn create_limited_api_proxy(&mut self, proxy_provider: ApiConnectionMode) -> ApiProxy {
let rest_handle = self
.api_runtime
.mullvad_rest_handle(proxy_provider.into_repeat())
.await;
.mullvad_rest_handle(proxy_provider, futures::stream::empty());
ApiProxy::new(rest_handle)
}

Expand Down
34 changes: 19 additions & 15 deletions mullvad-daemon/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ impl AccessModeSelector {
) -> Result<AccessModeSelectorHandle> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();

let (index, next) = Self::get_next_inner(0, &access_method_settings);
// Always start looking from the position of `Direct`.
let (index, next) = Self::select_next_active(0, &access_method_settings);
let initial_connection_mode =
Self::resolve_inner(next, &relay_selector, &address_cache).await;

Expand Down Expand Up @@ -396,25 +397,28 @@ impl AccessModeSelector {
if let Some(access_method) = self.set.take() {
access_method
} else {
let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings);
self.index = index;
let (next_index, next) =
Self::select_next_active(self.index + 1, &self.access_method_settings);
self.index = next_index;
next
}
}

fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
let xs: Vec<_> = access_methods.iter().collect();
for offset in 1..=access_methods.cardinality() {
let index = (start + offset) % access_methods.cardinality();
if let Some(&candidate) = xs.get(index) {
if candidate.enabled {
return (index, candidate.clone());
}
}
}
(0, access_methods.direct().clone())
/// Find the next access method to use.
///
/// * `start`: From which point in `access_methods` to start the search.
/// * `access_methods`: The search space.
fn select_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) {
access_methods
.iter()
.cloned()
.enumerate()
.cycle()
.skip(start)
.take(access_methods.cardinality())
.find(|(_index, access_method)| access_method.enabled())
.unwrap_or_else(|| (0, access_methods.direct().clone()))
}

fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
Expand Down
33 changes: 15 additions & 18 deletions mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,15 @@ where
.await
.map_err(Error::ApiConnectionModeError)?;

let api_handle = api_runtime
.mullvad_rest_handle(Box::pin(connection_modes_handler.clone().into_stream()))
.await;
let initial_connection_mode = connection_modes_handler
.get_current()
.await
.map_err(Error::ApiConnectionModeError)?;

let api_handle = api_runtime.mullvad_rest_handle(
initial_connection_mode.connection_mode,
Box::pin(connection_modes_handler.clone().into_stream()),
);

let migration_complete = if let Some(migration_data) = migration_data {
migrations::migrate_device(
Expand Down Expand Up @@ -787,11 +793,6 @@ where
let _ = param_gen_tx.unbounded_send(settings.tunnel_options.to_owned());
});

let initial_api_endpoint = connection_modes_handler
.get_current()
.await
.map_err(Error::ApiConnectionModeError)?
.endpoint;
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
#[cfg(target_os = "windows")]
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
Expand All @@ -800,7 +801,7 @@ where
allow_lan: settings.allow_lan,
block_when_disconnected: settings.block_when_disconnected,
dns_servers: dns::addresses_from_options(&settings.tunnel_options.dns_options),
allowed_endpoint: initial_api_endpoint,
allowed_endpoint: initial_connection_mode.endpoint,
reset_firewall: *target_state != TargetState::Secured,
#[cfg(windows)]
exclude_paths,
Expand Down Expand Up @@ -851,7 +852,7 @@ where
relay_list_updater.update().await;

let location_handler = GeoIpHandler::new(
api_runtime.rest_handle().await,
api_runtime.rest_handle(),
internal_event_tx.clone().to_specialized_sender(),
);

Expand Down Expand Up @@ -1248,9 +1249,7 @@ where
GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx),
SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await,
TestApiAccessMethodById(tx, method) => self.on_test_api_access_method(tx, method).await,
TestCustomApiAccessMethod(tx, proxy) => {
self.on_test_proxy_as_access_method(tx, proxy).await
}
TestCustomApiAccessMethod(tx, proxy) => self.on_test_proxy_as_access_method(tx, proxy),
IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx),
GetCurrentVersion(tx) => self.on_get_current_version(tx),
#[cfg(not(target_os = "android"))]
Expand Down Expand Up @@ -2478,7 +2477,7 @@ where
});
}

async fn on_test_proxy_as_access_method(
fn on_test_proxy_as_access_method(
&mut self,
tx: ResponseTx<bool, Error>,
proxy: talpid_types::net::proxy::CustomProxy,
Expand All @@ -2487,7 +2486,7 @@ where
use talpid_types::net::AllowedEndpoint;

let connection_mode = ApiConnectionMode::Proxied(ProxyConfig::from(proxy.clone()));
let api_proxy = self.create_limited_api_proxy(connection_mode.clone()).await;
let api_proxy = self.create_limited_api_proxy(connection_mode.clone());
let proxy_endpoint = AllowedEndpoint {
endpoint: proxy.get_remote_endpoint().endpoint,
clients: api::allowed_clients(&connection_mode),
Expand Down Expand Up @@ -2533,9 +2532,7 @@ where
}
};

let api_proxy = self
.create_limited_api_proxy(test_subject.connection_mode)
.await;
let api_proxy = self.create_limited_api_proxy(test_subject.connection_mode);
let daemon_event_sender = self.tx.to_specialized_sender();
let access_method_selector = self.connection_modes_handler.clone();

Expand Down
9 changes: 2 additions & 7 deletions mullvad-problem-report/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,9 @@ async fn send_problem_report_inner(
.await
.map_err(Error::CreateRpcClientError)?;

let connection_mode = ApiConnectionMode::try_from_cache(cache_dir).await;
let api_client = mullvad_api::ProblemReportProxy::new(
api_runtime
.mullvad_rest_handle(
ApiConnectionMode::try_from_cache(cache_dir)
.await
.into_repeat(),
)
.await,
api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
);

for _attempt in 0..MAX_SEND_ATTEMPTS {
Expand Down
9 changes: 2 additions & 7 deletions mullvad-setup/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,9 @@ async fn remove_device() -> Result<(), Error> {
.await
.map_err(Error::RpcInitializationError)?;

let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await;
let proxy = mullvad_api::DevicesProxy::new(
api_runtime
.mullvad_rest_handle(
ApiConnectionMode::try_from_cache(&cache_path)
.await
.into_repeat(),
)
.await,
api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()),
);

let device_removal = retry_future(
Expand Down
2 changes: 1 addition & 1 deletion mullvad-types/src/access_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl Settings {
}

/// Iterate over references of built-in & custom access methods.
pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> {
pub fn iter(&self) -> impl Iterator<Item = &AccessMethodSetting> + Clone {
use std::iter::once;
once(&self.direct)
.chain(once(&self.mullvad_bridges))
Expand Down
20 changes: 11 additions & 9 deletions test/test-manager/src/tests/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub async fn test_login(
// Instruct daemon to log in
//

clear_devices(&new_device_client().await)
clear_devices(&new_device_client())
.await
.expect("failed to clear devices");

Expand Down Expand Up @@ -65,7 +65,7 @@ pub async fn test_too_many_devices(
) -> Result<(), Error> {
log::info!("Using up all devices");

let device_client = new_device_client().await;
let device_client = new_device_client();

const MAX_ATTEMPTS: usize = 15;

Expand Down Expand Up @@ -151,7 +151,7 @@ pub async fn test_revoked_device(

log::debug!("Removing current device");

let device_client = new_device_client().await;
let device_client = new_device_client();
retry_if_throttled(|| {
device_client.remove(TEST_CONFIG.account_number.clone(), device_id.clone())
})
Expand Down Expand Up @@ -217,9 +217,10 @@ pub async fn clear_devices(device_client: &DevicesProxy) -> Result<(), mullvad_a
Ok(())
}

pub async fn new_device_client() -> DevicesProxy {
let api_endpoint = mullvad_api::ApiEndpoint::from_env_vars();
pub fn new_device_client() -> DevicesProxy {
use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint, API};

let api_endpoint = ApiEndpoint::from_env_vars();
let api_host = format!("api.{}", TEST_CONFIG.mullvad_host);
let api_address = format!("{api_host}:443")
.to_socket_addrs()
Expand All @@ -228,17 +229,18 @@ pub async fn new_device_client() -> DevicesProxy {
.unwrap();

// Override the API endpoint to use the one specified in the test config
let _ = mullvad_api::API.override_init(mullvad_api::ApiEndpoint {
let _ = API.override_init(ApiEndpoint {
host: Some(api_host),
address: Some(api_address),
..api_endpoint
});

let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("failed to create api runtime");
let rest_handle = api
.mullvad_rest_handle(mullvad_api::proxy::ApiConnectionMode::Direct.into_repeat())
.await;
let rest_handle = api.mullvad_rest_handle(
ApiConnectionMode::Direct,
ApiConnectionMode::Direct.into_repeat(),
);
DevicesProxy::new(rest_handle)
}

Expand Down
Loading