diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index c016b4c8a1dd..e395d8ae5fbd 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -11,11 +11,10 @@ async fn main() { let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("Failed to load runtime"); - let relay_list_request = RelayListProxy::new( - runtime - .mullvad_rest_handle(ApiConnectionMode::Direct.into_repeat()) - .await, - ) + let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle( + ApiConnectionMode::Direct, + ApiConnectionMode::Direct.into_repeat(), + )) .relay_list(None) .await; diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 26d6b9758cd7..17e80b66c00b 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -382,9 +382,10 @@ impl Runtime { } /// Creates a new request service and returns a handle to it. - async fn new_request_service + Unpin + Send + 'static>( + fn new_request_service + Unpin + Send + 'static>( &self, sni_hostname: Option, + initial_connection_mode: ApiConnectionMode, proxy_provider: T, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> rest::RequestServiceHandle { @@ -392,28 +393,26 @@ impl Runtime { sni_hostname, self.api_availability.handle(), self.address_cache.clone(), + initial_connection_mode, proxy_provider, #[cfg(target_os = "android")] socket_bypass_tx, ) - .await } /// Returns a request factory initialized to create requests for the master API - pub async fn mullvad_rest_handle< - T: Stream + Unpin + Send + 'static, - >( + pub fn mullvad_rest_handle + Unpin + Send + 'static>( &self, + initial_connection_mode: ApiConnectionMode, proxy_provider: T, ) -> rest::MullvadRestHandle { - let service = self - .new_request_service( - Some(API.host().to_string()), - proxy_provider, - #[cfg(target_os = "android")] - self.socket_bypass_tx.clone(), - ) - .await; + let service = self.new_request_service( + Some(API.host().to_string()), + initial_connection_mode, + proxy_provider, + #[cfg(target_os = "android")] + self.socket_bypass_tx.clone(), + ); let token_store = access::AccessTokenStore::new(service.clone()); let factory = rest::RequestFactory::new(API.host(), Some(token_store)); @@ -426,15 +425,14 @@ impl Runtime { } /// This is only to be used in test code - pub async fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle { - let service = self - .new_request_service( - Some(hostname.clone()), - futures::stream::repeat(ApiConnectionMode::Direct), - #[cfg(target_os = "android")] - self.socket_bypass_tx.clone(), - ) - .await; + pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle { + let service = self.new_request_service( + Some(hostname.clone()), + ApiConnectionMode::Direct, + futures::stream::repeat(ApiConnectionMode::Direct), + #[cfg(target_os = "android")] + self.socket_bypass_tx.clone(), + ); let token_store = access::AccessTokenStore::new(service.clone()); let factory = rest::RequestFactory::new(hostname, Some(token_store)); @@ -447,14 +445,14 @@ impl Runtime { } /// Returns a new request service handle - pub async fn rest_handle(&self) -> rest::RequestServiceHandle { + pub fn rest_handle(&self) -> rest::RequestServiceHandle { self.new_request_service( None, + ApiConnectionMode::Direct, ApiConnectionMode::Direct.into_repeat(), #[cfg(target_os = "android")] None, ) - .await } pub fn handle(&mut self) -> &mut tokio::runtime::Handle { diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 0560642bb09c..ca63f16c1fb6 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -131,11 +131,12 @@ pub(crate) struct RequestService> { impl + Unpin + Send + 'static> RequestService { /// Constructs a new request service. - pub async fn spawn( + pub fn spawn( sni_hostname: Option, api_availability: ApiAvailabilityHandle, address_cache: AddressCache, - mut proxy_config_provider: T, + initial_connection_mode: ApiConnectionMode, + proxy_config_provider: T, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( @@ -145,9 +146,7 @@ impl + Unpin + Send + 'static> RequestServic socket_bypass_tx.clone(), ); - if let Some(config) = proxy_config_provider.next().await { - connector_handle.set_connection_mode(config); - } + connector_handle.set_connection_mode(initial_connection_mode); let (command_tx, command_rx) = mpsc::unbounded(); let client = Client::builder().build(connector); diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index 51bf6c1ea570..664fce6bfee1 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -260,14 +260,10 @@ where /// Create an [`ApiProxy`] which will perform all REST requests against one /// specific endpoint `proxy_provider`. - pub async fn create_limited_api_proxy( - &mut self, - proxy_provider: ApiConnectionMode, - ) -> ApiProxy { + pub fn create_limited_api_proxy(&mut self, proxy_provider: ApiConnectionMode) -> ApiProxy { let rest_handle = self .api_runtime - .mullvad_rest_handle(proxy_provider.into_repeat()) - .await; + .mullvad_rest_handle(proxy_provider, futures::stream::empty()); ApiProxy::new(rest_handle) } diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index a493b532d185..5db03c200874 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -243,7 +243,8 @@ impl AccessModeSelector { ) -> Result { let (cmd_tx, cmd_rx) = mpsc::unbounded(); - let (index, next) = Self::get_next_inner(0, &access_method_settings); + // Always start looking from the position of `Direct`. + let (index, next) = Self::select_next_active(0, &access_method_settings); let initial_connection_mode = Self::resolve_inner(next, &relay_selector, &address_cache).await; @@ -396,25 +397,28 @@ impl AccessModeSelector { if let Some(access_method) = self.set.take() { access_method } else { - let (index, next) = Self::get_next_inner(self.index, &self.access_method_settings); - self.index = index; + let (next_index, next) = + Self::select_next_active(self.index + 1, &self.access_method_settings); + self.index = next_index; next } } - fn get_next_inner(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) { - let xs: Vec<_> = access_methods.iter().collect(); - for offset in 1..=access_methods.cardinality() { - let index = (start + offset) % access_methods.cardinality(); - if let Some(&candidate) = xs.get(index) { - if candidate.enabled { - return (index, candidate.clone()); - } - } - } - (0, access_methods.direct().clone()) + /// Find the next access method to use. + /// + /// * `start`: From which point in `access_methods` to start the search. + /// * `access_methods`: The search space. + fn select_next_active(start: usize, access_methods: &Settings) -> (usize, AccessMethodSetting) { + access_methods + .iter() + .cloned() + .enumerate() + .cycle() + .skip(start) + .take(access_methods.cardinality()) + .find(|(_index, access_method)| access_method.enabled()) + .unwrap_or_else(|| (0, access_methods.direct().clone())) } - fn on_update_access_methods( &mut self, tx: ResponseTx<()>, diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 00f4613ffbb9..ef4bcb86e077 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -717,9 +717,15 @@ where .await .map_err(Error::ApiConnectionModeError)?; - let api_handle = api_runtime - .mullvad_rest_handle(Box::pin(connection_modes_handler.clone().into_stream())) - .await; + let initial_connection_mode = connection_modes_handler + .get_current() + .await + .map_err(Error::ApiConnectionModeError)?; + + let api_handle = api_runtime.mullvad_rest_handle( + initial_connection_mode.connection_mode, + Box::pin(connection_modes_handler.clone().into_stream()), + ); let migration_complete = if let Some(migration_data) = migration_data { migrations::migrate_device( @@ -787,11 +793,6 @@ where let _ = param_gen_tx.unbounded_send(settings.tunnel_options.to_owned()); }); - let initial_api_endpoint = connection_modes_handler - .get_current() - .await - .map_err(Error::ApiConnectionModeError)? - .endpoint; let (offline_state_tx, offline_state_rx) = mpsc::unbounded(); #[cfg(target_os = "windows")] let (volume_update_tx, volume_update_rx) = mpsc::unbounded(); @@ -800,7 +801,7 @@ where allow_lan: settings.allow_lan, block_when_disconnected: settings.block_when_disconnected, dns_servers: dns::addresses_from_options(&settings.tunnel_options.dns_options), - allowed_endpoint: initial_api_endpoint, + allowed_endpoint: initial_connection_mode.endpoint, reset_firewall: *target_state != TargetState::Secured, #[cfg(windows)] exclude_paths, @@ -851,7 +852,7 @@ where relay_list_updater.update().await; let location_handler = GeoIpHandler::new( - api_runtime.rest_handle().await, + api_runtime.rest_handle(), internal_event_tx.clone().to_specialized_sender(), ); @@ -1248,9 +1249,7 @@ where GetCurrentAccessMethod(tx) => self.on_get_current_api_access_method(tx), SetApiAccessMethod(tx, method) => self.on_set_api_access_method(tx, method).await, TestApiAccessMethodById(tx, method) => self.on_test_api_access_method(tx, method).await, - TestCustomApiAccessMethod(tx, proxy) => { - self.on_test_proxy_as_access_method(tx, proxy).await - } + TestCustomApiAccessMethod(tx, proxy) => self.on_test_proxy_as_access_method(tx, proxy), IsPerformingPostUpgrade(tx) => self.on_is_performing_post_upgrade(tx), GetCurrentVersion(tx) => self.on_get_current_version(tx), #[cfg(not(target_os = "android"))] @@ -2478,7 +2477,7 @@ where }); } - async fn on_test_proxy_as_access_method( + fn on_test_proxy_as_access_method( &mut self, tx: ResponseTx, proxy: talpid_types::net::proxy::CustomProxy, @@ -2487,7 +2486,7 @@ where use talpid_types::net::AllowedEndpoint; let connection_mode = ApiConnectionMode::Proxied(ProxyConfig::from(proxy.clone())); - let api_proxy = self.create_limited_api_proxy(connection_mode.clone()).await; + let api_proxy = self.create_limited_api_proxy(connection_mode.clone()); let proxy_endpoint = AllowedEndpoint { endpoint: proxy.get_remote_endpoint().endpoint, clients: api::allowed_clients(&connection_mode), @@ -2533,9 +2532,7 @@ where } }; - let api_proxy = self - .create_limited_api_proxy(test_subject.connection_mode) - .await; + let api_proxy = self.create_limited_api_proxy(test_subject.connection_mode); let daemon_event_sender = self.tx.to_specialized_sender(); let access_method_selector = self.connection_modes_handler.clone(); diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 1f687b457093..bcd820bef5da 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -299,14 +299,9 @@ async fn send_problem_report_inner( .await .map_err(Error::CreateRpcClientError)?; + let connection_mode = ApiConnectionMode::try_from_cache(cache_dir).await; let api_client = mullvad_api::ProblemReportProxy::new( - api_runtime - .mullvad_rest_handle( - ApiConnectionMode::try_from_cache(cache_dir) - .await - .into_repeat(), - ) - .await, + api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()), ); for _attempt in 0..MAX_SEND_ATTEMPTS { diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index f89baeb04917..cf93b2d039e4 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -159,14 +159,9 @@ async fn remove_device() -> Result<(), Error> { .await .map_err(Error::RpcInitializationError)?; + let connection_mode = ApiConnectionMode::try_from_cache(&cache_path).await; let proxy = mullvad_api::DevicesProxy::new( - api_runtime - .mullvad_rest_handle( - ApiConnectionMode::try_from_cache(&cache_path) - .await - .into_repeat(), - ) - .await, + api_runtime.mullvad_rest_handle(connection_mode.clone(), connection_mode.into_repeat()), ); let device_removal = retry_future( diff --git a/mullvad-types/src/access_method.rs b/mullvad-types/src/access_method.rs index c0f648f25f48..e8f6bd4a5deb 100644 --- a/mullvad-types/src/access_method.rs +++ b/mullvad-types/src/access_method.rs @@ -83,7 +83,7 @@ impl Settings { } /// Iterate over references of built-in & custom access methods. - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator + Clone { use std::iter::once; once(&self.direct) .chain(once(&self.mullvad_bridges)) diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs index 95b54b6432a7..1eeeb8c17039 100644 --- a/test/test-manager/src/tests/account.rs +++ b/test/test-manager/src/tests/account.rs @@ -23,7 +23,7 @@ pub async fn test_login( // Instruct daemon to log in // - clear_devices(&new_device_client().await) + clear_devices(&new_device_client()) .await .expect("failed to clear devices"); @@ -65,7 +65,7 @@ pub async fn test_too_many_devices( ) -> Result<(), Error> { log::info!("Using up all devices"); - let device_client = new_device_client().await; + let device_client = new_device_client(); const MAX_ATTEMPTS: usize = 15; @@ -151,7 +151,7 @@ pub async fn test_revoked_device( log::debug!("Removing current device"); - let device_client = new_device_client().await; + let device_client = new_device_client(); retry_if_throttled(|| { device_client.remove(TEST_CONFIG.account_number.clone(), device_id.clone()) }) @@ -217,9 +217,10 @@ pub async fn clear_devices(device_client: &DevicesProxy) -> Result<(), mullvad_a Ok(()) } -pub async fn new_device_client() -> DevicesProxy { - let api_endpoint = mullvad_api::ApiEndpoint::from_env_vars(); +pub fn new_device_client() -> DevicesProxy { + use mullvad_api::{proxy::ApiConnectionMode, ApiEndpoint, API}; + let api_endpoint = ApiEndpoint::from_env_vars(); let api_host = format!("api.{}", TEST_CONFIG.mullvad_host); let api_address = format!("{api_host}:443") .to_socket_addrs() @@ -228,7 +229,7 @@ pub async fn new_device_client() -> DevicesProxy { .unwrap(); // Override the API endpoint to use the one specified in the test config - let _ = mullvad_api::API.override_init(mullvad_api::ApiEndpoint { + let _ = API.override_init(ApiEndpoint { host: Some(api_host), address: Some(api_address), ..api_endpoint @@ -236,9 +237,10 @@ pub async fn new_device_client() -> DevicesProxy { let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("failed to create api runtime"); - let rest_handle = api - .mullvad_rest_handle(mullvad_api::proxy::ApiConnectionMode::Direct.into_repeat()) - .await; + let rest_handle = api.mullvad_rest_handle( + ApiConnectionMode::Direct, + ApiConnectionMode::Direct.into_repeat(), + ); DevicesProxy::new(rest_handle) } diff --git a/test/test-manager/src/tests/install.rs b/test/test-manager/src/tests/install.rs index 614b9d9bb93d..8f2f2cdff22f 100644 --- a/test/test-manager/src/tests/install.rs +++ b/test/test-manager/src/tests/install.rs @@ -49,7 +49,7 @@ pub async fn test_upgrade_app(ctx: TestContext, rpc: ServiceClient) -> Result<() return Err(Error::DaemonNotRunning); } - super::account::clear_devices(&super::account::new_device_client().await) + super::account::clear_devices(&super::account::new_device_client()) .await .expect("failed to clear devices"); @@ -227,10 +227,9 @@ pub async fn test_uninstall_app( } // verify that device was removed - let devices = - super::account::list_devices_with_retries(&super::account::new_device_client().await) - .await - .expect("failed to list devices"); + let devices = super::account::list_devices_with_retries(&super::account::new_device_client()) + .await + .expect("failed to list devices"); assert!( !devices.iter().any(|device| device.id == uninstalled_device),