From 9960064720a74691490ad0c78f834018bdc84771 Mon Sep 17 00:00:00 2001 From: eta Date: Wed, 23 Feb 2022 13:59:03 +0000 Subject: [PATCH] examples/hook-tcp: add some comments, rework lifetimes a bit Try to make the `hook-tcp` example a bit easier to read by adding/changing comments, and renaming the lifetimes for `async_trait`-generated trait methods. --- crates/arti-client/examples/hook-tcp.rs | 101 +++++++++++++++++------- 1 file changed, 73 insertions(+), 28 deletions(-) diff --git a/crates/arti-client/examples/hook-tcp.rs b/crates/arti-client/examples/hook-tcp.rs index a008ffaa2..4a4db0370 100644 --- a/crates/arti-client/examples/hook-tcp.rs +++ b/crates/arti-client/examples/hook-tcp.rs @@ -1,3 +1,10 @@ +//! This example showcases using a custom [`TcpProvider`] to do custom actions before Arti initiates +//! TCP connections, and after the connections are closed. +//! +//! This might be useful, for example, to dynamically open ports on a restrictive firewall or modify +//! routing information. It would also be possible to adapt the example to make it proxy the TCP +//! connections somehow, depending on your usecase. + use std::future::Future; use std::io::Result as IoResult; use std::net::SocketAddr; @@ -15,18 +22,15 @@ use tor_rtcompat::{CompoundRuntime, TcpListener, TcpProvider}; use futures::io::{AsyncReadExt, AsyncWriteExt}; -// This example showcase using a custom TcpProvider to get a hook before connections are initiated -// and after they are closed. This might be useful in situations where you open dynamicaly ports -// on a very restrictive firewall, or set custom routing rules to force all traffic to arti somehow, -// and don't want arti to be sent to itself. - #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let config = TorClientConfig::default(); let rt = TokioNativeTlsRuntime::current()?; + // Instantiate our custom TCP provider (see implementation below). let tcp_rt = CustomTcpProvider { inner: rt.clone() }; + // Create a `CompoundRuntime`, swapping out the TCP part of the runtime for our custom one. let rt = CompoundRuntime::new(rt.clone(), rt, tcp_rt, NativeTlsProvider::default()); eprintln!("connecting to Tor..."); @@ -52,28 +56,43 @@ async fn main() -> Result<()> { Ok(()) } +/// A custom TCP provider that relies on an existing TCP provider (`inner`), but modifies its +/// behavior. struct CustomTcpProvider { + /// The underlying TCP provider. inner: T, } +/// A custom TCP stream that wraps another TCP provider's TCP stream type, letting us do things +/// when the stream is read from, written to, or closed. struct CustomTcpStream { + /// The underlying TCP stream. inner: T, + /// The address of the remote peer at the other end of this stream. addr: SocketAddr, + /// The current state of the socket: whether it is open, in the process of closing, or closed. state: TcpState, } +/// An enum representing states a TCP stream can be in. #[derive(PartialEq)] enum TcpState { + /// Stream is open. Open, + /// We've sent a close, but haven't received one. SendClosed, + /// We've received a close, but haven't sent one. RecvClosed, + /// Stream is fully closed. Closed, } +/// A wrapper over a `TcpListener`. struct CustomTcpListener { inner: T, } +/// An `Incoming` type for our `CustomTcpListener`. struct CustomIncoming { inner: T, } @@ -85,16 +104,20 @@ where type TcpStream = CustomTcpStream; type TcpListener = CustomTcpListener; - // using a manual implementation is required to have Send+Sync when using reference to self - fn connect<'life0, 'life1, 'async_trait>( - &'life0 self, - addr: &'life1 SocketAddr, - ) -> Pin> + Send + 'async_trait>> + // This is an async trait method (using the `async_trait` crate). We manually implement it + // here so that we don't borrow `self` for too long. + // (The lifetimes are explicit and somewhat ugly because that's how `async_trait` works.) + fn connect<'a, 'b, 'c>( + &'a self, + addr: &'b SocketAddr, + ) -> Pin> + Send + 'c>> where - 'life0: 'async_trait, - 'life1: 'async_trait, - Self: 'async_trait, + 'a: 'c, + 'b: 'c, + Self: 'c, { + // Use the underlying TCP provider implementation to do the connection, and + // return our wrapper around it once done. println!("tcp connect to {}", addr); self.inner .connect(addr) @@ -108,15 +131,18 @@ where .boxed() } - fn listen<'life0, 'life1, 'async_trait>( - &'life0 self, - addr: &'life1 SocketAddr, - ) -> Pin> + Send + 'async_trait>> + // This is also an async trait method (see above). + fn listen<'a, 'b, 'c>( + &'a self, + addr: &'b SocketAddr, + ) -> Pin> + Send + 'c>> where - 'life0: 'async_trait, - 'life1: 'async_trait, - Self: 'async_trait, + 'a: 'c, + 'b: 'c, + Self: 'c, { + // Use the underlying TCP provider implementation to make the listener, and + // return our wrapper around it once done. println!("tcp listen on {}", addr); self.inner .listen(addr) @@ -125,6 +151,11 @@ where } } +// We implement `AsyncRead` and `AsyncWrite` for our custom TCP stream object. +// This implementation mostly uses the underlying stream's methods, but we insert some +// code to check for a zero-byte read (indicating stream closure), and callers closing the +// stream, and use that to update our `TcpState`. +// When we detect that the stream is closed, we run some code (in this case, just a `println!`). impl AsyncRead for CustomTcpStream where T: AsyncRead + Unpin, @@ -134,12 +165,20 @@ where cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { + // Call the underlying stream's method. let res = Pin::new(&mut self.inner).poll_read(cx, buf); + + // Check for a zero-byte read, indicating closure. if let Poll::Ready(Ok(0)) = res { if !buf.is_empty() { match self.state { + // If we're already closed, do nothing. TcpState::Closed | TcpState::RecvClosed => (), + // We're open, and haven't tried to close the stream yet, so note that + // the other side closed it. TcpState::Open => self.state = TcpState::RecvClosed, + // We've closed the stream on our end, and the other side has now closed it + // too, so the stream is now fully closed. TcpState::SendClosed => { println!("closed a connecion to {}", self.addr); self.state = TcpState::Closed; @@ -150,12 +189,14 @@ where res } + // Do the same thing, but for `poll_read_vectored`. fn poll_read_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [std::io::IoSliceMut<'_>], ) -> Poll> { let res = Pin::new(&mut self.inner).poll_read_vectored(cx, bufs); + if let Poll::Ready(Ok(0)) = res { if bufs.iter().any(|buf| !buf.is_empty()) { match self.state { @@ -172,6 +213,8 @@ where } } +// The only thing that's custom here is checking for closure. Everything else is just calling +// `self.inner`. impl AsyncWrite for CustomTcpStream where T: AsyncWrite + Unpin, @@ -195,7 +238,7 @@ where TcpState::Closed | TcpState::SendClosed => (), TcpState::Open => self.state = TcpState::SendClosed, TcpState::RecvClosed => { - println!("closed a connecion to {}", self.addr); + println!("closed a connection to {}", self.addr); self.state = TcpState::Closed; } } @@ -215,7 +258,7 @@ where impl Drop for CustomTcpStream { fn drop(&mut self) { if self.state != TcpState::Closed { - println!("closed a connecion to {}", self.addr); + println!("closed a connection to {}", self.addr); } } } @@ -229,13 +272,15 @@ where type TcpStream = CustomTcpStream; type Incoming = CustomIncoming; - fn accept<'life0, 'async_trait>( - &'life0 self, - ) -> Pin> + Send + 'async_trait>> + // This is also an async trait method (see earlier commentary). + fn accept<'a, 'b>( + &'a self, + ) -> Pin> + Send + 'b>> where - 'life0: 'async_trait, - Self: 'async_trait, + 'a: 'b, + Self: 'b, { + // As with other implementations, we just defer to `self.inner` and wrap the result. self.inner .accept() .inspect(|r| { @@ -271,7 +316,7 @@ where impl Stream for CustomIncoming where - T: Stream> + std::marker::Unpin, + T: Stream> + Unpin, { type Item = IoResult<(CustomTcpStream, SocketAddr)>;