diff --git a/tor-rtcompat/src/impls/async_std.rs b/tor-rtcompat/src/impls/async_std.rs index 7a1b555b7..6e19ad29a 100644 --- a/tor-rtcompat/src/impls/async_std.rs +++ b/tor-rtcompat/src/impls/async_std.rs @@ -27,32 +27,10 @@ pub mod net { } /// Functions for launching and managing tasks (async_std implementation) -pub mod task { - pub use async_std_crate::task::{block_on, sleep, spawn, JoinHandle}; - - //pub use async_std_crate::task::JoinHandle; - - /// Stop the task `handle` from running. - /// - /// If you drop `handle` without calling this function, it will just - /// run to completion. - #[allow(unused)] - pub async fn cancel_task(handle: JoinHandle) { - handle.cancel().await; - } -} +pub mod task {} /// Functions and types for manipulating timers (async_std implementation) -pub mod timer { - use std::time::Duration; - - /// Return a future that will be ready after `duration` has passed. - pub fn sleep(duration: Duration) -> async_io::Timer { - async_io::Timer::after(duration) - } - - pub use async_std_crate::future::{timeout, TimeoutError}; -} +pub mod timer {} /// Implement TLS using async_std and async_native_tls. pub mod tls { diff --git a/tor-rtcompat/src/impls/tokio.rs b/tor-rtcompat/src/impls/tokio.rs index f6034e4be..bd09efa51 100644 --- a/tor-rtcompat/src/impls/tokio.rs +++ b/tor-rtcompat/src/impls/tokio.rs @@ -130,9 +130,7 @@ pub mod net { pub mod task {} /// Functions and types for manipulating timers (tokio implementation) -pub mod timer { - pub use tokio_crate::time::{error::Elapsed as TimeoutError, sleep, timeout}; -} +pub mod timer {} /// Implement a set of TLS wrappers for use with tokio. /// diff --git a/tor-rtcompat/src/lib.rs b/tor-rtcompat/src/lib.rs index b5396b234..402ae120e 100644 --- a/tor-rtcompat/src/lib.rs +++ b/tor-rtcompat/src/lib.rs @@ -89,9 +89,64 @@ pub mod task { /// Functions and types for manipulating timers. pub mod timer { - use std::time::{Duration, SystemTime}; + use futures::Future; + use pin_project::pin_project; + use std::{ + pin::Pin, + task::{Context, Poll}, + time::{Duration, SystemTime}, + }; - pub use crate::imp::timer::*; + pub use crate::task::sleep; // XXXX redundant. + + #[derive(Copy, Clone, Debug)] + pub struct TimeoutError; + impl std::error::Error for TimeoutError {} + impl std::fmt::Display for TimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Timeout expired") + } + } + + #[pin_project] + pub struct Timeout { + #[pin] + future: T, + #[pin] + sleep_future: S, + } + + pub fn timeout( + duration: Duration, + future: F, + ) -> impl Future> { + let sleep_future = crate::task::sleep(duration); + + Timeout { + future, + sleep_future, + } + } + + impl Future for Timeout + where + T: Future, + S: Future, + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + if let Poll::Ready(x) = this.future.poll(cx) { + return Poll::Ready(Ok(x)); + } + + match this.sleep_future.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => Poll::Ready(Err(TimeoutError)), + } + } + } /// Pause until the wall-clock is at `when` or later, trying to /// recover from clock jumps.