diff --git a/crates/tor-chanmgr/src/builder.rs b/crates/tor-chanmgr/src/builder.rs index 8168bbd7f..5b16ce004 100644 --- a/crates/tor-chanmgr/src/builder.rs +++ b/crates/tor-chanmgr/src/builder.rs @@ -1,28 +1,21 @@ //! Implement a concrete type to build channels. use std::io; -use std::net::SocketAddr; use std::sync::{Arc, Mutex}; -use crate::factory::{ChannelFactory, TransportHelper}; +use crate::factory::ChannelFactory; +use crate::transport::TransportHelper; use crate::{event::ChanMgrEventSender, Error}; -use safelog::sensitive as sv; use std::time::Duration; -use tor_error::{bad_api_usage, internal}; -use tor_linkspec::{ChannelMethod, HasChanMethod, HasRelayIds, OwnedChanTarget}; +use tor_error::internal; +use tor_linkspec::{HasChanMethod, HasRelayIds, OwnedChanTarget}; use tor_llcrypto::pk; use tor_proto::channel::params::ChannelPaddingInstructionsUpdates; -use tor_rtcompat::{tls::TlsConnector, Runtime, TcpProvider, TlsProvider}; +use tor_rtcompat::{tls::TlsConnector, Runtime, TlsProvider}; use async_trait::async_trait; -use futures::stream::FuturesUnordered; use futures::task::SpawnExt; -use futures::StreamExt; -use futures::{FutureExt, TryFutureExt}; - -/// Time to wait between starting parallel connections to the same relay. -static CONNECTION_DELAY: Duration = Duration::from_millis(150); /// TLS-based channel builder. /// @@ -61,130 +54,6 @@ where } } } - -#[async_trait] -impl crate::mgr::AbstractChannelFactory for CF -where - CF: ChannelFactory + Sync, -{ - type Channel = tor_proto::channel::Channel; - type BuildSpec = OwnedChanTarget; - - async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result { - self.connect_via_transport(target).await - } -} - -/// Connect to one of the addresses in `addrs` by running connections in parallel until one works. -/// -/// This implements a basic version of RFC 8305 "happy eyeballs". -async fn connect_to_one( - rt: &R, - addrs: &[SocketAddr], -) -> crate::Result<(::TcpStream, SocketAddr)> { - // We need *some* addresses to connect to. - if addrs.is_empty() { - return Err(Error::UnusableTarget(bad_api_usage!( - "No addresses for chosen relay" - ))); - } - - // Turn each address into a future that waits (i * CONNECTION_DELAY), then - // attempts to connect to the address using the runtime (where i is the - // array index). Shove all of these into a `FuturesUnordered`, polling them - // simultaneously and returning the results in completion order. - // - // This is basically the concurrent-connection stuff from RFC 8305, ish. - // TODO(eta): sort the addresses first? - let mut connections = addrs - .iter() - .enumerate() - .map(|(i, a)| { - let delay = rt.sleep(CONNECTION_DELAY * i as u32); - delay.then(move |_| { - tracing::debug!("Connecting to {}", a); - rt.connect(a) - .map_ok(move |stream| (stream, *a)) - .map_err(move |e| (e, *a)) - }) - }) - .collect::>(); - - let mut ret = None; - let mut errors = vec![]; - - while let Some(result) = connections.next().await { - match result { - Ok(s) => { - // We got a stream (and address). - ret = Some(s); - break; - } - Err((e, a)) => { - // We got a failure on one of the streams. Store the error. - // TODO(eta): ideally we'd start the next connection attempt immediately. - tracing::warn!("Connection to {} failed: {}", sv(a), e); - errors.push((e, a)); - } - } - } - - // Ensure we don't continue trying to make connections. - drop(connections); - - ret.ok_or_else(|| Error::ChannelBuild { - addresses: errors.into_iter().map(|(e, a)| (a, Arc::new(e))).collect(), - }) -} - -/// A default transport object that opens TCP connections for a -/// `ChannelMethod::Direct`. -/// -/// It opens almost-simultaneous parallel TCP connections to each address, and -/// chooses the first one to succeed. -#[derive(Clone, Debug)] -pub(crate) struct DefaultTransport { - /// The runtime that we use for connecting. - runtime: R, -} - -impl DefaultTransport { - /// Construct a new DefaultTransport - pub(crate) fn new(runtime: R) -> Self { - Self { runtime } - } -} - -#[async_trait] -impl crate::factory::TransportHelper for DefaultTransport { - type Stream = ::TcpStream; - - /// Implements the transport: makes a TCP connection (possibly - /// tunneled over whatever protocol) if possible. - async fn connect( - &self, - target: &OwnedChanTarget, - ) -> crate::Result<(OwnedChanTarget, Self::Stream)> { - let direct_addrs: Vec<_> = match target.chan_method() { - ChannelMethod::Direct(addrs) => addrs, - #[allow(unreachable_patterns)] - _ => { - return Err(Error::UnusableTarget(bad_api_usage!( - "Used default transport implementation for an unsupported transport." - ))) - } - }; - - let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?; - let using_target = match target.restrict_addr(&addr) { - Ok(v) => v, - Err(v) => v, - }; - - Ok((using_target, stream)) - } -} - #[async_trait] impl ChannelFactory for ChanBuilder where @@ -364,11 +233,11 @@ mod test { }; use pk::ed25519::Ed25519Identity; use pk::rsa::RsaIdentity; + use std::net::SocketAddr; use std::time::{Duration, SystemTime}; - use std::{net::SocketAddr, str::FromStr}; use tor_linkspec::ChannelMethod; use tor_proto::channel::Channel; - use tor_rtcompat::{test_with_one_runtime, SleepProviderExt, TcpListener}; + use tor_rtcompat::{test_with_one_runtime, TcpListener}; use tor_rtmock::{io::LocalStream, net::MockNetwork, MockSleepRuntime}; // Make sure that the builder can build a real channel. To test @@ -418,7 +287,7 @@ mod test { // Create the channel builder that we want to test. let (snd, _rcv) = crate::event::channel(); - let transport = DefaultTransport::new(client_rt.clone()); + let transport = crate::transport::DefaultTransport::new(client_rt.clone()); let builder = ChanBuilder::new(client_rt, transport, Arc::new(Mutex::new(snd))); let (r1, r2): (Result, Result) = futures::join!( @@ -454,87 +323,5 @@ mod test { }) } - #[test] - fn test_connect_one() { - let client_addr = "192.0.1.16".parse().unwrap(); - // We'll put a "relay" at this address - let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap(); - // We'll put nothing at this address, to generate errors. - let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap(); - // Well put a black hole at this address, to generate timeouts. - let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap(); - // We'll put a "relay" at this address too - let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap(); - - test_with_one_runtime!(|rt| async move { - // Stub out the internet so that this connection can work. - let network = MockNetwork::new(); - - // Set up a client and server runtime with a given IP - let client_rt = network - .builder() - .add_address(client_addr) - .runtime(rt.clone()); - let server_rt = network - .builder() - .add_address(addr1.ip()) - .add_address(addr4.ip()) - .runtime(rt.clone()); - let _listener = server_rt.mock_net().listen(&addr1).await.unwrap(); - let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap(); - // TODO: Because this test doesn't mock time, there will actually be - // delays as we wait for connections to this address to time out. It - // would be good to use MockSleepProvider instead, once we figure - // out how to make it both reliable and convenient. - network.add_blackhole(addr3).unwrap(); - - // No addresses? Can't succeed. - let failure = connect_to_one(&client_rt, &[]).await; - assert!(failure.is_err()); - - // Connect to a set of addresses including addr1? That's a success. - for addresses in [ - &[addr1][..], - &[addr1, addr2][..], - &[addr2, addr1][..], - &[addr1, addr3][..], - &[addr3, addr1][..], - &[addr1, addr2, addr3][..], - &[addr3, addr2, addr1][..], - ] { - let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap(); - assert_eq!(addr, addr1); - } - - // Connect to a set of addresses including addr2 but not addr1? - // That's an error of one kind or another. - for addresses in [ - &[addr2][..], - &[addr2, addr3][..], - &[addr3, addr2][..], - &[addr3][..], - ] { - let expect_timeout = addresses.contains(&addr3); - let failure = rt - .timeout( - Duration::from_millis(300), - connect_to_one(&client_rt, addresses), - ) - .await; - if expect_timeout { - assert!(failure.is_err()); - } else { - assert!(failure.unwrap().is_err()); - } - } - - // Connect to addr1 and addr4? The first one should win. - let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap(); - assert_eq!(addr, addr1); - let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap(); - assert_eq!(addr, addr4); - }); - } - // TODO: Write tests for timeout logic, once there is smarter logic. } diff --git a/crates/tor-chanmgr/src/factory.rs b/crates/tor-chanmgr/src/factory.rs index b9e1270d4..2d60e3c2f 100644 --- a/crates/tor-chanmgr/src/factory.rs +++ b/crates/tor-chanmgr/src/factory.rs @@ -1,15 +1,16 @@ //! Traits and code to define different mechanisms for building Channels to //! different kinds of targets. +pub(crate) mod registry; + use std::sync::Arc; -use crate::Error; - use async_trait::async_trait; -use futures::{AsyncRead, AsyncWrite}; -use tor_linkspec::{HasChanMethod, OwnedChanTarget, TransportId}; +use tor_linkspec::OwnedChanTarget; use tor_proto::channel::Channel; +pub use registry::TransportRegistry; + /// An object that knows how to build Channels to ChanTargets. /// /// This trait must be object-safe. @@ -26,63 +27,6 @@ pub trait ChannelFactory { async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result; } -/// A more convenient API for defining transports. This type's role is to let -/// the implementor just define a replacement way to pass bytes around, and -/// return something that we can use in place of a TcpStream. -/// -/// This is the trait you should probably implement if you want to define a new -/// [`ChannelFactory`] that performs Tor over TLS over some stream-like type, -/// and you only want to define the stream-like type. -/// -/// To convert a [`TransportHelper`] into a [`ChannelFactory`], wrap it in a ChannelBuilder. -#[async_trait] -pub trait TransportHelper { - /// The type of the resulting stream. - type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static; - - /// Implements the transport: makes a TCP connection (possibly - /// tunneled over whatever protocol) if possible. - /// - /// This method does does not necessarily handle retries or timeouts, - /// although some of its implementations may. - /// - /// This method does not necessarily handle every kind of transport. - /// If the caller provides a target with the wrong [`TransportId`], this - /// method should return [`Error::NoSuchTransport`]. - async fn connect( - &self, - target: &OwnedChanTarget, - ) -> crate::Result<(OwnedChanTarget, Self::Stream)>; -} - -/// An object that knows about one or more ChannelFactories. -pub trait TransportRegistry { - /// Return a ChannelFactory that can make connections via a chosen - /// transport, if we know one. - // - // TODO pt-client: This might need to return an Arc instead of a reference - fn get_factory(&self, transport: &TransportId) -> Option<&(dyn ChannelFactory + Sync)>; -} - -/// Helper type: Wrap a `TransportRegistry` so that it can be used as a -/// `ChannelFactory`. -/// -/// (This has to be a new type, or else the blanket implementation of -/// `ChannelFactory` for `TransportHelper` would conflict.) -#[derive(Clone, Debug)] -pub(crate) struct RegistryAsFactory(R); - -#[async_trait] -impl ChannelFactory for RegistryAsFactory { - async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result { - let method = target.chan_method(); - let id = method.transport_id(); - let factory = self.0.get_factory(&id).ok_or(Error::NoSuchTransport(id))?; - - factory.connect_via_transport(target).await - } -} - #[async_trait] impl<'a> ChannelFactory for Arc<(dyn ChannelFactory + Send + Sync + 'a)> { async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result { @@ -96,3 +40,16 @@ impl<'a> ChannelFactory for Box<(dyn ChannelFactory + Send + Sync + 'a)> { self.as_ref().connect_via_transport(target).await } } + +#[async_trait] +impl crate::mgr::AbstractChannelFactory for CF +where + CF: ChannelFactory + Sync, +{ + type Channel = tor_proto::channel::Channel; + type BuildSpec = OwnedChanTarget; + + async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result { + self.connect_via_transport(target).await + } +} diff --git a/crates/tor-chanmgr/src/factory/registry.rs b/crates/tor-chanmgr/src/factory/registry.rs new file mode 100644 index 000000000..ab8000fdd --- /dev/null +++ b/crates/tor-chanmgr/src/factory/registry.rs @@ -0,0 +1,37 @@ +//! Implement a registry for different kinds of transports. + +use async_trait::async_trait; +use tor_linkspec::{HasChanMethod, OwnedChanTarget, TransportId}; +use tor_proto::channel::Channel; + +use crate::Error; + +use super::ChannelFactory; + +/// An object that knows about one or more ChannelFactories. +pub trait TransportRegistry { + /// Return a ChannelFactory that can make connections via a chosen + /// transport, if we know one. + // + // TODO pt-client: This might need to return an Arc instead of a reference + fn get_factory(&self, transport: &TransportId) -> Option<&(dyn ChannelFactory + Sync)>; +} + +/// Helper type: Wrap a `TransportRegistry` so that it can be used as a +/// `ChannelFactory`. +/// +/// (This has to be a new type, or else the blanket implementation of +/// `ChannelFactory` for `TransportHelper` would conflict.) +#[derive(Clone, Debug)] +pub(crate) struct RegistryAsFactory(R); + +#[async_trait] +impl ChannelFactory for RegistryAsFactory { + async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result { + let method = target.chan_method(); + let id = method.transport_id(); + let factory = self.0.get_factory(&id).ok_or(Error::NoSuchTransport(id))?; + + factory.connect_via_transport(target).await + } +} diff --git a/crates/tor-chanmgr/src/lib.rs b/crates/tor-chanmgr/src/lib.rs index 04b2e4481..d077e35a1 100644 --- a/crates/tor-chanmgr/src/lib.rs +++ b/crates/tor-chanmgr/src/lib.rs @@ -44,6 +44,7 @@ pub mod factory; mod mgr; #[cfg(test)] mod testing; +pub mod transport; use educe::Educe; use factory::ChannelFactory; @@ -165,7 +166,7 @@ impl ChanMgr { { let (sender, receiver) = event::channel(); let sender = Arc::new(std::sync::Mutex::new(sender)); - let transport = builder::DefaultTransport::new(runtime.clone()); + let transport = transport::DefaultTransport::new(runtime.clone()); let builder = builder::ChanBuilder::new(runtime, transport, sender); let builder: Box = Box::new(builder); let mgr = mgr::AbstractChanMgr::new(builder, config, dormancy, netparams); diff --git a/crates/tor-chanmgr/src/transport.rs b/crates/tor-chanmgr/src/transport.rs new file mode 100644 index 000000000..b737841e6 --- /dev/null +++ b/crates/tor-chanmgr/src/transport.rs @@ -0,0 +1,38 @@ +//! Code to define the notion of a "Transport" and implement a default transport. + +use async_trait::async_trait; +use futures::{AsyncRead, AsyncWrite}; +use tor_linkspec::OwnedChanTarget; + +pub(crate) mod default; + +pub(crate) use default::DefaultTransport; + +/// A more convenient API for defining transports. This type's role is to let +/// the implementor just define a replacement way to pass bytes around, and +/// return something that we can use in place of a TcpStream. +/// +/// This is the trait you should probably implement if you want to define a new +/// [`ChannelFactory`] that performs Tor over TLS over some stream-like type, +/// and you only want to define the stream-like type. +/// +/// To convert a [`TransportHelper`] into a [`ChannelFactory`], wrap it in a ChannelBuilder. +#[async_trait] +pub trait TransportHelper { + /// The type of the resulting stream. + type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static; + + /// Implements the transport: makes a TCP connection (possibly + /// tunneled over whatever protocol) if possible. + /// + /// This method does does not necessarily handle retries or timeouts, + /// although some of its implementations may. + /// + /// This method does not necessarily handle every kind of transport. + /// If the caller provides a target with the wrong [`TransportId`], this + /// method should return [`Error::NoSuchTransport`]. + async fn connect( + &self, + target: &OwnedChanTarget, + ) -> crate::Result<(OwnedChanTarget, Self::Stream)>; +} diff --git a/crates/tor-chanmgr/src/transport/default.rs b/crates/tor-chanmgr/src/transport/default.rs new file mode 100644 index 000000000..8d5b6a59c --- /dev/null +++ b/crates/tor-chanmgr/src/transport/default.rs @@ -0,0 +1,228 @@ +//! Implement the default transport, which opens TCP connections using a +//! happy-eyeballs style parallel algorithm. + +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt}; +use safelog::sensitive as sv; +use tor_error::bad_api_usage; +use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget}; +use tor_rtcompat::{Runtime, TcpProvider}; + +use crate::Error; + +/// A default transport object that opens TCP connections for a +/// `ChannelMethod::Direct`. +/// +/// It opens almost-simultaneous parallel TCP connections to each address, and +/// chooses the first one to succeed. +#[derive(Clone, Debug)] +pub(crate) struct DefaultTransport { + /// The runtime that we use for connecting. + runtime: R, +} + +impl DefaultTransport { + /// Construct a new DefaultTransport + pub(crate) fn new(runtime: R) -> Self { + Self { runtime } + } +} + +#[async_trait] +impl crate::transport::TransportHelper for DefaultTransport { + type Stream = ::TcpStream; + + /// Implements the transport: makes a TCP connection (possibly + /// tunneled over whatever protocol) if possible. + async fn connect( + &self, + target: &OwnedChanTarget, + ) -> crate::Result<(OwnedChanTarget, Self::Stream)> { + let direct_addrs: Vec<_> = match target.chan_method() { + ChannelMethod::Direct(addrs) => addrs, + #[allow(unreachable_patterns)] + _ => { + return Err(Error::UnusableTarget(bad_api_usage!( + "Used default transport implementation for an unsupported transport." + ))) + } + }; + + let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?; + let using_target = match target.restrict_addr(&addr) { + Ok(v) => v, + Err(v) => v, + }; + + Ok((using_target, stream)) + } +} + +/// Time to wait between starting parallel connections to the same relay. +static CONNECTION_DELAY: Duration = Duration::from_millis(150); + +/// Connect to one of the addresses in `addrs` by running connections in parallel until one works. +/// +/// This implements a basic version of RFC 8305 "happy eyeballs". +async fn connect_to_one( + rt: &R, + addrs: &[SocketAddr], +) -> crate::Result<(::TcpStream, SocketAddr)> { + // We need *some* addresses to connect to. + if addrs.is_empty() { + return Err(Error::UnusableTarget(bad_api_usage!( + "No addresses for chosen relay" + ))); + } + + // Turn each address into a future that waits (i * CONNECTION_DELAY), then + // attempts to connect to the address using the runtime (where i is the + // array index). Shove all of these into a `FuturesUnordered`, polling them + // simultaneously and returning the results in completion order. + // + // This is basically the concurrent-connection stuff from RFC 8305, ish. + // TODO(eta): sort the addresses first? + let mut connections = addrs + .iter() + .enumerate() + .map(|(i, a)| { + let delay = rt.sleep(CONNECTION_DELAY * i as u32); + delay.then(move |_| { + tracing::debug!("Connecting to {}", a); + rt.connect(a) + .map_ok(move |stream| (stream, *a)) + .map_err(move |e| (e, *a)) + }) + }) + .collect::>(); + + let mut ret = None; + let mut errors = vec![]; + + while let Some(result) = connections.next().await { + match result { + Ok(s) => { + // We got a stream (and address). + ret = Some(s); + break; + } + Err((e, a)) => { + // We got a failure on one of the streams. Store the error. + // TODO(eta): ideally we'd start the next connection attempt immediately. + tracing::warn!("Connection to {} failed: {}", sv(a), e); + errors.push((e, a)); + } + } + } + + // Ensure we don't continue trying to make connections. + drop(connections); + + ret.ok_or_else(|| Error::ChannelBuild { + addresses: errors.into_iter().map(|(e, a)| (a, Arc::new(e))).collect(), + }) +} + +#[cfg(test)] +mod test { + // @@ begin test lint list maintained by maint/add_warning @@ + #![allow(clippy::bool_assert_comparison)] + #![allow(clippy::clone_on_copy)] + #![allow(clippy::dbg_macro)] + #![allow(clippy::print_stderr)] + #![allow(clippy::print_stdout)] + #![allow(clippy::single_char_pattern)] + #![allow(clippy::unwrap_used)] + //! + + use std::str::FromStr; + + use tor_rtcompat::{test_with_one_runtime, SleepProviderExt}; + use tor_rtmock::net::MockNetwork; + + use super::*; + + #[test] + fn test_connect_one() { + let client_addr = "192.0.1.16".parse().unwrap(); + // We'll put a "relay" at this address + let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap(); + // We'll put nothing at this address, to generate errors. + let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap(); + // Well put a black hole at this address, to generate timeouts. + let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap(); + // We'll put a "relay" at this address too + let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap(); + + test_with_one_runtime!(|rt| async move { + // Stub out the internet so that this connection can work. + let network = MockNetwork::new(); + + // Set up a client and server runtime with a given IP + let client_rt = network + .builder() + .add_address(client_addr) + .runtime(rt.clone()); + let server_rt = network + .builder() + .add_address(addr1.ip()) + .add_address(addr4.ip()) + .runtime(rt.clone()); + let _listener = server_rt.mock_net().listen(&addr1).await.unwrap(); + let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap(); + // TODO: Because this test doesn't mock time, there will actually be + // delays as we wait for connections to this address to time out. It + // would be good to use MockSleepProvider instead, once we figure + // out how to make it both reliable and convenient. + network.add_blackhole(addr3).unwrap(); + + // No addresses? Can't succeed. + let failure = connect_to_one(&client_rt, &[]).await; + assert!(failure.is_err()); + + // Connect to a set of addresses including addr1? That's a success. + for addresses in [ + &[addr1][..], + &[addr1, addr2][..], + &[addr2, addr1][..], + &[addr1, addr3][..], + &[addr3, addr1][..], + &[addr1, addr2, addr3][..], + &[addr3, addr2, addr1][..], + ] { + let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap(); + assert_eq!(addr, addr1); + } + + // Connect to a set of addresses including addr2 but not addr1? + // That's an error of one kind or another. + for addresses in [ + &[addr2][..], + &[addr2, addr3][..], + &[addr3, addr2][..], + &[addr3][..], + ] { + let expect_timeout = addresses.contains(&addr3); + let failure = rt + .timeout( + Duration::from_millis(300), + connect_to_one(&client_rt, addresses), + ) + .await; + if expect_timeout { + assert!(failure.is_err()); + } else { + assert!(failure.unwrap().is_err()); + } + } + + // Connect to addr1 and addr4? The first one should win. + let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap(); + assert_eq!(addr, addr1); + let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap(); + assert_eq!(addr, addr4); + }); + } +}