Move timer functions into an extension trait.

This commit is contained in:
Nick Mathewson 2021-04-17 09:37:46 -04:00
parent 6878fe8336
commit 92de7c58f0
10 changed files with 136 additions and 143 deletions

View File

@ -11,8 +11,7 @@ use std::sync::Arc;
use tor_client::{ConnectPrefs, TorClient};
use tor_proto::circuit::IpVersionPreference;
use tor_rtcompat::timer::TimeoutError;
use tor_rtcompat::{Runtime, TcpListener};
use tor_rtcompat::{Runtime, TcpListener, TimeoutError};
use tor_socksproto::{SocksCmd, SocksRequest};
use anyhow::{Context, Result};

View File

@ -36,7 +36,7 @@ use std::time::Duration;
pub use err::Error;
use tor_rtcompat::Runtime;
use tor_rtcompat::{Runtime, SleepProviderExt};
/// A Type that remembers a set of live channels, and launches new
/// ones on request.
@ -170,12 +170,10 @@ impl<R: Runtime> ChanMgr<R> {
// XXXX make this a parameter.
let timeout = Duration::new(5, 0);
let result = tor_rtcompat::timer::timeout_rt(
&self.runtime,
timeout,
self.build_channel_once(target),
)
.await;
let result = self
.runtime
.timeout(timeout, self.build_channel_once(target))
.await;
match result {
Ok(Ok(chan)) => Ok(chan),

View File

@ -17,7 +17,7 @@ use tor_netdir::{fallback::FallbackDir, NetDir};
use tor_netdoc::types::policy::PortPolicy;
use tor_proto::circuit::{CircParameters, ClientCirc, UniqId};
use tor_retry::RetryError;
use tor_rtcompat::Runtime;
use tor_rtcompat::{Runtime, SleepProviderExt};
use anyhow::Result;
use futures::lock::Mutex;
@ -520,12 +520,10 @@ impl<R: Runtime> CircMgr<R> {
let mut error = RetryError::while_doing("build a circuit");
for _ in 0..n_tries {
let result = tor_rtcompat::timer::timeout_rt(
&self.runtime,
timeout,
self.build_once_by_usage(rng, netdir, target_usage),
)
.await;
let result = self
.runtime
.timeout(timeout, self.build_once_by_usage(rng, netdir, target_usage))
.await;
match result {
Ok(Ok((circ, usage))) => {

View File

@ -8,7 +8,7 @@ use tor_circmgr::TargetPort;
use tor_dirmgr::NetDirConfig;
use tor_proto::circuit::IpVersionPreference;
use tor_proto::stream::DataStream;
use tor_rtcompat::Runtime;
use tor_rtcompat::{Runtime, SleepProviderExt};
use std::sync::Arc;
use std::time::Duration;
@ -131,8 +131,10 @@ impl<R: Runtime> TorClient<R> {
let stream_timeout = Duration::new(10, 0);
let stream_future = circ.begin_stream(&addr, port, Some(flags.begin_flags()));
let stream =
tor_rtcompat::timer::timeout_rt(&self.runtime, stream_timeout, stream_future).await??;
let stream = self
.runtime
.timeout(stream_timeout, stream_future)
.await??;
Ok(stream)
}

View File

@ -1,7 +1,7 @@
//! Declare dirclient-specific errors.
use thiserror::Error;
use tor_rtcompat::timer::TimeoutError;
use tor_rtcompat::TimeoutError;
/// An error originating from the tor-dirclient crate.
#[derive(Error, Debug)]

View File

@ -16,7 +16,7 @@ mod response;
mod util;
use tor_circmgr::{CircMgr, DirInfo};
use tor_rtcompat::{Runtime, SleepProvider};
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
use async_compression::futures::bufread::{XzDecoder, ZlibDecoder, ZstdDecoder};
use futures::io::{
@ -57,8 +57,6 @@ where
R: Runtime,
SP: SleepProvider,
{
use tor_rtcompat::timer::timeout_rt;
let circuit = circ_mgr.get_or_launch_dir(dirinfo).await?;
// XXXX should be an option, and is too long.
@ -66,7 +64,9 @@ where
let source = SourceInfo::new(circuit.unique_id());
// Launch the stream.
let mut stream = timeout_rt(runtime, begin_timeout, circuit.begin_dir_stream()).await??; // XXXX handle fatalities here too
let mut stream = runtime
.timeout(begin_timeout, circuit.begin_dir_stream())
.await??; // XXXX handle fatalities here too
// TODO: Perhaps we want separate timeouts for each phase of this.
// For now, we just use higher-level timeouts in `dirmgr`.

View File

@ -17,8 +17,7 @@ use futures::FutureExt;
use futures::StreamExt;
use log::{info, warn};
use tor_dirclient::DirResponse;
use tor_rtcompat::timer::sleep_until_wallclock_rt;
use tor_rtcompat::Runtime;
use tor_rtcompat::{Runtime, SleepProviderExt};
/// Try to read a set of documents from `dirmgr` by ID.
async fn load_all<R: Runtime>(
@ -228,7 +227,7 @@ pub(crate) async fn download<R: Runtime>(
Ok(changed) => changed
}
}
_ = sleep_until_wallclock_rt(&runtime, reset_time).fuse() => {
_ = runtime.sleep_until_wallclock(reset_time).fuse() => {
// We need to reset. This can happen if (for
// example) we're downloading the last few
// microdescriptors on a consensus that now
@ -259,7 +258,7 @@ pub(crate) async fn download<R: Runtime>(
let reset_time = no_more_than_a_week(state.reset_time());
let delay = retry.next_delay(&mut rand::thread_rng());
futures::select_biased! {
_ = sleep_until_wallclock_rt(&runtime, reset_time).fuse() => {
_ = runtime.sleep_until_wallclock(reset_time).fuse() => {
state = state.reset()?;
continue 'next_state;
}

View File

@ -34,8 +34,7 @@ use anyhow::{Context, Result};
use async_trait::async_trait;
use futures::{channel::oneshot, lock::Mutex, task::SpawnExt};
use log::{info, warn};
use tor_rtcompat::timer::sleep_until_wallclock_rt;
use tor_rtcompat::Runtime;
use tor_rtcompat::{Runtime, SleepProviderExt};
use std::sync::Arc;
use std::{collections::HashMap, sync::Weak};
@ -296,7 +295,7 @@ impl<R: Runtime> DirMgr<R> {
let reset_at = state.reset_time();
match reset_at {
Some(t) => sleep_until_wallclock_rt(&runtime, t).await,
Some(t) => runtime.sleep_until_wallclock(t).await,
None => return Ok(()),
}
state = state.reset()?;

View File

@ -27,12 +27,15 @@ use once_cell::sync::OnceCell;
//#![deny(clippy::missing_docs_in_private_items)]
pub(crate) mod impls;
mod timer;
mod traits;
pub use traits::{
CertifiedConn, Runtime, SleepProvider, SpawnBlocking, TcpListener, TcpProvider, TlsProvider,
};
pub use timer::{SleepProviderExt, Timeout, TimeoutError};
pub mod tls {
pub use crate::traits::{CertifiedConn, TlsConnector};
}
@ -73,115 +76,3 @@ pub mod task {
crate::runtime_ref().block_on(task)
}
}
/// Functions and types for manipulating timers.
pub mod timer {
use crate::traits::SleepProvider;
use futures::Future;
use pin_project::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
time::{Duration, SystemTime},
};
#[derive(Copy, Clone, Debug)]
pub struct TimeoutError;
impl std::error::Error for TimeoutError {}
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Timeout expired")
}
}
#[pin_project]
pub struct Timeout<T, S> {
#[pin]
future: T,
#[pin]
sleep_future: S,
}
pub fn timeout_rt<R: SleepProvider, F: Future>(
runtime: &R,
duration: Duration,
future: F,
) -> impl Future<Output = Result<F::Output, TimeoutError>> {
let sleep_future = runtime.sleep(duration);
Timeout {
future,
sleep_future,
}
}
impl<T, S> Future for Timeout<T, S>
where
T: Future,
S: Future<Output = ()>,
{
type Output = Result<T::Output, TimeoutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(x) = this.future.poll(cx) {
return Poll::Ready(Ok(x));
}
match this.sleep_future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(()) => Poll::Ready(Err(TimeoutError)),
}
}
}
/// Pause until the wall-clock is at `when` or later, trying to
/// recover from clock jumps.
pub async fn sleep_until_wallclock_rt<R>(runtime: &R, when: SystemTime)
where
R: SleepProvider,
{
loop {
let now = SystemTime::now();
if now >= when {
return;
}
let delay = calc_next_delay(now, when);
runtime.sleep(delay).await;
}
}
/// Return the amount of time we should wait next, when running
/// sleep_until_wallclock().
///
/// (This is a separate function for testing.)
fn calc_next_delay(now: SystemTime, when: SystemTime) -> Duration {
/// We never sleep more than this much, in case our system clock jumps
const MAX_SLEEP: Duration = Duration::from_secs(600);
let remainder = when
.duration_since(now)
.unwrap_or_else(|_| Duration::from_secs(0));
std::cmp::min(MAX_SLEEP, remainder)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn sleep_delay() {
use calc_next_delay as calc;
let minute = Duration::from_secs(60);
let second = Duration::from_secs(1);
let start = SystemTime::now();
let target = start + 30 * minute;
assert_eq!(calc(start, target), minute * 10);
assert_eq!(calc(target + minute, target), minute * 0);
assert_eq!(calc(target, target), minute * 0);
assert_eq!(calc(target - second, target), second);
assert_eq!(calc(target - minute * 9, target), minute * 9);
assert_eq!(calc(target - minute * 11, target), minute * 10);
}
}
}

107
tor-rtcompat/src/timer.rs Normal file
View File

@ -0,0 +1,107 @@
use crate::traits::SleepProvider;
use async_trait::async_trait;
use futures::Future;
use pin_project::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
time::{Duration, SystemTime},
};
#[derive(Copy, Clone, Debug)]
pub struct TimeoutError;
impl std::error::Error for TimeoutError {}
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Timeout expired")
}
}
#[async_trait]
pub trait SleepProviderExt: SleepProvider {
fn timeout<F: Future>(&self, duration: Duration, future: F) -> Timeout<F, Self::SleepFuture> {
let sleep_future = self.sleep(duration);
Timeout {
future,
sleep_future,
}
}
/// Pause until the wall-clock is at `when` or later, trying to
/// recover from clock jumps.
async fn sleep_until_wallclock(&self, when: SystemTime) {
loop {
let now = SystemTime::now();
if now >= when {
return;
}
let delay = calc_next_delay(now, when);
self.sleep(delay).await;
}
}
}
impl<T: SleepProvider> SleepProviderExt for T {}
#[pin_project]
pub struct Timeout<T, S> {
#[pin]
future: T,
#[pin]
sleep_future: S,
}
impl<T, S> Future for Timeout<T, S>
where
T: Future,
S: Future<Output = ()>,
{
type Output = Result<T::Output, TimeoutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(x) = this.future.poll(cx) {
return Poll::Ready(Ok(x));
}
match this.sleep_future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(()) => Poll::Ready(Err(TimeoutError)),
}
}
}
/// Return the amount of time we should wait next, when running
/// sleep_until_wallclock().
///
/// (This is a separate function for testing.)
fn calc_next_delay(now: SystemTime, when: SystemTime) -> Duration {
/// We never sleep more than this much, in case our system clock jumps
const MAX_SLEEP: Duration = Duration::from_secs(600);
let remainder = when
.duration_since(now)
.unwrap_or_else(|_| Duration::from_secs(0));
std::cmp::min(MAX_SLEEP, remainder)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn sleep_delay() {
use calc_next_delay as calc;
let minute = Duration::from_secs(60);
let second = Duration::from_secs(1);
let start = SystemTime::now();
let target = start + 30 * minute;
assert_eq!(calc(start, target), minute * 10);
assert_eq!(calc(target + minute, target), minute * 0);
assert_eq!(calc(target, target), minute * 0);
assert_eq!(calc(target - second, target), second);
assert_eq!(calc(target - minute * 9, target), minute * 9);
assert_eq!(calc(target - minute * 11, target), minute * 10);
}
}