diff --git a/sdk/core/azure_core/CHANGELOG.md b/sdk/core/azure_core/CHANGELOG.md index ac4b04af90..5108792222 100644 --- a/sdk/core/azure_core/CHANGELOG.md +++ b/sdk/core/azure_core/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added +- Added `get_async_runtime()` and `set_async_runtime()` to allow customers to replace +the asynchronous runtime used by the Azure SDK. + ### Breaking Changes ### Bugs Fixed @@ -16,7 +19,7 @@ - Added `#[safe]` attribute helper for `SafeDebug` derive macro to show or hide types and members as appropriate. - Added `Page` trait to facilitate the `ItemIterator`. -- Added `PageIterator` to asynchronously iterator all pages. +- Added `PageIterator` to asynchronously iterate all pages. ### Breaking Changes diff --git a/sdk/core/azure_core/README.md b/sdk/core/azure_core/README.md index b60b48e082..e4b1b663df 100644 --- a/sdk/core/azure_core/README.md +++ b/sdk/core/azure_core/README.md @@ -20,12 +20,12 @@ you can find the [package on crates.io][Package (crates.io)]. The main shared concepts of `azure_core` - and Azure SDK libraries using `azure_core` - include: -- Configuring service clients, e.g. configuring retries, logging (`ClientOptions`). -- Accessing HTTP response details (`Response`). -- Paging and asynchronous streams (`Pager`). -- Errors from service requests in a consistent fashion. (`azure_core::Error`). -- Customizing requests (`ClientOptions`). -- Abstractions for representing Azure SDK credentials. (`TokenCredentials`). +- Configuring service clients, e.g. configuring retries, logging (`ClientOptions`). +- Accessing HTTP response details (`Response`). +- Paging and asynchronous streams (`Pager`). +- Errors from service requests in a consistent fashion. (`azure_core::Error`). +- Customizing requests (`ClientOptions`). +- Abstractions for representing Azure SDK credentials. (`TokenCredentials`). ### Thread safety @@ -34,23 +34,25 @@ We guarantee that all client instance methods are thread-safe and independent of ### Additional concepts + [Client options](#configuring-service-clients-using-clientoptions) | [Accessing the response](#accessing-http-response-details-using-responset) | [Handling Errors Results](#handling-errors-results) | [Consuming Service Methods Returning `Pager`](#consuming-service-methods-returning-pagert) + ## Features -- `debug`: enables extra information for developers e.g., emitting all fields in `std::fmt::Debug` implementation. -- `hmac_openssl`: configures HMAC using `openssl`. -- `hmac_rust`: configures HMAC using pure Rust. -- `reqwest` (default): enables and sets `reqwest` as the default `HttpClient`. Enables `reqwest`'s `native-tls` feature. -- `reqwest_deflate` (default): enables deflate compression for `reqwest`. -- `reqwest_gzip` (default): enables gzip compression for `reqwest`. -- `reqwest_rustls`: enables `reqwest`'s `rustls-tls-native-roots-no-provider` feature, -- `tokio`: enables and sets `tokio` as the default async runtime. -- `xml`: enables XML support. +- `debug`: enables extra information for developers e.g., emitting all fields in `std::fmt::Debug` implementation. +- `hmac_openssl`: configures HMAC using `openssl`. +- `hmac_rust`: configures HMAC using pure Rust. +- `reqwest` (default): enables and sets `reqwest` as the default `HttpClient`. Enables `reqwest`'s `native-tls` feature. +- `reqwest_deflate` (default): enables deflate compression for `reqwest`. +- `reqwest_gzip` (default): enables gzip compression for `reqwest`. +- `reqwest_rustls`: enables `reqwest`'s `rustls-tls-native-roots-no-provider` feature, +- `tokio`: enables and sets `tokio` as the default async runtime. +- `xml`: enables XML support. ## Examples @@ -244,6 +246,36 @@ async fn main() -> Result<(), Box> { } ``` +### Replacing the async runtime + +Internally, the Azure SDK uses either the `tokio` async runtime (with the `tokio` feature), or it implements asynchronous functionality using functions in the `std` namespace. + +If your application uses a different asynchronous runtime, you can replace the asynchronous runtime used for internal functions by providing your own implementation of the `azure_core::async_runtime::AsyncRuntime` trait. + +You provide the implementation by calling the `set_async_runtime()` API: + +```rust no_run +use azure_core::async_runtime::{ + set_async_runtime, AsyncRuntime, TaskFuture, SpawnedTask}; +use std::sync::Arc; +use futures::FutureExt; + +struct CustomRuntime; + +impl AsyncRuntime for CustomRuntime { + fn spawn(&self, f: TaskFuture) -> SpawnedTask { + unimplemented!("Custom spawn not implemented"); + } + fn sleep(&self, duration: std::time::Duration) -> TaskFuture { + unimplemented!("Custom sleep not implemented"); + } + } + + set_async_runtime(Arc::new(CustomRuntime)).expect("Failed to set async runtime"); +``` + +There can only be one async runtime set in a given process, so attempts to set the async runtime multiple times will fail. + ## Troubleshooting ### Logging diff --git a/sdk/core/azure_core/src/lib.rs b/sdk/core/azure_core/src/lib.rs index 172ccca9e8..90c1a7a87f 100644 --- a/sdk/core/azure_core/src/lib.rs +++ b/sdk/core/azure_core/src/lib.rs @@ -15,7 +15,6 @@ pub mod fs; pub mod hmac; pub mod http; pub mod process; -pub mod task; #[cfg(feature = "test")] pub mod test; @@ -24,7 +23,7 @@ pub use constants::*; // Re-export modules in typespec_client_core such that azure_core-based crates don't need to reference it directly. pub use typespec_client_core::{ - base64, create_enum, create_extensible_enum, date, + async_runtime, base64, create_enum, create_extensible_enum, date, error::{self, Error, Result}, fmt, json, sleep, stream, Bytes, Uuid, }; diff --git a/sdk/core/azure_core/src/task/tokio_spawn.rs b/sdk/core/azure_core/src/task/tokio_spawn.rs deleted file mode 100644 index 6d48f007b1..0000000000 --- a/sdk/core/azure_core/src/task/tokio_spawn.rs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -use super::{SpawnedTask, TaskFuture, TaskSpawner}; -use std::fmt::Debug; - -/// A [`TaskSpawner`] using [`tokio::spawn`]. -#[derive(Debug)] -pub struct TokioSpawner; - -impl TaskSpawner for TokioSpawner { - fn spawn(&self, f: TaskFuture) -> SpawnedTask { - let handle = ::tokio::spawn(f); - Box::pin(async move { - handle - .await - .map_err(|e| Box::new(e) as Box) - }) - } -} diff --git a/sdk/eventhubs/azure_messaging_eventhubs/src/common/authorizer.rs b/sdk/eventhubs/azure_messaging_eventhubs/src/common/authorizer.rs index 103a4ff844..6b656845b9 100644 --- a/sdk/eventhubs/azure_messaging_eventhubs/src/common/authorizer.rs +++ b/sdk/eventhubs/azure_messaging_eventhubs/src/common/authorizer.rs @@ -5,10 +5,10 @@ use super::recoverable_connection::RecoverableConnection; use crate::error::{ErrorKind, EventHubsError}; use async_lock::Mutex as AsyncMutex; use azure_core::{ + async_runtime::{get_async_runtime, SpawnedTask}, credentials::{AccessToken, TokenCredential}, error::ErrorKind as AzureErrorKind, http::Url, - task::{new_task_spawner, SpawnedTask}, Result, }; use azure_core_amqp::AmqpClaimsBasedSecurityApis as _; @@ -113,8 +113,8 @@ impl Authorizer { self.authorization_refresher.get_or_init(|| { debug!("Starting authorization refresh task."); let self_clone = self.clone(); - let spawner = new_task_spawner(); - spawner.spawn(Box::pin(self_clone.refresh_tokens_task())) + let async_runtime = get_async_runtime(); + async_runtime.spawn(Box::pin(self_clone.refresh_tokens_task())) }); } else { debug!("Token already exists for path: {path}"); diff --git a/sdk/typespec/typespec_client_core/CHANGELOG.md b/sdk/typespec/typespec_client_core/CHANGELOG.md index c4edf1008f..f9ed9564f4 100644 --- a/sdk/typespec/typespec_client_core/CHANGELOG.md +++ b/sdk/typespec/typespec_client_core/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added +- Added `get_async_runtime()` and `set_async_runtime()` to allow customers to replace + the default asynchronous runtime with another. + ### Breaking Changes ### Bugs Fixed @@ -28,29 +31,29 @@ ### Breaking Changes -- The `reqwest_rustls` feature enables `rustls-tls-native-roots-no-provider` instead of `rustls-tls-native-roots` to remove the dependency on the `ring` crate. +- The `reqwest_rustls` feature enables `rustls-tls-native-roots-no-provider` instead of `rustls-tls-native-roots` to remove the dependency on the `ring` crate. ### Other Changes -- Deriving `SafeDebug` formats non-exhaustive types by default. Enable `debug` feature to format normal `Debug` output. -- Updated dependencies. +- Deriving `SafeDebug` formats non-exhaustive types by default. Enable `debug` feature to format normal `Debug` output. +- Updated dependencies. ## 0.2.0 (2025-04-08) ### Breaking Changes -- Consolidated all the `tokio` features into a single feature named `tokio`. Traits remain separate but `tokio` support is enabled with a single feature. -- Removed `Header` re-export from `http` module. It is still defined in the `http::headers` module. -- Removed `http-types` dependency and implemented `Method` instead. -- Removed `Pager`. -- Removed `parsing` module. +- Consolidated all the `tokio` features into a single feature named `tokio`. Traits remain separate but `tokio` support is enabled with a single feature. +- Removed `Header` re-export from `http` module. It is still defined in the `http::headers` module. +- Removed `http-types` dependency and implemented `Method` instead. +- Removed `Pager`. +- Removed `parsing` module. ### Other Changes -- Use `std::sync::LazyLock` added in rustc 1.80 instead of `once_cell::sync::Lazy`. +- Use `std::sync::LazyLock` added in rustc 1.80 instead of `once_cell::sync::Lazy`. ## 0.1.0 (2025-02-18) ### Features Added -- Initial supported release. +- Initial supported release. diff --git a/sdk/core/azure_core/src/task/mod.rs b/sdk/typespec/typespec_client_core/src/async_runtime/mod.rs similarity index 51% rename from sdk/core/azure_core/src/task/mod.rs rename to sdk/typespec/typespec_client_core/src/async_runtime/mod.rs index 1209b4fb66..b1720c97f5 100644 --- a/sdk/core/azure_core/src/task/mod.rs +++ b/sdk/typespec/typespec_client_core/src/async_runtime/mod.rs @@ -11,13 +11,13 @@ //! Example usage: //! //! ``` -//! use azure_core::task::{new_task_spawner, TaskSpawner}; +//! use typespec_client_core::async_runtime::get_async_runtime; //! use futures::FutureExt; //! //! #[tokio::main] //! async fn main() { -//! let spawner = new_task_spawner(); -//! let handle = spawner.spawn(async { +//! let async_runtime = get_async_runtime(); +//! let handle = async_runtime.spawn(async { //! // Simulate some work //! std::thread::sleep(std::time::Duration::from_secs(1)); //! }.boxed()); @@ -29,22 +29,26 @@ //! ``` //! //! -use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc}; +use std::{ + future::Future, + pin::Pin, + sync::{Arc, OnceLock}, +}; -mod standard_spawn; +mod standard_runtime; #[cfg(feature = "tokio")] -mod tokio_spawn; +mod tokio_runtime; #[cfg(test)] mod tests; #[cfg(not(target_arch = "wasm32"))] -pub(crate) type TaskFuture = Pin + Send + 'static>>; +pub type TaskFuture = Pin + Send + 'static>>; // WASM32 does not support `Send` futures, so we use a non-Send future type. #[cfg(target_arch = "wasm32")] -pub(crate) type TaskFuture = Pin + 'static>>; +pub type TaskFuture = Pin + 'static>>; /// A `SpawnedTask` is a future that represents a running task. /// It can be awaited to block until the task has completed. @@ -61,9 +65,11 @@ pub type SpawnedTask = Pin< pub type SpawnedTask = Pin>> + 'static>>; -/// An async command runner. +/// An Asynchronous Runtime. /// -pub trait TaskSpawner: Send + Sync + Debug { +/// This trait defines the various +/// +pub trait AsyncRuntime: Send + Sync { /// Spawn a task that executes a given future and returns the output. /// /// # Arguments @@ -76,13 +82,13 @@ pub trait TaskSpawner: Send + Sync + Debug { /// /// # Example /// ``` - /// use azure_core::task::{new_task_spawner, TaskSpawner}; + /// use typespec_client_core::async_runtime::get_async_runtime; /// use futures::FutureExt; /// /// #[tokio::main] /// async fn main() { - /// let spawner = new_task_spawner(); - /// let future = spawner.spawn(async { + /// let async_runtime = get_async_runtime(); + /// let future = async_runtime.spawn(async { /// // Simulate some work /// std::thread::sleep(std::time::Duration::from_secs(1)); /// }.boxed()); @@ -99,41 +105,98 @@ pub trait TaskSpawner: Send + Sync + Debug { /// that can be awaited. /// fn spawn(&self, f: TaskFuture) -> SpawnedTask; + + fn sleep( + &self, + duration: std::time::Duration, + ) -> Pin + Send + 'static>>; } -/// Creates a new [`TaskSpawner`] to enable running tasks asynchronously. +static ASYNC_RUNTIME_IMPLEMENTATION: OnceLock> = OnceLock::new(); + +/// Returns an [`AsyncRuntime`] to enable running operations which need to interact with an +/// asynchronous runtime. /// /// /// The implementation depends on the target architecture and the features enabled: -/// - If the `tokio` feature is enabled, it uses a tokio based spawner. -/// - If the `tokio` feature is not enabled and the target architecture is not `wasm32`, it uses a std::thread based spawner. +/// - If the `tokio` feature is enabled, it uses a tokio based spawner and timer. +/// - If the `tokio` feature is not enabled and the target architecture is not `wasm32`, it uses a std::thread based spawner and timer. /// /// # Returns -/// A new instance of a [`TaskSpawner`] which can be used to spawn background tasks. +/// An instance of a [`AsyncRuntime`] which can be used to spawn background tasks or perform other asynchronous operations. /// /// # Example /// /// ``` -/// use azure_core::task::{new_task_spawner, TaskSpawner}; +/// use typespec_client_core::async_runtime::get_async_runtime; /// use futures::FutureExt; /// /// #[tokio::main] /// async fn main() { -/// let spawner = new_task_spawner(); -/// let handle = spawner.spawn(async { +/// let async_runtime = get_async_runtime(); +/// let handle = async_runtime.spawn(async { /// // Simulate some work /// std::thread::sleep(std::time::Duration::from_secs(1)); /// }.boxed()); /// } /// ``` /// -pub fn new_task_spawner() -> Arc { +pub fn get_async_runtime() -> Arc { + ASYNC_RUNTIME_IMPLEMENTATION + .get_or_init(|| create_async_runtime()) + .clone() +} + +/// Sets the current [`AsyncRuntime`] to enable running operations which need to interact with an +/// asynchronous runtime. +/// +/// # Arguments +/// * `runtime` - An instance of a type that implements the [`AsyncRuntime`] trait. +/// +/// # Returns +/// Ok if the async runtime was set successfully, or an error if it has already been set. +/// +/// # Example +/// +/// ``` +/// use typespec_client_core::async_runtime::{ +/// set_async_runtime, AsyncRuntime, TaskFuture, SpawnedTask}; +/// use std::sync::Arc; +/// use futures::FutureExt; +/// +/// struct CustomRuntime; +/// +/// impl AsyncRuntime for CustomRuntime { +/// fn spawn(&self, f: TaskFuture) -> SpawnedTask { +/// unimplemented!("Custom spawn not implemented"); +/// } +/// fn sleep(&self, duration: std::time::Duration) -> TaskFuture { +/// unimplemented!("Custom sleep not implemented"); +/// } +/// } +/// +/// set_async_runtime(Arc::new(CustomRuntime)).expect("Failed to set async runtime"); +/// ``` +/// +pub fn set_async_runtime(runtime: Arc) -> crate::Result<()> { + let result = ASYNC_RUNTIME_IMPLEMENTATION.set(runtime); + if result.is_err() { + Err(crate::Error::message( + crate::error::ErrorKind::Other, + "Async runtime has already been set.", + )) + } else { + Ok(()) + } +} + +fn create_async_runtime() -> Arc { #[cfg(not(feature = "tokio"))] { - Arc::new(standard_spawn::StdSpawner) + Arc::new(standard_runtime::StdRuntime) } #[cfg(feature = "tokio")] { - Arc::new(tokio_spawn::TokioSpawner) as Arc + Arc::new(tokio_runtime::TokioRuntime) as Arc } } diff --git a/sdk/core/azure_core/src/task/standard_spawn.rs b/sdk/typespec/typespec_client_core/src/async_runtime/standard_runtime.rs similarity index 74% rename from sdk/core/azure_core/src/task/standard_spawn.rs rename to sdk/typespec/typespec_client_core/src/async_runtime/standard_runtime.rs index 9710c4cda0..a71a4885fc 100644 --- a/sdk/core/azure_core/src/task/standard_spawn.rs +++ b/sdk/typespec/typespec_client_core/src/async_runtime/standard_runtime.rs @@ -1,19 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use super::{SpawnedTask, TaskFuture, TaskSpawner}; +use super::{AsyncRuntime, SpawnedTask, TaskFuture}; + #[cfg(not(target_arch = "wasm32"))] use futures::{executor::LocalPool, task::SpawnExt}; + +#[cfg(not(target_arch = "wasm32"))] +use std::sync::atomic::{AtomicBool, Ordering}; #[cfg(not(target_arch = "wasm32"))] use std::{ future, - future::Future, - pin::Pin, sync::{Arc, Mutex}, - task::Waker, - task::{Context, Poll}, + task::{Context, Poll, Waker}, thread, }; +use std::{future::Future, pin::Pin}; #[cfg(not(target_arch = "wasm32"))] use tracing::debug; @@ -78,11 +80,11 @@ impl Future for ThreadJoinFuture { } } -/// A [`TaskSpawner`] using [`std::thread::spawn`]. -#[derive(Debug)] -pub struct StdSpawner; +/// An [`AsyncRuntime`] using [`std::thread::spawn`]. +#[allow(dead_code)] +pub(crate) struct StdRuntime; -impl TaskSpawner for StdSpawner { +impl AsyncRuntime for StdRuntime { #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] fn spawn(&self, f: TaskFuture) -> SpawnedTask { #[cfg(target_arch = "wasm32")] @@ -143,4 +145,57 @@ impl TaskSpawner for StdSpawner { Box::pin(join_future) } } + + /// Creates a future that resolves after a specified duration of time. + /// + /// Uses a simple thread based implementation for sleep. A more efficient + /// implementation is available by using the `tokio` crate feature. + #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] + fn sleep( + &self, + duration: std::time::Duration, + ) -> Pin + Send + 'static>> { + #[cfg(target_arch = "wasm32")] + { + panic!("sleep is not supported on wasm32") + } + #[cfg(not(target_arch = "wasm32"))] + Box::pin(Sleep { + signal: None, + duration, + }) + } +} + +#[derive(Debug)] +#[cfg(not(target_arch = "wasm32"))] +pub struct Sleep { + signal: Option>, + duration: std::time::Duration, +} + +#[cfg(not(target_arch = "wasm32"))] +impl Future for Sleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(signal) = &self.signal { + if signal.load(Ordering::Acquire) { + Poll::Ready(()) + } else { + Poll::Pending + } + } else { + let signal = Arc::new(AtomicBool::new(false)); + let waker = cx.waker().clone(); + let duration = self.duration; + self.get_mut().signal = Some(signal.clone()); + thread::spawn(move || { + thread::sleep(duration); + signal.store(true, Ordering::Release); + waker.wake(); + }); + Poll::Pending + } + } } diff --git a/sdk/core/azure_core/src/task/tests.rs b/sdk/typespec/typespec_client_core/src/async_runtime/tests.rs similarity index 66% rename from sdk/core/azure_core/src/task/tests.rs rename to sdk/typespec/typespec_client_core/src/async_runtime/tests.rs index bd53de4e80..7ff3ad5643 100644 --- a/sdk/core/azure_core/src/task/tests.rs +++ b/sdk/typespec/typespec_client_core/src/async_runtime/tests.rs @@ -9,11 +9,11 @@ use std::time::Duration; #[cfg(not(feature = "tokio"))] #[test] fn test_task_spawner_execution() { - let spawner = new_task_spawner(); + let runtime = get_async_runtime(); let result = Arc::new(Mutex::new(false)); let result_clone = Arc::clone(&result); - let handle = spawner.spawn( + let handle = runtime.spawn( async move { // Simulate some work crate::sleep::sleep(Duration::from_millis(50)).await; @@ -32,11 +32,11 @@ fn test_task_spawner_execution() { #[cfg(feature = "tokio")] #[tokio::test] async fn tokio_task_spawner_execution() { - let spawner = new_task_spawner(); + let async_runtime = get_async_runtime(); let result = Arc::new(Mutex::new(false)); let result_clone = Arc::clone(&result); - let handle = spawner.spawn( + let handle = async_runtime.spawn( async move { // Simulate some work crate::sleep::sleep(Duration::from_millis(50)).await; @@ -55,7 +55,7 @@ async fn tokio_task_spawner_execution() { #[cfg(feature = "tokio")] #[tokio::test] async fn test_tokio_specific_handling() { - let spawner = Arc::new(tokio_spawn::TokioSpawner); + let spawner = Arc::new(tokio_runtime::TokioRuntime); let task_completed = Arc::new(Mutex::new(false)); let task_completed_clone = Arc::clone(&task_completed); @@ -73,7 +73,7 @@ async fn test_tokio_specific_handling() { #[cfg(feature = "tokio")] #[tokio::test] async fn tokio_multiple_tasks() { - let spawner = Arc::new(tokio_spawn::TokioSpawner); + let spawner = Arc::new(tokio_runtime::TokioRuntime); let counter = Arc::new(Mutex::new(0)); let mut handles = Vec::new(); @@ -101,7 +101,7 @@ async fn tokio_multiple_tasks() { #[cfg(feature = "tokio")] #[tokio::test] async fn tokio_task_execution() { - let spawner = Arc::new(tokio_spawn::TokioSpawner); + let spawner = Arc::new(tokio_runtime::TokioRuntime); let result = Arc::new(Mutex::new(false)); let result_clone = Arc::clone(&result); @@ -126,7 +126,7 @@ async fn tokio_task_execution() { // When the "tokio" feature is not enabled, it uses std::thread::sleep which does not require a tokio runtime. #[test] fn std_specific_handling() { - let spawner = Arc::new(standard_spawn::StdSpawner); + let spawner = Arc::new(standard_runtime::StdRuntime); let task_completed = Arc::new(Mutex::new(false)); let task_completed_clone = Arc::clone(&task_completed); @@ -145,7 +145,7 @@ fn std_specific_handling() { #[test] fn std_multiple_tasks() { - let spawner = Arc::new(standard_spawn::StdSpawner); + let spawner = Arc::new(standard_runtime::StdRuntime); let counter = Arc::new(Mutex::new(0)); let mut handles = Vec::new(); @@ -175,11 +175,11 @@ fn std_multiple_tasks() { #[cfg(not(feature = "tokio"))] #[test] fn std_task_execution() { - let spawner = Arc::new(standard_spawn::StdSpawner); + let runtime = Arc::new(standard_runtime::StdRuntime); let result = Arc::new(Mutex::new(false)); let result_clone = Arc::clone(&result); - let handle = spawner.spawn( + let handle = runtime.spawn( async move { // Simulate some work crate::sleep::sleep(Duration::from_millis(500)).await; @@ -195,3 +195,77 @@ fn std_task_execution() { // Verify the task executed assert!(*result.lock().unwrap()); } + +// Basic test that launches 10k futures and waits for them to complete: +// it has a high chance of failing if there is a race condition in the sleep method; +// otherwise, it runs quickly. +#[cfg(not(feature = "tokio"))] +#[tokio::test] +async fn test_timeout() { + use super::*; + use std::time::Duration; + use tokio::task::JoinSet; + + let async_runtime = get_async_runtime(); + let mut join_set = JoinSet::default(); + let total = 10000; + for _i in 0..total { + let runtime = async_runtime.clone(); + join_set.spawn(async move { + runtime.sleep(Duration::from_millis(10)).await; + }); + } + + loop { + let res = + tokio::time::timeout(std::time::Duration::from_secs(10), join_set.join_next()).await; + assert!(res.is_ok()); + if let Ok(None) = res { + break; + } + } +} + +#[tokio::test] +async fn test_sleep() { + let runtime = get_async_runtime(); + let start = std::time::Instant::now(); + runtime.sleep(Duration::from_millis(100)).await; + let elapsed = start.elapsed(); + assert!(elapsed >= Duration::from_millis(100)); +} + +#[test] +fn test_get_runtime() { + // Ensure that the runtime can be retrieved without panicking + let _runtime = get_async_runtime(); +} + +struct TestRuntime; + +impl AsyncRuntime for TestRuntime { + fn spawn(&self, _f: TaskFuture) -> SpawnedTask { + unimplemented!("TestRuntime does not support spawning tasks"); + } + + fn sleep( + &self, + _duration: std::time::Duration, + ) -> Pin + Send + 'static>> { + unimplemented!("TestRuntime does not support sleeping"); + } +} + +// This test is ignored because by default, cargo test runs all tests in parallel, but +// this test sets the runtime, which will fail if run in parallel with other tests that +// get the runtime. +#[test] +#[ignore = "Skipping the runtime set test to avoid conflicts with parallel test execution"] +fn test_set_runtime() { + let runtime = Arc::new(TestRuntime); + // Ensure that the runtime can be set without panicking + set_async_runtime(runtime.clone()).unwrap(); + + // Ensure that setting the runtime again fails + set_async_runtime(runtime.clone()).unwrap_err(); +} diff --git a/sdk/typespec/typespec_client_core/src/async_runtime/tokio_runtime.rs b/sdk/typespec/typespec_client_core/src/async_runtime/tokio_runtime.rs new file mode 100644 index 0000000000..e9b19bd55d --- /dev/null +++ b/sdk/typespec/typespec_client_core/src/async_runtime/tokio_runtime.rs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use super::{AsyncRuntime, SpawnedTask, TaskFuture}; +use std::pin::Pin; + +/// An [`AsyncRuntime`] using `tokio` based APIs. +pub(crate) struct TokioRuntime; + +impl AsyncRuntime for TokioRuntime { + fn spawn(&self, f: TaskFuture) -> SpawnedTask { + let handle = ::tokio::spawn(f); + Box::pin(async move { + handle + .await + .map_err(|e| Box::new(e) as Box) + }) + } + + fn sleep( + &self, + duration: std::time::Duration, + ) -> Pin + Send>> { + Box::pin(::tokio::time::sleep(duration)) + } +} diff --git a/sdk/typespec/typespec_client_core/src/lib.rs b/sdk/typespec/typespec_client_core/src/lib.rs index 8281ae3d9e..6f2e832cd2 100644 --- a/sdk/typespec/typespec_client_core/src/lib.rs +++ b/sdk/typespec/typespec_client_core/src/lib.rs @@ -6,6 +6,7 @@ #[macro_use] mod macros; +pub mod async_runtime; pub mod base64; pub mod date; pub mod error; @@ -23,3 +24,5 @@ pub mod xml; pub use crate::error::{Error, Result}; pub use bytes::Bytes; pub use uuid::Uuid; + +pub use sleep::sleep; diff --git a/sdk/typespec/typespec_client_core/src/sleep.rs b/sdk/typespec/typespec_client_core/src/sleep.rs new file mode 100644 index 0000000000..0a850aee0b --- /dev/null +++ b/sdk/typespec/typespec_client_core/src/sleep.rs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Sleep functions. + +use crate::async_runtime::get_async_runtime; + +/// Sleeps for the specified duration using the configured async runtime. +/// +/// # Arguments +/// * `duration` - The duration to sleep for. +/// +/// # Returns +/// A future that resolves when the sleep duration has elapsed. +/// +/// # Example +/// ``` +/// use typespec_client_core::sleep; +/// use std::time::Duration; +/// +/// #[tokio::main] +/// async fn main() { +/// // Sleep for 1 second +/// sleep(Duration::from_secs(1)).await; +/// println!("Slept for 1 second"); +/// } +/// ``` +pub async fn sleep(duration: std::time::Duration) { + get_async_runtime().sleep(duration).await +} diff --git a/sdk/typespec/typespec_client_core/src/sleep/mod.rs b/sdk/typespec/typespec_client_core/src/sleep/mod.rs deleted file mode 100644 index f59a29afe8..0000000000 --- a/sdk/typespec/typespec_client_core/src/sleep/mod.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -//! Sleep functions. - -#[cfg(any(not(feature = "tokio"), target_arch = "wasm32"))] -mod thread; - -#[cfg(any(not(feature = "tokio"), target_arch = "wasm32"))] -pub use self::thread::{sleep, Sleep}; - -#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))] -pub use tokio::time::{sleep, Sleep}; - -// Unit tests -#[cfg(test)] -mod tests { - - // Basic test that launches 10k futures and waits for them to complete: - // it has a high chance of failing if there is a race condition in the sleep method; - // otherwise, it runs quickly. - #[cfg(not(feature = "tokio"))] - #[tokio::test] - async fn test_timeout() { - use super::*; - use std::time::Duration; - use tokio::task::JoinSet; - - let mut join_set = JoinSet::default(); - let total = 10000; - for _i in 0..total { - join_set.spawn(async move { - sleep(Duration::from_millis(10)).await; - }); - } - - loop { - let res = - tokio::time::timeout(std::time::Duration::from_secs(10), join_set.join_next()) - .await; - assert!(res.is_ok()); - if let Ok(None) = res { - break; - } - } - } -} diff --git a/sdk/typespec/typespec_client_core/src/sleep/thread.rs b/sdk/typespec/typespec_client_core/src/sleep/thread.rs deleted file mode 100644 index 1e8684f83b..0000000000 --- a/sdk/typespec/typespec_client_core/src/sleep/thread.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -use futures::Future; -use std::{ - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll}, - thread, - time::Duration, -}; - -/// Creates a future that resolves after a specified duration of time. -/// -/// Uses a simple thread based implementation for sleep. A more efficient -/// implementation is available by using the `tokio` crate feature. -pub fn sleep(duration: Duration) -> Sleep { - Sleep { - signal: None, - duration, - } -} - -#[derive(Debug)] -pub struct Sleep { - signal: Option>, - duration: Duration, -} - -impl Future for Sleep { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(signal) = &self.signal { - if signal.load(Ordering::Acquire) { - Poll::Ready(()) - } else { - Poll::Pending - } - } else { - let signal = Arc::new(AtomicBool::new(false)); - let waker = cx.waker().clone(); - let duration = self.duration; - self.get_mut().signal = Some(signal.clone()); - thread::spawn(move || { - thread::sleep(duration); - signal.store(true, Ordering::Release); - waker.wake(); - }); - Poll::Pending - } - } -}