diff --git a/arti/src/proxy.rs b/arti/src/proxy.rs index 9a55fdb5d..90704392d 100644 --- a/arti/src/proxy.rs +++ b/arti/src/proxy.rs @@ -11,8 +11,7 @@ use std::sync::Arc; use tor_client::{ConnectPrefs, TorClient}; use tor_proto::circuit::IpVersionPreference; -use tor_rtcompat::timer::TimeoutError; -use tor_rtcompat::{Runtime, TcpListener}; +use tor_rtcompat::{Runtime, TcpListener, TimeoutError}; use tor_socksproto::{SocksCmd, SocksRequest}; use anyhow::{Context, Result}; diff --git a/tor-chanmgr/src/lib.rs b/tor-chanmgr/src/lib.rs index 656d5124b..c73e3edad 100644 --- a/tor-chanmgr/src/lib.rs +++ b/tor-chanmgr/src/lib.rs @@ -36,7 +36,7 @@ use std::time::Duration; pub use err::Error; -use tor_rtcompat::Runtime; +use tor_rtcompat::{Runtime, SleepProviderExt}; /// A Type that remembers a set of live channels, and launches new /// ones on request. @@ -170,12 +170,10 @@ impl ChanMgr { // XXXX make this a parameter. let timeout = Duration::new(5, 0); - let result = tor_rtcompat::timer::timeout_rt( - &self.runtime, - timeout, - self.build_channel_once(target), - ) - .await; + let result = self + .runtime + .timeout(timeout, self.build_channel_once(target)) + .await; match result { Ok(Ok(chan)) => Ok(chan), diff --git a/tor-circmgr/src/lib.rs b/tor-circmgr/src/lib.rs index 8f36e81b6..8e84889e9 100644 --- a/tor-circmgr/src/lib.rs +++ b/tor-circmgr/src/lib.rs @@ -17,7 +17,7 @@ use tor_netdir::{fallback::FallbackDir, NetDir}; use tor_netdoc::types::policy::PortPolicy; use tor_proto::circuit::{CircParameters, ClientCirc, UniqId}; use tor_retry::RetryError; -use tor_rtcompat::Runtime; +use tor_rtcompat::{Runtime, SleepProviderExt}; use anyhow::Result; use futures::lock::Mutex; @@ -520,12 +520,10 @@ impl CircMgr { let mut error = RetryError::while_doing("build a circuit"); for _ in 0..n_tries { - let result = tor_rtcompat::timer::timeout_rt( - &self.runtime, - timeout, - self.build_once_by_usage(rng, netdir, target_usage), - ) - .await; + let result = self + .runtime + .timeout(timeout, self.build_once_by_usage(rng, netdir, target_usage)) + .await; match result { Ok(Ok((circ, usage))) => { diff --git a/tor-client/src/client.rs b/tor-client/src/client.rs index 63226dc00..e11424387 100644 --- a/tor-client/src/client.rs +++ b/tor-client/src/client.rs @@ -8,7 +8,7 @@ use tor_circmgr::TargetPort; use tor_dirmgr::NetDirConfig; use tor_proto::circuit::IpVersionPreference; use tor_proto::stream::DataStream; -use tor_rtcompat::Runtime; +use tor_rtcompat::{Runtime, SleepProviderExt}; use std::sync::Arc; use std::time::Duration; @@ -131,8 +131,10 @@ impl TorClient { let stream_timeout = Duration::new(10, 0); let stream_future = circ.begin_stream(&addr, port, Some(flags.begin_flags())); - let stream = - tor_rtcompat::timer::timeout_rt(&self.runtime, stream_timeout, stream_future).await??; + let stream = self + .runtime + .timeout(stream_timeout, stream_future) + .await??; Ok(stream) } diff --git a/tor-dirclient/src/err.rs b/tor-dirclient/src/err.rs index 26b7a056c..50705618e 100644 --- a/tor-dirclient/src/err.rs +++ b/tor-dirclient/src/err.rs @@ -1,7 +1,7 @@ //! Declare dirclient-specific errors. use thiserror::Error; -use tor_rtcompat::timer::TimeoutError; +use tor_rtcompat::TimeoutError; /// An error originating from the tor-dirclient crate. #[derive(Error, Debug)] diff --git a/tor-dirclient/src/lib.rs b/tor-dirclient/src/lib.rs index 5198b7844..75909d5ed 100644 --- a/tor-dirclient/src/lib.rs +++ b/tor-dirclient/src/lib.rs @@ -16,7 +16,7 @@ mod response; mod util; use tor_circmgr::{CircMgr, DirInfo}; -use tor_rtcompat::{Runtime, SleepProvider}; +use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt}; use async_compression::futures::bufread::{XzDecoder, ZlibDecoder, ZstdDecoder}; use futures::io::{ @@ -57,8 +57,6 @@ where R: Runtime, SP: SleepProvider, { - use tor_rtcompat::timer::timeout_rt; - let circuit = circ_mgr.get_or_launch_dir(dirinfo).await?; // XXXX should be an option, and is too long. @@ -66,7 +64,9 @@ where let source = SourceInfo::new(circuit.unique_id()); // Launch the stream. - let mut stream = timeout_rt(runtime, begin_timeout, circuit.begin_dir_stream()).await??; // XXXX handle fatalities here too + let mut stream = runtime + .timeout(begin_timeout, circuit.begin_dir_stream()) + .await??; // XXXX handle fatalities here too // TODO: Perhaps we want separate timeouts for each phase of this. // For now, we just use higher-level timeouts in `dirmgr`. diff --git a/tor-dirmgr/src/bootstrap.rs b/tor-dirmgr/src/bootstrap.rs index 12ea7c0c8..c6c40658d 100644 --- a/tor-dirmgr/src/bootstrap.rs +++ b/tor-dirmgr/src/bootstrap.rs @@ -17,8 +17,7 @@ use futures::FutureExt; use futures::StreamExt; use log::{info, warn}; use tor_dirclient::DirResponse; -use tor_rtcompat::timer::sleep_until_wallclock_rt; -use tor_rtcompat::Runtime; +use tor_rtcompat::{Runtime, SleepProviderExt}; /// Try to read a set of documents from `dirmgr` by ID. async fn load_all( @@ -228,7 +227,7 @@ pub(crate) async fn download( Ok(changed) => changed } } - _ = sleep_until_wallclock_rt(&runtime, reset_time).fuse() => { + _ = runtime.sleep_until_wallclock(reset_time).fuse() => { // We need to reset. This can happen if (for // example) we're downloading the last few // microdescriptors on a consensus that now @@ -259,7 +258,7 @@ pub(crate) async fn download( let reset_time = no_more_than_a_week(state.reset_time()); let delay = retry.next_delay(&mut rand::thread_rng()); futures::select_biased! { - _ = sleep_until_wallclock_rt(&runtime, reset_time).fuse() => { + _ = runtime.sleep_until_wallclock(reset_time).fuse() => { state = state.reset()?; continue 'next_state; } diff --git a/tor-dirmgr/src/lib.rs b/tor-dirmgr/src/lib.rs index 56b158c73..ec294b4b5 100644 --- a/tor-dirmgr/src/lib.rs +++ b/tor-dirmgr/src/lib.rs @@ -34,8 +34,7 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use futures::{channel::oneshot, lock::Mutex, task::SpawnExt}; use log::{info, warn}; -use tor_rtcompat::timer::sleep_until_wallclock_rt; -use tor_rtcompat::Runtime; +use tor_rtcompat::{Runtime, SleepProviderExt}; use std::sync::Arc; use std::{collections::HashMap, sync::Weak}; @@ -296,7 +295,7 @@ impl DirMgr { let reset_at = state.reset_time(); match reset_at { - Some(t) => sleep_until_wallclock_rt(&runtime, t).await, + Some(t) => runtime.sleep_until_wallclock(t).await, None => return Ok(()), } state = state.reset()?; diff --git a/tor-rtcompat/src/lib.rs b/tor-rtcompat/src/lib.rs index e308c722a..262a780c7 100644 --- a/tor-rtcompat/src/lib.rs +++ b/tor-rtcompat/src/lib.rs @@ -27,12 +27,15 @@ use once_cell::sync::OnceCell; //#![deny(clippy::missing_docs_in_private_items)] pub(crate) mod impls; +mod timer; mod traits; pub use traits::{ CertifiedConn, Runtime, SleepProvider, SpawnBlocking, TcpListener, TcpProvider, TlsProvider, }; +pub use timer::{SleepProviderExt, Timeout, TimeoutError}; + pub mod tls { pub use crate::traits::{CertifiedConn, TlsConnector}; } @@ -73,115 +76,3 @@ pub mod task { crate::runtime_ref().block_on(task) } } - -/// Functions and types for manipulating timers. -pub mod timer { - use crate::traits::SleepProvider; - use futures::Future; - use pin_project::pin_project; - use std::{ - pin::Pin, - task::{Context, Poll}, - time::{Duration, SystemTime}, - }; - - #[derive(Copy, Clone, Debug)] - pub struct TimeoutError; - impl std::error::Error for TimeoutError {} - impl std::fmt::Display for TimeoutError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Timeout expired") - } - } - - #[pin_project] - pub struct Timeout { - #[pin] - future: T, - #[pin] - sleep_future: S, - } - - pub fn timeout_rt( - runtime: &R, - duration: Duration, - future: F, - ) -> impl Future> { - let sleep_future = runtime.sleep(duration); - - Timeout { - future, - sleep_future, - } - } - - impl Future for Timeout - where - T: Future, - S: Future, - { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - if let Poll::Ready(x) = this.future.poll(cx) { - return Poll::Ready(Ok(x)); - } - - match this.sleep_future.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(()) => Poll::Ready(Err(TimeoutError)), - } - } - } - - /// Pause until the wall-clock is at `when` or later, trying to - /// recover from clock jumps. - pub async fn sleep_until_wallclock_rt(runtime: &R, when: SystemTime) - where - R: SleepProvider, - { - loop { - let now = SystemTime::now(); - if now >= when { - return; - } - let delay = calc_next_delay(now, when); - runtime.sleep(delay).await; - } - } - - /// Return the amount of time we should wait next, when running - /// sleep_until_wallclock(). - /// - /// (This is a separate function for testing.) - fn calc_next_delay(now: SystemTime, when: SystemTime) -> Duration { - /// We never sleep more than this much, in case our system clock jumps - const MAX_SLEEP: Duration = Duration::from_secs(600); - let remainder = when - .duration_since(now) - .unwrap_or_else(|_| Duration::from_secs(0)); - std::cmp::min(MAX_SLEEP, remainder) - } - - #[cfg(test)] - mod test { - use super::*; - #[test] - fn sleep_delay() { - use calc_next_delay as calc; - let minute = Duration::from_secs(60); - let second = Duration::from_secs(1); - let start = SystemTime::now(); - - let target = start + 30 * minute; - - assert_eq!(calc(start, target), minute * 10); - assert_eq!(calc(target + minute, target), minute * 0); - assert_eq!(calc(target, target), minute * 0); - assert_eq!(calc(target - second, target), second); - assert_eq!(calc(target - minute * 9, target), minute * 9); - assert_eq!(calc(target - minute * 11, target), minute * 10); - } - } -} diff --git a/tor-rtcompat/src/timer.rs b/tor-rtcompat/src/timer.rs new file mode 100644 index 000000000..8cebd3cb0 --- /dev/null +++ b/tor-rtcompat/src/timer.rs @@ -0,0 +1,107 @@ +use crate::traits::SleepProvider; +use async_trait::async_trait; +use futures::Future; +use pin_project::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, + time::{Duration, SystemTime}, +}; + +#[derive(Copy, Clone, Debug)] +pub struct TimeoutError; +impl std::error::Error for TimeoutError {} +impl std::fmt::Display for TimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Timeout expired") + } +} + +#[async_trait] +pub trait SleepProviderExt: SleepProvider { + fn timeout(&self, duration: Duration, future: F) -> Timeout { + let sleep_future = self.sleep(duration); + + Timeout { + future, + sleep_future, + } + } + + /// Pause until the wall-clock is at `when` or later, trying to + /// recover from clock jumps. + async fn sleep_until_wallclock(&self, when: SystemTime) { + loop { + let now = SystemTime::now(); + if now >= when { + return; + } + let delay = calc_next_delay(now, when); + self.sleep(delay).await; + } + } +} + +impl SleepProviderExt for T {} + +#[pin_project] +pub struct Timeout { + #[pin] + future: T, + #[pin] + sleep_future: S, +} + +impl Future for Timeout +where + T: Future, + S: Future, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + if let Poll::Ready(x) = this.future.poll(cx) { + return Poll::Ready(Ok(x)); + } + + match this.sleep_future.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => Poll::Ready(Err(TimeoutError)), + } + } +} + +/// Return the amount of time we should wait next, when running +/// sleep_until_wallclock(). +/// +/// (This is a separate function for testing.) +fn calc_next_delay(now: SystemTime, when: SystemTime) -> Duration { + /// We never sleep more than this much, in case our system clock jumps + const MAX_SLEEP: Duration = Duration::from_secs(600); + let remainder = when + .duration_since(now) + .unwrap_or_else(|_| Duration::from_secs(0)); + std::cmp::min(MAX_SLEEP, remainder) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn sleep_delay() { + use calc_next_delay as calc; + let minute = Duration::from_secs(60); + let second = Duration::from_secs(1); + let start = SystemTime::now(); + + let target = start + 30 * minute; + + assert_eq!(calc(start, target), minute * 10); + assert_eq!(calc(target + minute, target), minute * 0); + assert_eq!(calc(target, target), minute * 0); + assert_eq!(calc(target - second, target), second); + assert_eq!(calc(target - minute * 9, target), minute * 9); + assert_eq!(calc(target - minute * 11, target), minute * 10); + } +}