diff --git a/Cargo.lock b/Cargo.lock index 1aad63304..d0fb45bb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2742,6 +2742,8 @@ dependencies = [ "rand_core 0.6.3", "subtle", "thiserror", + "tokio", + "tokio-util", "tor-bytes", "tor-cell", "tor-cert", diff --git a/crates/tor-client/Cargo.toml b/crates/tor-client/Cargo.toml index b6619c099..8fb8d4e2e 100644 --- a/crates/tor-client/Cargo.toml +++ b/crates/tor-client/Cargo.toml @@ -13,7 +13,7 @@ repository="https://gitlab.torproject.org/tpo/core/arti.git/" [features] default = [ "tokio" ] async-std = [ "tor-rtcompat/async-std" ] -tokio = [ "tor-rtcompat/tokio" ] +tokio = [ "tor-rtcompat/tokio", "tor-proto/tokio" ] static = [ "tor-rtcompat/static", "tor-dirmgr/static" ] experimental-api = [] diff --git a/crates/tor-proto/Cargo.toml b/crates/tor-proto/Cargo.toml index 9c26a5b5c..23051b440 100644 --- a/crates/tor-proto/Cargo.toml +++ b/crates/tor-proto/Cargo.toml @@ -14,6 +14,7 @@ repository="https://gitlab.torproject.org/tpo/core/arti.git/" default = [] hs = [] ntor_v3 = [] +tokio = [ "tokio-crate", "tokio-util" ] [dependencies] tor-llcrypto = { path="../tor-llcrypto", version="0.0.0" } @@ -43,6 +44,9 @@ thiserror = "1.0.24" typenum = "1.13.0" zeroize = "1.3.0" +tokio-crate = { package = "tokio", version = "1.7.0", optional = true } +tokio-util = { version = "0.6", features = ["compat"], optional = true } + [dev-dependencies] futures-await-test = "0.3.0" hex-literal = "0.3.1" diff --git a/crates/tor-proto/src/stream/data.rs b/crates/tor-proto/src/stream/data.rs index 3768fc195..10555d338 100644 --- a/crates/tor-proto/src/stream/data.rs +++ b/crates/tor-proto/src/stream/data.rs @@ -9,6 +9,13 @@ use futures::io::{AsyncRead, AsyncWrite}; use futures::task::{Context, Poll}; use futures::Future; +#[cfg(feature = "tokio")] +use tokio_crate::io::ReadBuf; +#[cfg(feature = "tokio")] +use tokio_crate::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}; +#[cfg(feature = "tokio")] +use tokio_util::compat::FuturesAsyncReadCompatExt; + use std::io::Result as IoResult; use std::pin::Pin; use std::sync::Arc; @@ -94,6 +101,17 @@ impl AsyncRead for DataStream { } } +#[cfg(feature = "tokio")] +impl TokioAsyncRead for DataStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + TokioAsyncRead::poll_read(Pin::new(&mut self.compat()), cx, buf) + } +} + impl AsyncWrite for DataStream { fn poll_write( mut self: Pin<&mut Self>, @@ -110,6 +128,21 @@ impl AsyncWrite for DataStream { } } +#[cfg(feature = "tokio")] +impl TokioAsyncWrite for DataStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + TokioAsyncWrite::poll_write(Pin::new(&mut self.compat()), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat()), cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat()), cx) + } +} + /// An enumeration for the state of a DataWriter. /// /// We have to use an enum here because, for as long as we're waiting