Skip to content

Commit e471d07

Browse files
committed
Refactor API access methods
1 parent c8a3a3b commit e471d07

File tree

11 files changed

+295
-313
lines changed

11 files changed

+295
-313
lines changed

mullvad-api/src/bin/relay_list.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@ async fn main() {
1111
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
1212
.expect("Failed to load runtime");
1313

14-
let relay_list_request = RelayListProxy::new(runtime.mullvad_rest_handle(
15-
ApiConnectionMode::Direct,
16-
ApiConnectionMode::Direct.into_repeat(),
17-
))
18-
.relay_list(None)
19-
.await;
14+
let relay_list_request =
15+
RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()))
16+
.relay_list(None)
17+
.await;
2018

2119
let relay_list = match relay_list_request {
2220
Ok(relay_list) => relay_list,

mullvad-api/src/lib.rs

+9-16
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
#[cfg(target_os = "android")]
22
use futures::channel::mpsc;
3-
use futures::Stream;
43
use hyper::Method;
54
#[cfg(target_os = "android")]
65
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
76
use mullvad_types::{
87
account::{AccountData, AccountToken, VoucherSubmission},
98
version::AppVersion,
109
};
11-
use proxy::ApiConnectionMode;
10+
use proxy::{ApiConnectionMode, ConnectionModeProvider};
1211
use std::{
1312
cell::Cell,
1413
collections::BTreeMap,
@@ -408,34 +407,30 @@ impl Runtime {
408407
}
409408

410409
/// Creates a new request service and returns a handle to it.
411-
fn new_request_service<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
410+
fn new_request_service<T: ConnectionModeProvider + 'static>(
412411
&self,
413412
sni_hostname: Option<String>,
414-
initial_connection_mode: ApiConnectionMode,
415-
proxy_provider: T,
413+
connection_mode_provider: T,
416414
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
417415
) -> rest::RequestServiceHandle {
418416
rest::RequestService::spawn(
419417
sni_hostname,
420418
self.api_availability.handle(),
421419
self.address_cache.clone(),
422-
initial_connection_mode,
423-
proxy_provider,
420+
connection_mode_provider,
424421
#[cfg(target_os = "android")]
425422
socket_bypass_tx,
426423
)
427424
}
428425

429426
/// Returns a request factory initialized to create requests for the master API
430-
pub fn mullvad_rest_handle<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static>(
427+
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
431428
&self,
432-
initial_connection_mode: ApiConnectionMode,
433-
proxy_provider: T,
429+
connection_mode_provider: T,
434430
) -> rest::MullvadRestHandle {
435431
let service = self.new_request_service(
436432
Some(API.host().to_string()),
437-
initial_connection_mode,
438-
proxy_provider,
433+
connection_mode_provider,
439434
#[cfg(target_os = "android")]
440435
self.socket_bypass_tx.clone(),
441436
);
@@ -454,8 +449,7 @@ impl Runtime {
454449
pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
455450
let service = self.new_request_service(
456451
Some(hostname.clone()),
457-
ApiConnectionMode::Direct,
458-
futures::stream::repeat(ApiConnectionMode::Direct),
452+
ApiConnectionMode::Direct.into_provider(),
459453
#[cfg(target_os = "android")]
460454
self.socket_bypass_tx.clone(),
461455
);
@@ -474,8 +468,7 @@ impl Runtime {
474468
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
475469
self.new_request_service(
476470
None,
477-
ApiConnectionMode::Direct,
478-
ApiConnectionMode::Direct.into_repeat(),
471+
ApiConnectionMode::Direct.into_provider(),
479472
#[cfg(target_os = "android")]
480473
None,
481474
)

mullvad-api/src/proxy.rs

+37-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use futures::Stream;
21
use hyper::client::connect::Connected;
32
use serde::{Deserialize, Serialize};
43
use std::{
@@ -18,6 +17,41 @@ use tokio::{
1817

1918
const CURRENT_CONFIG_FILENAME: &str = "api-endpoint.json";
2019

20+
pub trait ConnectionModeProvider: Send {
21+
/// Initial connection mode
22+
fn initial(&self) -> ApiConnectionMode;
23+
24+
/// Request a new connection mode from the provider
25+
fn rotate(&self) -> impl std::future::Future<Output = ()> + Send;
26+
27+
/// Receive changes to the connection mode, announced by the provider
28+
fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send;
29+
}
30+
31+
pub struct StaticConnectionModeProvider {
32+
mode: ApiConnectionMode,
33+
}
34+
35+
impl StaticConnectionModeProvider {
36+
pub fn new(mode: ApiConnectionMode) -> Self {
37+
Self { mode }
38+
}
39+
}
40+
41+
impl ConnectionModeProvider for StaticConnectionModeProvider {
42+
fn initial(&self) -> ApiConnectionMode {
43+
self.mode.clone()
44+
}
45+
46+
fn rotate(&self) -> impl std::future::Future<Output = ()> + Send {
47+
futures::future::ready(())
48+
}
49+
50+
fn receive(&mut self) -> impl std::future::Future<Output = Option<ApiConnectionMode>> + Send {
51+
futures::future::pending()
52+
}
53+
}
54+
2155
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
2256
pub enum ApiConnectionMode {
2357
/// Connect directly to the target.
@@ -153,10 +187,8 @@ impl ApiConnectionMode {
153187
*self != ApiConnectionMode::Direct
154188
}
155189

156-
/// Convenience function that returns a stream that repeats
157-
/// this config forever.
158-
pub fn into_repeat(self) -> impl Stream<Item = ApiConnectionMode> {
159-
futures::stream::repeat(self)
190+
pub fn into_provider(self) -> StaticConnectionModeProvider {
191+
StaticConnectionModeProvider::new(self)
160192
}
161193
}
162194

mullvad-api/src/rest.rs

+32-33
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@ use crate::{
55
address_cache::AddressCache,
66
availability::ApiAvailabilityHandle,
77
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
8-
proxy::ApiConnectionMode,
8+
proxy::ConnectionModeProvider,
99
};
1010
use futures::{
1111
channel::{mpsc, oneshot},
1212
stream::StreamExt,
13-
Stream,
1413
};
1514
use hyper::{
1615
client::{connect::Connect, Client},
@@ -120,23 +119,22 @@ impl Error {
120119

121120
/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
122121
/// requests
123-
pub(crate) struct RequestService<T: Stream<Item = ApiConnectionMode>> {
122+
pub(crate) struct RequestService<T: ConnectionModeProvider> {
124123
command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
125124
command_rx: mpsc::UnboundedReceiver<RequestCommand>,
126125
connector_handle: HttpsConnectorWithSniHandle,
127126
client: hyper::Client<HttpsConnectorWithSni, hyper::Body>,
128-
proxy_config_provider: T,
127+
connection_mode_provider: T,
129128
api_availability: ApiAvailabilityHandle,
130129
}
131130

132-
impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestService<T> {
131+
impl<T: ConnectionModeProvider + 'static> RequestService<T> {
133132
/// Constructs a new request service.
134133
pub fn spawn(
135134
sni_hostname: Option<String>,
136135
api_availability: ApiAvailabilityHandle,
137136
address_cache: AddressCache,
138-
initial_connection_mode: ApiConnectionMode,
139-
proxy_config_provider: T,
137+
connection_mode_provider: T,
140138
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
141139
) -> RequestServiceHandle {
142140
let (connector, connector_handle) = HttpsConnectorWithSni::new(
@@ -146,7 +144,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
146144
socket_bypass_tx.clone(),
147145
);
148146

149-
connector_handle.set_connection_mode(initial_connection_mode);
147+
connector_handle.set_connection_mode(connection_mode_provider.initial());
150148

151149
let (command_tx, command_rx) = mpsc::unbounded();
152150
let client = Client::builder().build(connector);
@@ -158,14 +156,35 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
158156
command_rx,
159157
connector_handle,
160158
client,
161-
proxy_config_provider,
159+
connection_mode_provider,
162160
api_availability,
163161
};
164162
let handle = RequestServiceHandle { tx: command_tx };
165163
tokio::spawn(service.into_future());
166164
handle
167165
}
168166

167+
async fn into_future(mut self) {
168+
loop {
169+
tokio::select! {
170+
new_mode = self.connection_mode_provider.receive() => {
171+
let Some(new_mode) = new_mode else {
172+
break;
173+
};
174+
self.connector_handle.set_connection_mode(new_mode);
175+
}
176+
command = self.command_rx.next() => {
177+
let Some(command) = command else {
178+
break;
179+
};
180+
181+
self.process_command(command).await;
182+
}
183+
}
184+
}
185+
self.connector_handle.reset();
186+
}
187+
169188
async fn process_command(&mut self, command: RequestCommand) {
170189
match command {
171190
RequestCommand::NewRequest(request, completion_tx) => {
@@ -174,11 +193,8 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
174193
RequestCommand::Reset => {
175194
self.connector_handle.reset();
176195
}
177-
RequestCommand::NextApiConfig(completion_tx) => {
178-
if let Some(connection_mode) = self.proxy_config_provider.next().await {
179-
self.connector_handle.set_connection_mode(connection_mode);
180-
}
181-
let _ = completion_tx.send(Ok(()));
196+
RequestCommand::NextApiConfig => {
197+
self.connection_mode_provider.rotate().await;
182198
}
183199
}
184200
}
@@ -201,22 +217,14 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
201217
if err.is_network_error() && !api_availability.get_state().is_offline() {
202218
log::error!("{}", err.display_chain_with_msg("HTTP request failed"));
203219
if let Some(tx) = tx {
204-
let (completion_tx, _completion_rx) = oneshot::channel();
205-
let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx));
220+
let _ = tx.unbounded_send(RequestCommand::NextApiConfig);
206221
}
207222
}
208223
}
209224

210225
let _ = completion_tx.send(response);
211226
});
212227
}
213-
214-
async fn into_future(mut self) {
215-
while let Some(command) = self.command_rx.next().await {
216-
self.process_command(command).await;
217-
}
218-
self.connector_handle.reset();
219-
}
220228
}
221229

222230
#[derive(Clone)]
@@ -239,15 +247,6 @@ impl RequestServiceHandle {
239247
.map_err(|_| Error::RestServiceDown)?;
240248
completion_rx.await.map_err(|_| Error::RestServiceDown)?
241249
}
242-
243-
/// Forcibly update the connection mode.
244-
pub async fn next_api_endpoint(&self) -> Result<()> {
245-
let (completion_tx, completion_rx) = oneshot::channel();
246-
self.tx
247-
.unbounded_send(RequestCommand::NextApiConfig(completion_tx))
248-
.map_err(|_| Error::RestServiceDown)?;
249-
completion_rx.await.map_err(|_| Error::RestServiceDown)?
250-
}
251250
}
252251

253252
#[derive(Debug)]
@@ -257,7 +256,7 @@ pub(crate) enum RequestCommand {
257256
oneshot::Sender<std::result::Result<Response, Error>>,
258257
),
259258
Reset,
260-
NextApiConfig(oneshot::Sender<std::result::Result<(), Error>>),
259+
NextApiConfig,
261260
}
262261

263262
/// A REST request that is sent to the RequestService to be executed.

0 commit comments

Comments
 (0)