diff --git a/crates/tor-async-utils/Cargo.toml b/crates/tor-async-utils/Cargo.toml index e4d9d4e1b..b27ffbbc5 100644 --- a/crates/tor-async-utils/Cargo.toml +++ b/crates/tor-async-utils/Cargo.toml @@ -21,5 +21,5 @@ void = "1" [dev-dependencies] futures-await-test = "0.3.0" -tokio = { version = "1.7", features = ["macros", "rt", "rt-multi-thread", "time"] } +tokio = { version = "1.7", features = ["macros", "net", "rt", "rt-multi-thread", "time"] } diff --git a/crates/tor-async-utils/src/join_read_write.rs b/crates/tor-async-utils/src/join_read_write.rs new file mode 100644 index 000000000..3ee5bf428 --- /dev/null +++ b/crates/tor-async-utils/src/join_read_write.rs @@ -0,0 +1,95 @@ +//! Join a readable and writeable into a single `AsyncRead` + `AsyncWrite` + +use std::io::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{AsyncRead, AsyncWrite}; +use pin_project::pin_project; + +/// Async readable/writeable that dispatches reads to `R` and writes to `W` +/// +/// `AsyncRead` is forwarded to `R`. +/// `AsyncWrite` is forwarded to `W`. +/// +/// [`JoinReadWrite::new()`] is the converse of +/// [`AsyncReadExt::split()`](futures::AsyncReadExt::split). +/// But, if `R` and `W` came from splitting a single `AsyncRead + AsyncWrite`, +/// you probably want the `reunite` or `unsplit` method, instead of `JoinReadWrite`. +/// +/// Does *not* implement any kind of flushing behaviour when switching between reading and writing. +/// +/// # Example +/// +/// ``` +/// # #[tokio::main] +/// # async fn main() { +/// use tor_async_utils::JoinReadWrite; +/// use futures::{AsyncReadExt as _, AsyncWriteExt as _}; +/// +/// let read = b"hello\n"; +/// let mut read = &read[..]; +/// let mut write = Vec::::new(); +/// +/// let mut joined = JoinReadWrite::new(read, write); +/// +/// let mut got = String::new(); +/// let _: usize = joined.read_to_string(&mut got).await.unwrap(); +/// assert_eq!(got, "hello\n"); +/// +/// let () = joined.write_all(b"some data").await.unwrap(); +/// +/// let (r, w) = joined.into_parts(); +/// assert_eq!(w, b"some data"); +/// # } +/// ``` +#[pin_project] +pub struct JoinReadWrite { + /// readable + #[pin] + r: R, + /// writeable + #[pin] + w: W, +} + +impl JoinReadWrite { + /// Join an `AsyncRead` and an `AsyncWrite` into a single `impl AsyncRead + AsyncWrite` + pub fn new(r: R, w: W) -> Self { + JoinReadWrite { r, w } + } + + /// Dismantle a `JoinReadWrite` into its constituent `AsyncRead` and `AsyncWrite` + pub fn into_parts(self) -> (R, W) { + let JoinReadWrite { r, w } = self; + (r, w) + } +} + +impl AsyncRead for JoinReadWrite { + fn poll_read( + self: Pin<&mut Self>, + c: &mut Context, + out: &mut [u8], + ) -> Poll> { + self.project().r.poll_read(c, out) + } +} + +impl AsyncWrite for JoinReadWrite { + fn poll_write( + self: Pin<&mut Self>, + c: &mut Context, + data: &[u8], + ) -> Poll> { + self.project().w.poll_write(c, data) + } + + fn poll_flush(self: Pin<&mut Self>, c: &mut Context) -> Poll> { + self.project().w.poll_flush(c) + } + + fn poll_close(self: Pin<&mut Self>, c: &mut Context) -> Poll> { + self.project().w.poll_close(c) + } +} diff --git a/crates/tor-async-utils/src/lib.rs b/crates/tor-async-utils/src/lib.rs index 125d1e5ef..3677059bb 100644 --- a/crates/tor-async-utils/src/lib.rs +++ b/crates/tor-async-utils/src/lib.rs @@ -38,9 +38,12 @@ #![allow(clippy::result_large_err)] // temporary workaround for arti#587 //! +mod join_read_write; mod sinkext; mod watch; +pub use join_read_write::*; + pub use sinkext::{SinkExt, SinkPrepareSendFuture, SinkSendable}; pub use watch::{DropNotifyEofSignallable, DropNotifyWatchSender, PostageWatchSenderExt};