diff --git a/Cargo.lock b/Cargo.lock index bc4e5ba36..57569292d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2679,7 +2679,6 @@ name = "tor-circmgr" version = "0.0.3" dependencies = [ "async-trait", - "async_executors", "bounded-vec-deque", "derive_builder", "futures", diff --git a/crates/tor-circmgr/Cargo.toml b/crates/tor-circmgr/Cargo.toml index 196b0f2f9..3783d0a90 100644 --- a/crates/tor-circmgr/Cargo.toml +++ b/crates/tor-circmgr/Cargo.toml @@ -44,7 +44,6 @@ thiserror = "1" weak-table = "0.3.0" [dev-dependencies] -async_executors = { version = "0.4", default-features = false, features = [ "tokio_tp" ] } futures-await-test = "0.3.0" tor-rtmock = { path="../tor-rtmock", version = "0.0.3"} tor-guardmgr = { path="../tor-guardmgr", version = "0.0.3", features=["testing"]} diff --git a/crates/tor-circmgr/src/lib.rs b/crates/tor-circmgr/src/lib.rs index 5d8d51380..fdc6501b5 100644 --- a/crates/tor-circmgr/src/lib.rs +++ b/crates/tor-circmgr/src/lib.rs @@ -440,7 +440,7 @@ mod test { /// Helper type used to help type inference. pub(crate) type OptDummyGuardMgr<'a> = - Option<&'a tor_guardmgr::GuardMgr>; + Option<&'a tor_guardmgr::GuardMgr>; #[test] fn get_params() { diff --git a/crates/tor-rtcompat/src/async_std.rs b/crates/tor-rtcompat/src/async_std.rs index 891e3db39..f78d5f782 100644 --- a/crates/tor-rtcompat/src/async_std.rs +++ b/crates/tor-rtcompat/src/async_std.rs @@ -1,9 +1,24 @@ //! Entry points for use with async_std runtimes. -pub use crate::impls::async_std::create_runtime as create_async_std_runtime; -use crate::SpawnBlocking; +pub use crate::impls::async_std::create_runtime as create_runtime_impl; +use crate::{compound::CompoundRuntime, SpawnBlocking}; -/// A [`Runtime`](crate::Runtime) powered by async-std. -pub use async_executors::AsyncStd as AsyncStdRuntime; +use crate::impls::async_std::NativeTlsAsyncStd; + +use async_executors::AsyncStd; + +/// A [`Runtime`](crate::Runtime) powered by `async_std` and `native_tls`. +#[derive(Clone)] +pub struct AsyncStdRuntime { + /// The actual runtime object. + inner: Inner, +} + +/// Implementation type for AsyncStdRuntime. +type Inner = CompoundRuntime; + +crate::opaque::implement_opaque_runtime! { + AsyncStdRuntime { inner : Inner } +} /// Return a new async-std-based [`Runtime`](crate::Runtime). /// @@ -12,7 +27,10 @@ pub use async_executors::AsyncStd as AsyncStdRuntime; /// runtime. pub fn create_runtime() -> std::io::Result { - Ok(create_async_std_runtime()) + let rt = create_runtime_impl(); + Ok(AsyncStdRuntime { + inner: CompoundRuntime::new(rt, rt, rt, NativeTlsAsyncStd::default()), + }) } /// Try to return an instance of the currently running async_std @@ -28,6 +46,6 @@ where P: FnOnce(AsyncStdRuntime) -> F, F: futures::Future, { - let runtime = create_async_std_runtime(); - runtime.block_on(func(runtime)) + let runtime = current_runtime().expect("Couldn't get global async_std runtime?"); + runtime.clone().block_on(func(runtime)) } diff --git a/crates/tor-rtcompat/src/compound.rs b/crates/tor-rtcompat/src/compound.rs new file mode 100644 index 000000000..f5bd94ab4 --- /dev/null +++ b/crates/tor-rtcompat/src/compound.rs @@ -0,0 +1,131 @@ +//! Define a [`CompoundRuntime`] part that can be built from several component +//! pieces. + +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use crate::traits::*; +use async_trait::async_trait; +use futures::{future::FutureObj, task::Spawn}; +use std::io::Result as IoResult; + +/// A runtime made of several parts, each of which implements one trait-group. +/// +/// The `SpawnR` component should implements [`Spawn`] and [`SpawnBlocking`]; +/// the `SleepR` component should implement [`SleepProvider`]; the `TcpR` +/// component should implement [`TcpProvider`]; and the `TlsR` component should +/// implement [`TlsProvider`]. +/// +/// You can use this structure to create new runtimes in two ways: either by +/// overriding a single part of an existing runtime, or by building an entirely +/// new runtime from pieces. +#[derive(Clone)] +pub struct CompoundRuntime { + /// The actual collection of Runtime objects. + /// + /// We wrap this in an Arc rather than requiring that each item implement + /// Clone, though we could change our minds later on. + inner: Arc>, +} + +/// A collection of objects implementing that traits that make up a [`Runtime`] +struct Inner { + /// A `Spawn` and `SpawnBlocking` implementation. + spawn: SpawnR, + /// A `SleepProvider` implementation. + sleep: SleepR, + /// A `TcpProvider` implementation + tcp: TcpR, + /// A `TcpProvider` implementation. + tls: TlsR, +} + +impl CompoundRuntime { + /// Construct a new CompoundRuntime from its components. + pub fn new(spawn: SpawnR, sleep: SleepR, tcp: TcpR, tls: TlsR) -> Self { + CompoundRuntime { + inner: Arc::new(Inner { + spawn, + sleep, + tcp, + tls, + }), + } + } +} + +impl Spawn for CompoundRuntime +where + SpawnR: Spawn, +{ + #[inline] + fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> { + self.inner.spawn.spawn_obj(future) + } +} + +impl SpawnBlocking for CompoundRuntime +where + SpawnR: SpawnBlocking, +{ + #[inline] + fn block_on(&self, future: F) -> F::Output { + self.inner.spawn.block_on(future) + } +} + +impl SleepProvider for CompoundRuntime +where + SleepR: SleepProvider, +{ + type SleepFuture = SleepR::SleepFuture; + + #[inline] + fn sleep(&self, duration: Duration) -> Self::SleepFuture { + self.inner.sleep.sleep(duration) + } +} + +#[async_trait] +impl TcpProvider for CompoundRuntime +where + TcpR: TcpProvider, + SpawnR: Send + Sync + 'static, + SleepR: Send + Sync + 'static, + TcpR: Send + Sync + 'static, + TlsR: Send + Sync + 'static, +{ + type TcpStream = TcpR::TcpStream; + + type TcpListener = TcpR::TcpListener; + + #[inline] + async fn connect(&self, addr: &SocketAddr) -> IoResult { + self.inner.tcp.connect(addr).await + } + + #[inline] + async fn listen(&self, addr: &SocketAddr) -> IoResult { + self.inner.tcp.listen(addr).await + } +} + +impl TlsProvider + for CompoundRuntime +where + TcpR: TcpProvider, + TlsR: TlsProvider, +{ + type Connector = TlsR::Connector; + type TlsStream = TlsR::TlsStream; + + #[inline] + fn tls_connector(&self) -> Self::Connector { + self.inner.tls.tls_connector() + } +} + +impl std::fmt::Debug for CompoundRuntime { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CompoundRuntime").finish_non_exhaustive() + } +} diff --git a/crates/tor-rtcompat/src/impls/async_std.rs b/crates/tor-rtcompat/src/impls/async_std.rs index 8712f1357..2621960b9 100644 --- a/crates/tor-rtcompat/src/impls/async_std.rs +++ b/crates/tor-rtcompat/src/impls/async_std.rs @@ -205,7 +205,12 @@ impl SpawnBlocking for async_executors::AsyncStd { } } -impl TlsProvider for async_executors::AsyncStd { +/// A TlsProvider that uses native_tls and works with the AsyncStd executor. +#[derive(Clone, Debug, Default)] +#[non_exhaustive] +pub struct NativeTlsAsyncStd {} + +impl TlsProvider for NativeTlsAsyncStd { type TlsStream = tls::TlsStream; type Connector = tls::TlsConnector; diff --git a/crates/tor-rtcompat/src/impls/tokio.rs b/crates/tor-rtcompat/src/impls/tokio.rs index e1d02679a..a7e318457 100644 --- a/crates/tor-rtcompat/src/impls/tokio.rs +++ b/crates/tor-rtcompat/src/impls/tokio.rs @@ -219,6 +219,29 @@ mod tls { // ============================== +/// A TlsProvider that uses native_tls and works with the Tokio executor. +#[derive(Clone, Debug, Default)] +#[non_exhaustive] +pub struct NativeTlsTokio {} + +impl TlsProvider for NativeTlsTokio { + type TlsStream = tls::TlsStream; + type Connector = tls::TlsConnector; + + fn tls_connector(&self) -> tls::TlsConnector { + let mut builder = native_tls::TlsConnector::builder(); + // These function names are scary, but they just mean that we + // aren't checking whether the signer of this cert + // participates in the web PKI, and we aren't checking the + // hostname in the cert. + builder + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true); + + builder.try_into().expect("Couldn't build a TLS connector!") + } +} + use crate::traits::*; use async_trait::async_trait; use futures::Future; @@ -237,24 +260,6 @@ macro_rules! implement_traits_for { } } - impl TlsProvider for $runtime { - type TlsStream = tls::TlsStream; - type Connector = tls::TlsConnector; - - fn tls_connector(&self) -> tls::TlsConnector { - let mut builder = native_tls::TlsConnector::builder(); - // These function names are scary, but they just mean that we - // aren't checking whether the signer of this cert - // participates in the web PKI, and we aren't checking the - // hostname in the cert. - builder - .danger_accept_invalid_certs(true) - .danger_accept_invalid_hostnames(true); - - builder.try_into().expect("Couldn't build a TLS connector!") - } - } - #[async_trait] impl crate::traits::TcpProvider for $runtime { type TcpStream = net::TcpStream; @@ -273,7 +278,7 @@ macro_rules! implement_traits_for { } /// Create and return a new Tokio multithreaded runtime. -pub fn create_runtime() -> IoResult { +pub(crate) fn create_runtime() -> IoResult { let mut builder = async_executors::TokioTpBuilder::new(); builder.tokio_builder().enable_all(); builder.build() @@ -281,11 +286,16 @@ pub fn create_runtime() -> IoResult { /// Wrapper around a Handle to a tokio runtime. /// +/// Ideally, this type would go away, and we would just use +/// `tokio::runtime::Handle` directly. Unfortunately, we can't implement +/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need +/// to define a new type here. +/// /// # Limitations /// -/// Note that Arti requires that the runtime should have working -/// implementations for Tokio's time, net, and io facilities, but we have -/// no good way to check that when creating this object. +/// Note that Arti requires that the runtime should have working implementations +/// for Tokio's time, net, and io facilities, but we have no good way to check +/// that when creating this object. #[derive(Clone, Debug)] pub struct TokioRuntimeHandle { /// The underlying Handle. @@ -300,7 +310,7 @@ impl TokioRuntimeHandle { /// Note that Arti requires that the runtime should have working /// implementations for Tokio's time, net, and io facilities, but we have /// no good way to check that when creating this object. - pub fn new(handle: tokio_crate::runtime::Handle) -> Self { + pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self { handle.into() } } diff --git a/crates/tor-rtcompat/src/lib.rs b/crates/tor-rtcompat/src/lib.rs index 71ab84414..da7ca44d2 100644 --- a/crates/tor-rtcompat/src/lib.rs +++ b/crates/tor-rtcompat/src/lib.rs @@ -61,8 +61,8 @@ //! using for anything besides Arti, you can use [`create_runtime()`]. //! //! * If you want to explicitly construct a runtime with a specific -//! backend, you can do so with [`async_std::create_async_std_runtime`] or -//! [`tokio::create_tokio_runtime`]. Or if you have already constructed a +//! backend, you can do so with [`async_std::create_runtime`] or +//! [`tokio::create_runtime`]. Or if you have already constructed a //! tokio runtime that you want to use, you can wrap it as a //! [`Runtime`] explicitly with [`tokio::TokioRuntimeHandle`]. //! @@ -143,6 +143,8 @@ pub(crate) mod impls; pub mod task; +mod compound; +mod opaque; mod timer; mod traits; @@ -167,6 +169,8 @@ pub mod tokio; #[cfg(feature = "async-std")] pub mod async_std; +pub use compound::CompoundRuntime; + /// Try to return an instance of the currently running [`Runtime`]. /// /// # Limitations @@ -204,7 +208,7 @@ pub fn current_user_runtime() -> std::io::Result { /// /// Tokio users may want to avoid this function and instead make a /// runtime using [`current_user_runtime()`] or -/// [`tokio::TokioRuntimeHandle::new()`]: this function always _builds_ a +/// [`tokio::current_runtime()`]: this function always _builds_ a /// runtime, and if you already have a runtime, that isn't what you /// want with Tokio. /// diff --git a/crates/tor-rtcompat/src/opaque.rs b/crates/tor-rtcompat/src/opaque.rs new file mode 100644 index 000000000..111b73773 --- /dev/null +++ b/crates/tor-rtcompat/src/opaque.rs @@ -0,0 +1,75 @@ +//! Declare a macro for making opaque runtime wrappers. + +/// Implement delegating implementations of the runtime traits for a type $t +/// whose member $r implements Runtime. Used to hide the details of the +/// implementation of $t. +#[allow(unused)] // Can be unused if no runtimes are declared. +macro_rules! implement_opaque_runtime { +{ + $t:ty { $member:ident : $mty:ty } +} => { + + impl futures::task::Spawn for $t { + #[inline] + fn spawn_obj(&self, future: futures::future::FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> { + self.$member.spawn_obj(future) + } + } + + impl $crate::traits::SpawnBlocking for $t { + #[inline] + fn block_on(&self, future: F) -> F::Output { + self.$member.block_on(future) + } + + } + + impl $crate::traits::SleepProvider for $t { + type SleepFuture = <$mty as $crate::traits::SleepProvider>::SleepFuture; + #[inline] + fn sleep(&self, duration: std::time::Duration) -> Self::SleepFuture { + self.$member.sleep(duration) + } + } + + #[async_trait::async_trait] + impl $crate::traits::TcpProvider for $t { + type TcpStream = <$mty as $crate::traits::TcpProvider>::TcpStream; + type TcpListener = <$mty as $crate::traits::TcpProvider>::TcpListener; + #[inline] + async fn connect(&self, addr: &std::net::SocketAddr) -> std::io::Result { + self.$member.connect(addr).await + } + #[inline] + async fn listen(&self, addr: &std::net::SocketAddr) -> std::io::Result { + self.$member.listen(addr).await + } + } + + impl $crate::traits::TlsProvider<<$t as $crate::traits::TcpProvider>::TcpStream> for $t { + type Connector = <$mty as $crate::traits::TlsProvider<<$t as $crate::traits::TcpProvider>::TcpStream>>::Connector; + type TlsStream = <$mty as $crate::traits::TlsProvider<<$t as $crate::traits::TcpProvider>::TcpStream>>::TlsStream; + #[inline] + fn tls_connector(&self) -> Self::Connector { + self.$member.tls_connector() + } + } + + impl std::fmt::Debug for $t { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(stringify!($t)).finish_non_exhaustive() + } + } + + // This boilerplate will fail unless $t implements Runtime. + const _ : () = { + fn assert_runtime() {} + fn check() { + assert_runtime::<$t>(); + } + }; +} +} + +#[allow(unused)] // Can be unused if no runtimes are declared. +pub(crate) use implement_opaque_runtime; diff --git a/crates/tor-rtcompat/src/tokio.rs b/crates/tor-rtcompat/src/tokio.rs index 19653624c..35ec45110 100644 --- a/crates/tor-rtcompat/src/tokio.rs +++ b/crates/tor-rtcompat/src/tokio.rs @@ -1,9 +1,59 @@ //! Entry points for use with Tokio runtimes. -pub use crate::impls::tokio::create_runtime as create_tokio_runtime; -pub use crate::impls::tokio::TokioRuntimeHandle; +use crate::impls::tokio::{NativeTlsTokio, TokioRuntimeHandle as Handle}; +use async_executors::TokioTp; -use crate::Runtime; -use std::io::{Error as IoError, ErrorKind}; +use crate::{CompoundRuntime, Runtime, SpawnBlocking}; +use std::io::{Error as IoError, ErrorKind, Result as IoResult}; + +/// A [`Runtime`] built around a Handle to a tokio runtime, and `native_tls`. +/// +/// # Limitations +/// +/// Note that Arti requires that the runtime should have working +/// implementations for Tokio's time, net, and io facilities, but we have +/// no good way to check that when creating this object. +#[derive(Clone)] +pub struct TokioRuntimeHandle { + /// The actual [`CompoundRuntime`] that implements this. + inner: HandleInner, +} + +/// Implementation type for a TokioRuntimeHandle. +type HandleInner = CompoundRuntime; + +/// A [`Runtime`] built around an owned `TokioTp` executor, and `native_tls`. +#[derive(Clone)] +pub struct TokioRuntime { + /// The actual [`CompoundRuntime`] that implements this. + inner: TokioRuntimeInner, +} + +/// Implementation type for TokioRuntime. +type TokioRuntimeInner = CompoundRuntime; + +crate::opaque::implement_opaque_runtime! { + TokioRuntimeHandle { inner : HandleInner } +} + +crate::opaque::implement_opaque_runtime! { + TokioRuntime { inner : TokioRuntimeInner } +} + +impl From for TokioRuntimeHandle { + fn from(h: tokio_crate::runtime::Handle) -> Self { + let h = Handle::new(h); + TokioRuntimeHandle { + inner: CompoundRuntime::new(h.clone(), h.clone(), h, NativeTlsTokio::default()), + } + } +} + +/// Create and return a new Tokio multithreaded runtime. +fn create_tokio_runtime() -> IoResult { + crate::impls::tokio::create_runtime().map(|r| TokioRuntime { + inner: CompoundRuntime::new(r.clone(), r.clone(), r, NativeTlsTokio::default()), + }) +} /// Create a new Tokio-based [`Runtime`]. /// @@ -12,8 +62,7 @@ use std::io::{Error as IoError, ErrorKind}; /// runtime. /// /// Tokio users may want to avoid this function and instead make a -/// runtime using [`current_runtime()`] or -/// [`TokioRuntimeHandle::new()`]: this function always _builds_ a +/// runtime using [`current_runtime()`]: this function always _builds_ a /// runtime, and if you already have a runtime, that isn't what you /// want with Tokio. pub fn create_runtime() -> std::io::Result { @@ -35,7 +84,10 @@ pub fn create_runtime() -> std::io::Result { pub fn current_runtime() -> std::io::Result { let handle = tokio_crate::runtime::Handle::try_current() .map_err(|e| IoError::new(ErrorKind::Other, e))?; - Ok(TokioRuntimeHandle::new(handle)) + let h = Handle::new(handle); + Ok(TokioRuntimeHandle { + inner: CompoundRuntime::new(h.clone(), h.clone(), h, NativeTlsTokio::default()), + }) } /// Run a test function using a freshly created tokio runtime. @@ -45,9 +97,9 @@ pub fn current_runtime() -> std::io::Result { /// Panics if we can't create a tokio runtime. pub fn test_with_runtime(func: P) -> O where - P: FnOnce(async_executors::TokioTp) -> F, + P: FnOnce(TokioRuntime) -> F, F: futures::Future, { let runtime = create_tokio_runtime().expect("Failed to create a tokio runtime"); - runtime.block_on(func(runtime.clone())) + runtime.clone().block_on(func(runtime)) }