Skip to content

Commit babe680

Browse files
committed
Fixed Azure#2506, creating an AsyncRuntime trait which can be used by customers to replace the async runtime
1 parent 772047b commit babe680

File tree

9 files changed

+216
-138
lines changed

9 files changed

+216
-138
lines changed

sdk/core/azure_core/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ pub mod fs;
1515
pub mod hmac;
1616
pub mod http;
1717
pub mod process;
18-
pub mod task;
1918

2019
#[cfg(feature = "test")]
2120
pub mod test;
@@ -31,3 +30,10 @@ pub use typespec_client_core::{
3130

3231
#[cfg(feature = "xml")]
3332
pub use typespec_client_core::xml;
33+
34+
pub use typespec_client_core::get_async_runtime;
35+
pub use typespec_client_core::set_async_runtime;
36+
37+
pub mod async_runtime {
38+
pub use typespec_client_core::async_runtime::SpawnedTask;
39+
}

sdk/eventhubs/azure_messaging_eventhubs/src/common/authorizer.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ use super::recoverable_connection::RecoverableConnection;
55
use crate::error::{ErrorKind, EventHubsError};
66
use async_lock::Mutex as AsyncMutex;
77
use azure_core::{
8+
async_runtime::SpawnedTask,
89
credentials::{AccessToken, TokenCredential},
910
error::ErrorKind as AzureErrorKind,
11+
get_async_runtime,
1012
http::Url,
11-
task::{new_task_spawner, SpawnedTask},
1213
Result,
1314
};
1415
use azure_core_amqp::AmqpClaimsBasedSecurityApis as _;
@@ -113,8 +114,8 @@ impl Authorizer {
113114
self.authorization_refresher.get_or_init(|| {
114115
debug!("Starting authorization refresh task.");
115116
let self_clone = self.clone();
116-
let spawner = new_task_spawner();
117-
spawner.spawn(Box::pin(self_clone.refresh_tokens_task()))
117+
let async_runtime = get_async_runtime();
118+
async_runtime.spawn(Box::pin(self_clone.refresh_tokens_task()))
118119
});
119120
} else {
120121
debug!("Token already exists for path: {path}");

sdk/core/azure_core/src/task/mod.rs renamed to sdk/typespec/typespec_client_core/src/async_runtime/mod.rs

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,17 @@
2929
//! ```
3030
//!
3131
//!
32-
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc};
32+
use std::{
33+
fmt::Debug,
34+
future::Future,
35+
pin::Pin,
36+
sync::{Arc, OnceLock},
37+
};
3338

34-
mod standard_spawn;
39+
mod standard_runtime;
3540

3641
#[cfg(feature = "tokio")]
37-
mod tokio_spawn;
42+
mod tokio_runtime;
3843

3944
#[cfg(test)]
4045
mod tests;
@@ -61,9 +66,11 @@ pub type SpawnedTask = Pin<
6166
pub type SpawnedTask =
6267
Pin<Box<dyn Future<Output = std::result::Result<(), Box<dyn std::error::Error>>> + 'static>>;
6368

64-
/// An async command runner.
69+
/// An Asynchronous Runtime.
6570
///
66-
pub trait TaskSpawner: Send + Sync + Debug {
71+
/// This trait defines the various
72+
///
73+
pub trait AsyncRuntime: Send + Sync + Debug {
6774
/// Spawn a task that executes a given future and returns the output.
6875
///
6976
/// # Arguments
@@ -76,13 +83,13 @@ pub trait TaskSpawner: Send + Sync + Debug {
7683
///
7784
/// # Example
7885
/// ```
79-
/// use azure_core::task::{new_task_spawner, TaskSpawner};
86+
/// use azure_core::async_runtime::{get_async_runtime, TaskSpawner};
8087
/// use futures::FutureExt;
8188
///
8289
/// #[tokio::main]
8390
/// async fn main() {
84-
/// let spawner = new_task_spawner();
85-
/// let future = spawner.spawn(async {
91+
/// let async_runtime = get_async_runtime();
92+
/// let future = async_runtime.spawn(async {
8693
/// // Simulate some work
8794
/// std::thread::sleep(std::time::Duration::from_secs(1));
8895
/// }.boxed());
@@ -99,41 +106,86 @@ pub trait TaskSpawner: Send + Sync + Debug {
99106
/// that can be awaited.
100107
///
101108
fn spawn(&self, f: TaskFuture) -> SpawnedTask;
109+
110+
fn sleep(&self, duration: std::time::Duration) -> TaskFuture;
102111
}
103112

104-
/// Creates a new [`TaskSpawner`] to enable running tasks asynchronously.
113+
static ASYNC_RUNTIME_IMPLEMENTATION: OnceLock<Arc<dyn AsyncRuntime>> = OnceLock::new();
114+
115+
/// Returns an [`AsyncRuntime`] to enable running operations which need to interact with an
116+
/// asynchronous runtime.
105117
///
106118
///
107119
/// The implementation depends on the target architecture and the features enabled:
108-
/// - If the `tokio` feature is enabled, it uses a tokio based spawner.
109-
/// - If the `tokio` feature is not enabled and the target architecture is not `wasm32`, it uses a std::thread based spawner.
120+
/// - If the `tokio` feature is enabled, it uses a tokio based spawner and timer.
121+
/// - If the `tokio` feature is not enabled and the target architecture is not `wasm32`, it uses a std::thread based spawner and timer.
110122
///
111123
/// # Returns
112-
/// A new instance of a [`TaskSpawner`] which can be used to spawn background tasks.
124+
/// A new instance of a [`AsyncRuntime`] which can be used to spawn background tasks.
113125
///
114126
/// # Example
115127
///
116128
/// ```
117-
/// use azure_core::task::{new_task_spawner, TaskSpawner};
129+
/// use azure_core::get_async_runtime;
118130
/// use futures::FutureExt;
119131
///
120132
/// #[tokio::main]
121133
/// async fn main() {
122-
/// let spawner = new_task_spawner();
123-
/// let handle = spawner.spawn(async {
134+
/// let async_runtime = get_async_runtime();
135+
/// let handle = async_runtime.spawn(async {
136+
/// // Simulate some work
137+
/// std::thread::sleep(std::time::Duration::from_secs(1));
138+
/// }.boxed());
139+
/// }
140+
/// ```
141+
///
142+
pub fn get_async_runtime() -> Arc<dyn AsyncRuntime> {
143+
ASYNC_RUNTIME_IMPLEMENTATION
144+
.get_or_init(|| create_async_runtime())
145+
.clone()
146+
}
147+
148+
/// Sets the current [`AsyncRuntime`] to enable running operations which need to interact with an
149+
/// asynchronous runtime.
150+
///
151+
///
152+
/// # Returns
153+
/// Ok if the async runtime was set successfully, or an error if it has already been set.
154+
///
155+
/// # Example
156+
///
157+
/// ```
158+
/// use azure_core::async_runtime::{get_async_runtime, AsyncRuntime};
159+
/// use futures::FutureExt;
160+
///
161+
/// async fn main() {
162+
/// let async_runtime = set_async_runtime();
163+
/// let handle = async_runtime.spawn(async {
124164
/// // Simulate some work
125165
/// std::thread::sleep(std::time::Duration::from_secs(1));
126166
/// }.boxed());
127167
/// }
128168
/// ```
129169
///
130-
pub fn new_task_spawner() -> Arc<dyn TaskSpawner> {
170+
pub fn set_async_runtime(runtime: Arc<dyn AsyncRuntime>) -> crate::Result<()> {
171+
let result = ASYNC_RUNTIME_IMPLEMENTATION.set(runtime);
172+
if result.is_err() {
173+
Err(crate::Error::message(
174+
crate::error::ErrorKind::Other,
175+
"Async runtime has already been set.",
176+
))
177+
} else {
178+
Ok(())
179+
}
180+
}
181+
182+
fn create_async_runtime() -> Arc<dyn AsyncRuntime> {
131183
#[cfg(not(feature = "tokio"))]
132184
{
133-
Arc::new(standard_spawn::StdSpawner)
185+
Arc::new(standard_runtime::StdRuntime)
134186
}
135187
#[cfg(feature = "tokio")]
136188
{
137-
Arc::new(tokio_spawn::TokioSpawner) as Arc<dyn TaskSpawner>
189+
Arc::new(tokio_runtime::TokioRuntime) as Arc<dyn AsyncRuntime>
138190
}
139191
}

sdk/core/azure_core/src/task/standard_spawn.rs renamed to sdk/typespec/typespec_client_core/src/async_runtime/standard_runtime.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
use super::{SpawnedTask, TaskFuture, TaskSpawner};
4+
use super::{AsyncRuntime, SpawnedTask, TaskFuture};
5+
56
#[cfg(not(target_arch = "wasm32"))]
67
use futures::{executor::LocalPool, task::SpawnExt};
8+
9+
use std::sync::atomic::{AtomicBool, Ordering};
710
#[cfg(not(target_arch = "wasm32"))]
811
use std::{
912
future,
1013
future::Future,
1114
pin::Pin,
1215
sync::{Arc, Mutex},
13-
task::Waker,
14-
task::{Context, Poll},
16+
task::{Context, Poll, Waker},
1517
thread,
18+
time::Duration,
1619
};
1720
#[cfg(not(target_arch = "wasm32"))]
1821
use tracing::debug;
@@ -78,11 +81,11 @@ impl Future for ThreadJoinFuture {
7881
}
7982
}
8083

81-
/// A [`TaskSpawner`] using [`std::thread::spawn`].
84+
/// An [`AsyncRuntime`] using [`std::thread::spawn`].
8285
#[derive(Debug)]
83-
pub struct StdSpawner;
86+
pub struct StdRuntime;
8487

85-
impl TaskSpawner for StdSpawner {
88+
impl AsyncRuntime for StdRuntime {
8689
#[cfg_attr(target_arch = "wasm32", allow(unused_variables))]
8790
fn spawn(&self, f: TaskFuture) -> SpawnedTask {
8891
#[cfg(target_arch = "wasm32")]
@@ -143,4 +146,46 @@ impl TaskSpawner for StdSpawner {
143146
Box::pin(join_future)
144147
}
145148
}
149+
150+
/// Creates a future that resolves after a specified duration of time.
151+
///
152+
/// Uses a simple thread based implementation for sleep. A more efficient
153+
/// implementation is available by using the `tokio` crate feature.
154+
fn sleep(&self, duration: Duration) -> TaskFuture {
155+
Box::pin(Sleep {
156+
signal: None,
157+
duration,
158+
})
159+
}
160+
}
161+
162+
#[derive(Debug)]
163+
pub struct Sleep {
164+
signal: Option<Arc<AtomicBool>>,
165+
duration: Duration,
166+
}
167+
168+
impl Future for Sleep {
169+
type Output = ();
170+
171+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172+
if let Some(signal) = &self.signal {
173+
if signal.load(Ordering::Acquire) {
174+
Poll::Ready(())
175+
} else {
176+
Poll::Pending
177+
}
178+
} else {
179+
let signal = Arc::new(AtomicBool::new(false));
180+
let waker = cx.waker().clone();
181+
let duration = self.duration;
182+
self.get_mut().signal = Some(signal.clone());
183+
thread::spawn(move || {
184+
thread::sleep(duration);
185+
signal.store(true, Ordering::Release);
186+
waker.wake();
187+
});
188+
Poll::Pending
189+
}
190+
}
146191
}

0 commit comments

Comments
 (0)