examples/hook-tcp: add some comments, rework lifetimes a bit
Try to make the `hook-tcp` example a bit easier to read by adding/changing comments, and renaming the lifetimes for `async_trait`-generated trait methods.
This commit is contained in:
parent
ce679ad72a
commit
9960064720
|
@ -1,3 +1,10 @@
|
|||
//! This example showcases using a custom [`TcpProvider`] to do custom actions before Arti initiates
|
||||
//! TCP connections, and after the connections are closed.
|
||||
//!
|
||||
//! This might be useful, for example, to dynamically open ports on a restrictive firewall or modify
|
||||
//! routing information. It would also be possible to adapt the example to make it proxy the TCP
|
||||
//! connections somehow, depending on your usecase.
|
||||
|
||||
use std::future::Future;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
|
@ -15,18 +22,15 @@ use tor_rtcompat::{CompoundRuntime, TcpListener, TcpProvider};
|
|||
|
||||
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// This example showcase using a custom TcpProvider to get a hook before connections are initiated
|
||||
// and after they are closed. This might be useful in situations where you open dynamicaly ports
|
||||
// on a very restrictive firewall, or set custom routing rules to force all traffic to arti somehow,
|
||||
// and don't want arti to be sent to itself.
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = TorClientConfig::default();
|
||||
let rt = TokioNativeTlsRuntime::current()?;
|
||||
// Instantiate our custom TCP provider (see implementation below).
|
||||
let tcp_rt = CustomTcpProvider { inner: rt.clone() };
|
||||
// Create a `CompoundRuntime`, swapping out the TCP part of the runtime for our custom one.
|
||||
let rt = CompoundRuntime::new(rt.clone(), rt, tcp_rt, NativeTlsProvider::default());
|
||||
|
||||
eprintln!("connecting to Tor...");
|
||||
|
@ -52,28 +56,43 @@ async fn main() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// A custom TCP provider that relies on an existing TCP provider (`inner`), but modifies its
|
||||
/// behavior.
|
||||
struct CustomTcpProvider<T> {
|
||||
/// The underlying TCP provider.
|
||||
inner: T,
|
||||
}
|
||||
|
||||
/// A custom TCP stream that wraps another TCP provider's TCP stream type, letting us do things
|
||||
/// when the stream is read from, written to, or closed.
|
||||
struct CustomTcpStream<T> {
|
||||
/// The underlying TCP stream.
|
||||
inner: T,
|
||||
/// The address of the remote peer at the other end of this stream.
|
||||
addr: SocketAddr,
|
||||
/// The current state of the socket: whether it is open, in the process of closing, or closed.
|
||||
state: TcpState,
|
||||
}
|
||||
|
||||
/// An enum representing states a TCP stream can be in.
|
||||
#[derive(PartialEq)]
|
||||
enum TcpState {
|
||||
/// Stream is open.
|
||||
Open,
|
||||
/// We've sent a close, but haven't received one.
|
||||
SendClosed,
|
||||
/// We've received a close, but haven't sent one.
|
||||
RecvClosed,
|
||||
/// Stream is fully closed.
|
||||
Closed,
|
||||
}
|
||||
|
||||
/// A wrapper over a `TcpListener`.
|
||||
struct CustomTcpListener<T> {
|
||||
inner: T,
|
||||
}
|
||||
|
||||
/// An `Incoming` type for our `CustomTcpListener`.
|
||||
struct CustomIncoming<T> {
|
||||
inner: T,
|
||||
}
|
||||
|
@ -85,16 +104,20 @@ where
|
|||
type TcpStream = CustomTcpStream<T::TcpStream>;
|
||||
type TcpListener = CustomTcpListener<T::TcpListener>;
|
||||
|
||||
// using a manual implementation is required to have Send+Sync when using reference to self
|
||||
fn connect<'life0, 'life1, 'async_trait>(
|
||||
&'life0 self,
|
||||
addr: &'life1 SocketAddr,
|
||||
) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpStream>> + Send + 'async_trait>>
|
||||
// This is an async trait method (using the `async_trait` crate). We manually implement it
|
||||
// here so that we don't borrow `self` for too long.
|
||||
// (The lifetimes are explicit and somewhat ugly because that's how `async_trait` works.)
|
||||
fn connect<'a, 'b, 'c>(
|
||||
&'a self,
|
||||
addr: &'b SocketAddr,
|
||||
) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpStream>> + Send + 'c>>
|
||||
where
|
||||
'life0: 'async_trait,
|
||||
'life1: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
'a: 'c,
|
||||
'b: 'c,
|
||||
Self: 'c,
|
||||
{
|
||||
// Use the underlying TCP provider implementation to do the connection, and
|
||||
// return our wrapper around it once done.
|
||||
println!("tcp connect to {}", addr);
|
||||
self.inner
|
||||
.connect(addr)
|
||||
|
@ -108,15 +131,18 @@ where
|
|||
.boxed()
|
||||
}
|
||||
|
||||
fn listen<'life0, 'life1, 'async_trait>(
|
||||
&'life0 self,
|
||||
addr: &'life1 SocketAddr,
|
||||
) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpListener>> + Send + 'async_trait>>
|
||||
// This is also an async trait method (see above).
|
||||
fn listen<'a, 'b, 'c>(
|
||||
&'a self,
|
||||
addr: &'b SocketAddr,
|
||||
) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpListener>> + Send + 'c>>
|
||||
where
|
||||
'life0: 'async_trait,
|
||||
'life1: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
'a: 'c,
|
||||
'b: 'c,
|
||||
Self: 'c,
|
||||
{
|
||||
// Use the underlying TCP provider implementation to make the listener, and
|
||||
// return our wrapper around it once done.
|
||||
println!("tcp listen on {}", addr);
|
||||
self.inner
|
||||
.listen(addr)
|
||||
|
@ -125,6 +151,11 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
// We implement `AsyncRead` and `AsyncWrite` for our custom TCP stream object.
|
||||
// This implementation mostly uses the underlying stream's methods, but we insert some
|
||||
// code to check for a zero-byte read (indicating stream closure), and callers closing the
|
||||
// stream, and use that to update our `TcpState`.
|
||||
// When we detect that the stream is closed, we run some code (in this case, just a `println!`).
|
||||
impl<T> AsyncRead for CustomTcpStream<T>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
|
@ -134,12 +165,20 @@ where
|
|||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<IoResult<usize>> {
|
||||
// Call the underlying stream's method.
|
||||
let res = Pin::new(&mut self.inner).poll_read(cx, buf);
|
||||
|
||||
// Check for a zero-byte read, indicating closure.
|
||||
if let Poll::Ready(Ok(0)) = res {
|
||||
if !buf.is_empty() {
|
||||
match self.state {
|
||||
// If we're already closed, do nothing.
|
||||
TcpState::Closed | TcpState::RecvClosed => (),
|
||||
// We're open, and haven't tried to close the stream yet, so note that
|
||||
// the other side closed it.
|
||||
TcpState::Open => self.state = TcpState::RecvClosed,
|
||||
// We've closed the stream on our end, and the other side has now closed it
|
||||
// too, so the stream is now fully closed.
|
||||
TcpState::SendClosed => {
|
||||
println!("closed a connecion to {}", self.addr);
|
||||
self.state = TcpState::Closed;
|
||||
|
@ -150,12 +189,14 @@ where
|
|||
res
|
||||
}
|
||||
|
||||
// Do the same thing, but for `poll_read_vectored`.
|
||||
fn poll_read_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &mut [std::io::IoSliceMut<'_>],
|
||||
) -> Poll<IoResult<usize>> {
|
||||
let res = Pin::new(&mut self.inner).poll_read_vectored(cx, bufs);
|
||||
|
||||
if let Poll::Ready(Ok(0)) = res {
|
||||
if bufs.iter().any(|buf| !buf.is_empty()) {
|
||||
match self.state {
|
||||
|
@ -172,6 +213,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
// The only thing that's custom here is checking for closure. Everything else is just calling
|
||||
// `self.inner`.
|
||||
impl<T> AsyncWrite for CustomTcpStream<T>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
|
@ -195,7 +238,7 @@ where
|
|||
TcpState::Closed | TcpState::SendClosed => (),
|
||||
TcpState::Open => self.state = TcpState::SendClosed,
|
||||
TcpState::RecvClosed => {
|
||||
println!("closed a connecion to {}", self.addr);
|
||||
println!("closed a connection to {}", self.addr);
|
||||
self.state = TcpState::Closed;
|
||||
}
|
||||
}
|
||||
|
@ -215,7 +258,7 @@ where
|
|||
impl<T> Drop for CustomTcpStream<T> {
|
||||
fn drop(&mut self) {
|
||||
if self.state != TcpState::Closed {
|
||||
println!("closed a connecion to {}", self.addr);
|
||||
println!("closed a connection to {}", self.addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -229,13 +272,15 @@ where
|
|||
type TcpStream = CustomTcpStream<T::TcpStream>;
|
||||
type Incoming = CustomIncoming<T::Incoming>;
|
||||
|
||||
fn accept<'life0, 'async_trait>(
|
||||
&'life0 self,
|
||||
) -> Pin<Box<dyn Future<Output = AcceptResult<Self::TcpStream>> + Send + 'async_trait>>
|
||||
// This is also an async trait method (see earlier commentary).
|
||||
fn accept<'a, 'b>(
|
||||
&'a self,
|
||||
) -> Pin<Box<dyn Future<Output = AcceptResult<Self::TcpStream>> + Send + 'b>>
|
||||
where
|
||||
'life0: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
'a: 'b,
|
||||
Self: 'b,
|
||||
{
|
||||
// As with other implementations, we just defer to `self.inner` and wrap the result.
|
||||
self.inner
|
||||
.accept()
|
||||
.inspect(|r| {
|
||||
|
@ -271,7 +316,7 @@ where
|
|||
|
||||
impl<T, S> Stream for CustomIncoming<T>
|
||||
where
|
||||
T: Stream<Item = IoResult<(S, SocketAddr)>> + std::marker::Unpin,
|
||||
T: Stream<Item = IoResult<(S, SocketAddr)>> + Unpin,
|
||||
{
|
||||
type Item = IoResult<(CustomTcpStream<S>, SocketAddr)>;
|
||||
|
||||
|
|
Loading…
Reference in New Issue