ChanMgr: Reorganize factory, builder, transport code.
There is no actual code change here: just movement.
This commit is contained in:
parent
a77312a6ec
commit
fe2d44d10a
|
@ -1,28 +1,21 @@
|
|||
//! Implement a concrete type to build channels.
|
||||
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::factory::{ChannelFactory, TransportHelper};
|
||||
use crate::factory::ChannelFactory;
|
||||
use crate::transport::TransportHelper;
|
||||
use crate::{event::ChanMgrEventSender, Error};
|
||||
|
||||
use safelog::sensitive as sv;
|
||||
use std::time::Duration;
|
||||
use tor_error::{bad_api_usage, internal};
|
||||
use tor_linkspec::{ChannelMethod, HasChanMethod, HasRelayIds, OwnedChanTarget};
|
||||
use tor_error::internal;
|
||||
use tor_linkspec::{HasChanMethod, HasRelayIds, OwnedChanTarget};
|
||||
use tor_llcrypto::pk;
|
||||
use tor_proto::channel::params::ChannelPaddingInstructionsUpdates;
|
||||
use tor_rtcompat::{tls::TlsConnector, Runtime, TcpProvider, TlsProvider};
|
||||
use tor_rtcompat::{tls::TlsConnector, Runtime, TlsProvider};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::task::SpawnExt;
|
||||
use futures::StreamExt;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
|
||||
/// Time to wait between starting parallel connections to the same relay.
|
||||
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
|
||||
|
||||
/// TLS-based channel builder.
|
||||
///
|
||||
|
@ -61,130 +54,6 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<CF> crate::mgr::AbstractChannelFactory for CF
|
||||
where
|
||||
CF: ChannelFactory + Sync,
|
||||
{
|
||||
type Channel = tor_proto::channel::Channel;
|
||||
type BuildSpec = OwnedChanTarget;
|
||||
|
||||
async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result<Self::Channel> {
|
||||
self.connect_via_transport(target).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
|
||||
///
|
||||
/// This implements a basic version of RFC 8305 "happy eyeballs".
|
||||
async fn connect_to_one<R: Runtime>(
|
||||
rt: &R,
|
||||
addrs: &[SocketAddr],
|
||||
) -> crate::Result<(<R as TcpProvider>::TcpStream, SocketAddr)> {
|
||||
// We need *some* addresses to connect to.
|
||||
if addrs.is_empty() {
|
||||
return Err(Error::UnusableTarget(bad_api_usage!(
|
||||
"No addresses for chosen relay"
|
||||
)));
|
||||
}
|
||||
|
||||
// Turn each address into a future that waits (i * CONNECTION_DELAY), then
|
||||
// attempts to connect to the address using the runtime (where i is the
|
||||
// array index). Shove all of these into a `FuturesUnordered`, polling them
|
||||
// simultaneously and returning the results in completion order.
|
||||
//
|
||||
// This is basically the concurrent-connection stuff from RFC 8305, ish.
|
||||
// TODO(eta): sort the addresses first?
|
||||
let mut connections = addrs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, a)| {
|
||||
let delay = rt.sleep(CONNECTION_DELAY * i as u32);
|
||||
delay.then(move |_| {
|
||||
tracing::debug!("Connecting to {}", a);
|
||||
rt.connect(a)
|
||||
.map_ok(move |stream| (stream, *a))
|
||||
.map_err(move |e| (e, *a))
|
||||
})
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
|
||||
let mut ret = None;
|
||||
let mut errors = vec![];
|
||||
|
||||
while let Some(result) = connections.next().await {
|
||||
match result {
|
||||
Ok(s) => {
|
||||
// We got a stream (and address).
|
||||
ret = Some(s);
|
||||
break;
|
||||
}
|
||||
Err((e, a)) => {
|
||||
// We got a failure on one of the streams. Store the error.
|
||||
// TODO(eta): ideally we'd start the next connection attempt immediately.
|
||||
tracing::warn!("Connection to {} failed: {}", sv(a), e);
|
||||
errors.push((e, a));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we don't continue trying to make connections.
|
||||
drop(connections);
|
||||
|
||||
ret.ok_or_else(|| Error::ChannelBuild {
|
||||
addresses: errors.into_iter().map(|(e, a)| (a, Arc::new(e))).collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// A default transport object that opens TCP connections for a
|
||||
/// `ChannelMethod::Direct`.
|
||||
///
|
||||
/// It opens almost-simultaneous parallel TCP connections to each address, and
|
||||
/// chooses the first one to succeed.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct DefaultTransport<R: Runtime> {
|
||||
/// The runtime that we use for connecting.
|
||||
runtime: R,
|
||||
}
|
||||
|
||||
impl<R: Runtime> DefaultTransport<R> {
|
||||
/// Construct a new DefaultTransport
|
||||
pub(crate) fn new(runtime: R) -> Self {
|
||||
Self { runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<R: Runtime> crate::factory::TransportHelper for DefaultTransport<R> {
|
||||
type Stream = <R as TcpProvider>::TcpStream;
|
||||
|
||||
/// Implements the transport: makes a TCP connection (possibly
|
||||
/// tunneled over whatever protocol) if possible.
|
||||
async fn connect(
|
||||
&self,
|
||||
target: &OwnedChanTarget,
|
||||
) -> crate::Result<(OwnedChanTarget, Self::Stream)> {
|
||||
let direct_addrs: Vec<_> = match target.chan_method() {
|
||||
ChannelMethod::Direct(addrs) => addrs,
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => {
|
||||
return Err(Error::UnusableTarget(bad_api_usage!(
|
||||
"Used default transport implementation for an unsupported transport."
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?;
|
||||
let using_target = match target.restrict_addr(&addr) {
|
||||
Ok(v) => v,
|
||||
Err(v) => v,
|
||||
};
|
||||
|
||||
Ok((using_target, stream))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<R: Runtime, H: TransportHelper> ChannelFactory for ChanBuilder<R, H>
|
||||
where
|
||||
|
@ -364,11 +233,11 @@ mod test {
|
|||
};
|
||||
use pk::ed25519::Ed25519Identity;
|
||||
use pk::rsa::RsaIdentity;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{net::SocketAddr, str::FromStr};
|
||||
use tor_linkspec::ChannelMethod;
|
||||
use tor_proto::channel::Channel;
|
||||
use tor_rtcompat::{test_with_one_runtime, SleepProviderExt, TcpListener};
|
||||
use tor_rtcompat::{test_with_one_runtime, TcpListener};
|
||||
use tor_rtmock::{io::LocalStream, net::MockNetwork, MockSleepRuntime};
|
||||
|
||||
// Make sure that the builder can build a real channel. To test
|
||||
|
@ -418,7 +287,7 @@ mod test {
|
|||
|
||||
// Create the channel builder that we want to test.
|
||||
let (snd, _rcv) = crate::event::channel();
|
||||
let transport = DefaultTransport::new(client_rt.clone());
|
||||
let transport = crate::transport::DefaultTransport::new(client_rt.clone());
|
||||
let builder = ChanBuilder::new(client_rt, transport, Arc::new(Mutex::new(snd)));
|
||||
|
||||
let (r1, r2): (Result<Channel>, Result<LocalStream>) = futures::join!(
|
||||
|
@ -454,87 +323,5 @@ mod test {
|
|||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connect_one() {
|
||||
let client_addr = "192.0.1.16".parse().unwrap();
|
||||
// We'll put a "relay" at this address
|
||||
let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
|
||||
// We'll put nothing at this address, to generate errors.
|
||||
let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
|
||||
// Well put a black hole at this address, to generate timeouts.
|
||||
let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
|
||||
// We'll put a "relay" at this address too
|
||||
let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
|
||||
|
||||
test_with_one_runtime!(|rt| async move {
|
||||
// Stub out the internet so that this connection can work.
|
||||
let network = MockNetwork::new();
|
||||
|
||||
// Set up a client and server runtime with a given IP
|
||||
let client_rt = network
|
||||
.builder()
|
||||
.add_address(client_addr)
|
||||
.runtime(rt.clone());
|
||||
let server_rt = network
|
||||
.builder()
|
||||
.add_address(addr1.ip())
|
||||
.add_address(addr4.ip())
|
||||
.runtime(rt.clone());
|
||||
let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
|
||||
let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
|
||||
// TODO: Because this test doesn't mock time, there will actually be
|
||||
// delays as we wait for connections to this address to time out. It
|
||||
// would be good to use MockSleepProvider instead, once we figure
|
||||
// out how to make it both reliable and convenient.
|
||||
network.add_blackhole(addr3).unwrap();
|
||||
|
||||
// No addresses? Can't succeed.
|
||||
let failure = connect_to_one(&client_rt, &[]).await;
|
||||
assert!(failure.is_err());
|
||||
|
||||
// Connect to a set of addresses including addr1? That's a success.
|
||||
for addresses in [
|
||||
&[addr1][..],
|
||||
&[addr1, addr2][..],
|
||||
&[addr2, addr1][..],
|
||||
&[addr1, addr3][..],
|
||||
&[addr3, addr1][..],
|
||||
&[addr1, addr2, addr3][..],
|
||||
&[addr3, addr2, addr1][..],
|
||||
] {
|
||||
let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap();
|
||||
assert_eq!(addr, addr1);
|
||||
}
|
||||
|
||||
// Connect to a set of addresses including addr2 but not addr1?
|
||||
// That's an error of one kind or another.
|
||||
for addresses in [
|
||||
&[addr2][..],
|
||||
&[addr2, addr3][..],
|
||||
&[addr3, addr2][..],
|
||||
&[addr3][..],
|
||||
] {
|
||||
let expect_timeout = addresses.contains(&addr3);
|
||||
let failure = rt
|
||||
.timeout(
|
||||
Duration::from_millis(300),
|
||||
connect_to_one(&client_rt, addresses),
|
||||
)
|
||||
.await;
|
||||
if expect_timeout {
|
||||
assert!(failure.is_err());
|
||||
} else {
|
||||
assert!(failure.unwrap().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to addr1 and addr4? The first one should win.
|
||||
let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap();
|
||||
assert_eq!(addr, addr1);
|
||||
let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap();
|
||||
assert_eq!(addr, addr4);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: Write tests for timeout logic, once there is smarter logic.
|
||||
}
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
//! Traits and code to define different mechanisms for building Channels to
|
||||
//! different kinds of targets.
|
||||
|
||||
pub(crate) mod registry;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::Error;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use tor_linkspec::{HasChanMethod, OwnedChanTarget, TransportId};
|
||||
use tor_linkspec::OwnedChanTarget;
|
||||
use tor_proto::channel::Channel;
|
||||
|
||||
pub use registry::TransportRegistry;
|
||||
|
||||
/// An object that knows how to build Channels to ChanTargets.
|
||||
///
|
||||
/// This trait must be object-safe.
|
||||
|
@ -26,63 +27,6 @@ pub trait ChannelFactory {
|
|||
async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result<Channel>;
|
||||
}
|
||||
|
||||
/// A more convenient API for defining transports. This type's role is to let
|
||||
/// the implementor just define a replacement way to pass bytes around, and
|
||||
/// return something that we can use in place of a TcpStream.
|
||||
///
|
||||
/// This is the trait you should probably implement if you want to define a new
|
||||
/// [`ChannelFactory`] that performs Tor over TLS over some stream-like type,
|
||||
/// and you only want to define the stream-like type.
|
||||
///
|
||||
/// To convert a [`TransportHelper`] into a [`ChannelFactory`], wrap it in a ChannelBuilder.
|
||||
#[async_trait]
|
||||
pub trait TransportHelper {
|
||||
/// The type of the resulting stream.
|
||||
type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static;
|
||||
|
||||
/// Implements the transport: makes a TCP connection (possibly
|
||||
/// tunneled over whatever protocol) if possible.
|
||||
///
|
||||
/// This method does does not necessarily handle retries or timeouts,
|
||||
/// although some of its implementations may.
|
||||
///
|
||||
/// This method does not necessarily handle every kind of transport.
|
||||
/// If the caller provides a target with the wrong [`TransportId`], this
|
||||
/// method should return [`Error::NoSuchTransport`].
|
||||
async fn connect(
|
||||
&self,
|
||||
target: &OwnedChanTarget,
|
||||
) -> crate::Result<(OwnedChanTarget, Self::Stream)>;
|
||||
}
|
||||
|
||||
/// An object that knows about one or more ChannelFactories.
|
||||
pub trait TransportRegistry {
|
||||
/// Return a ChannelFactory that can make connections via a chosen
|
||||
/// transport, if we know one.
|
||||
//
|
||||
// TODO pt-client: This might need to return an Arc instead of a reference
|
||||
fn get_factory(&self, transport: &TransportId) -> Option<&(dyn ChannelFactory + Sync)>;
|
||||
}
|
||||
|
||||
/// Helper type: Wrap a `TransportRegistry` so that it can be used as a
|
||||
/// `ChannelFactory`.
|
||||
///
|
||||
/// (This has to be a new type, or else the blanket implementation of
|
||||
/// `ChannelFactory` for `TransportHelper` would conflict.)
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct RegistryAsFactory<R: TransportRegistry>(R);
|
||||
|
||||
#[async_trait]
|
||||
impl<R: TransportRegistry + Sync> ChannelFactory for RegistryAsFactory<R> {
|
||||
async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result<Channel> {
|
||||
let method = target.chan_method();
|
||||
let id = method.transport_id();
|
||||
let factory = self.0.get_factory(&id).ok_or(Error::NoSuchTransport(id))?;
|
||||
|
||||
factory.connect_via_transport(target).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> ChannelFactory for Arc<(dyn ChannelFactory + Send + Sync + 'a)> {
|
||||
async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result<Channel> {
|
||||
|
@ -96,3 +40,16 @@ impl<'a> ChannelFactory for Box<(dyn ChannelFactory + Send + Sync + 'a)> {
|
|||
self.as_ref().connect_via_transport(target).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<CF> crate::mgr::AbstractChannelFactory for CF
|
||||
where
|
||||
CF: ChannelFactory + Sync,
|
||||
{
|
||||
type Channel = tor_proto::channel::Channel;
|
||||
type BuildSpec = OwnedChanTarget;
|
||||
|
||||
async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result<Self::Channel> {
|
||||
self.connect_via_transport(target).await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
//! Implement a registry for different kinds of transports.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tor_linkspec::{HasChanMethod, OwnedChanTarget, TransportId};
|
||||
use tor_proto::channel::Channel;
|
||||
|
||||
use crate::Error;
|
||||
|
||||
use super::ChannelFactory;
|
||||
|
||||
/// An object that knows about one or more ChannelFactories.
|
||||
pub trait TransportRegistry {
|
||||
/// Return a ChannelFactory that can make connections via a chosen
|
||||
/// transport, if we know one.
|
||||
//
|
||||
// TODO pt-client: This might need to return an Arc instead of a reference
|
||||
fn get_factory(&self, transport: &TransportId) -> Option<&(dyn ChannelFactory + Sync)>;
|
||||
}
|
||||
|
||||
/// Helper type: Wrap a `TransportRegistry` so that it can be used as a
|
||||
/// `ChannelFactory`.
|
||||
///
|
||||
/// (This has to be a new type, or else the blanket implementation of
|
||||
/// `ChannelFactory` for `TransportHelper` would conflict.)
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct RegistryAsFactory<R: TransportRegistry>(R);
|
||||
|
||||
#[async_trait]
|
||||
impl<R: TransportRegistry + Sync> ChannelFactory for RegistryAsFactory<R> {
|
||||
async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result<Channel> {
|
||||
let method = target.chan_method();
|
||||
let id = method.transport_id();
|
||||
let factory = self.0.get_factory(&id).ok_or(Error::NoSuchTransport(id))?;
|
||||
|
||||
factory.connect_via_transport(target).await
|
||||
}
|
||||
}
|
|
@ -44,6 +44,7 @@ pub mod factory;
|
|||
mod mgr;
|
||||
#[cfg(test)]
|
||||
mod testing;
|
||||
pub mod transport;
|
||||
|
||||
use educe::Educe;
|
||||
use factory::ChannelFactory;
|
||||
|
@ -165,7 +166,7 @@ impl<R: Runtime> ChanMgr<R> {
|
|||
{
|
||||
let (sender, receiver) = event::channel();
|
||||
let sender = Arc::new(std::sync::Mutex::new(sender));
|
||||
let transport = builder::DefaultTransport::new(runtime.clone());
|
||||
let transport = transport::DefaultTransport::new(runtime.clone());
|
||||
let builder = builder::ChanBuilder::new(runtime, transport, sender);
|
||||
let builder: Box<dyn ChannelFactory + Send + Sync + 'static> = Box::new(builder);
|
||||
let mgr = mgr::AbstractChanMgr::new(builder, config, dormancy, netparams);
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
//! Code to define the notion of a "Transport" and implement a default transport.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use tor_linkspec::OwnedChanTarget;
|
||||
|
||||
pub(crate) mod default;
|
||||
|
||||
pub(crate) use default::DefaultTransport;
|
||||
|
||||
/// A more convenient API for defining transports. This type's role is to let
|
||||
/// the implementor just define a replacement way to pass bytes around, and
|
||||
/// return something that we can use in place of a TcpStream.
|
||||
///
|
||||
/// This is the trait you should probably implement if you want to define a new
|
||||
/// [`ChannelFactory`] that performs Tor over TLS over some stream-like type,
|
||||
/// and you only want to define the stream-like type.
|
||||
///
|
||||
/// To convert a [`TransportHelper`] into a [`ChannelFactory`], wrap it in a ChannelBuilder.
|
||||
#[async_trait]
|
||||
pub trait TransportHelper {
|
||||
/// The type of the resulting stream.
|
||||
type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static;
|
||||
|
||||
/// Implements the transport: makes a TCP connection (possibly
|
||||
/// tunneled over whatever protocol) if possible.
|
||||
///
|
||||
/// This method does does not necessarily handle retries or timeouts,
|
||||
/// although some of its implementations may.
|
||||
///
|
||||
/// This method does not necessarily handle every kind of transport.
|
||||
/// If the caller provides a target with the wrong [`TransportId`], this
|
||||
/// method should return [`Error::NoSuchTransport`].
|
||||
async fn connect(
|
||||
&self,
|
||||
target: &OwnedChanTarget,
|
||||
) -> crate::Result<(OwnedChanTarget, Self::Stream)>;
|
||||
}
|
|
@ -0,0 +1,228 @@
|
|||
//! Implement the default transport, which opens TCP connections using a
|
||||
//! happy-eyeballs style parallel algorithm.
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
|
||||
use safelog::sensitive as sv;
|
||||
use tor_error::bad_api_usage;
|
||||
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
|
||||
use tor_rtcompat::{Runtime, TcpProvider};
|
||||
|
||||
use crate::Error;
|
||||
|
||||
/// A default transport object that opens TCP connections for a
|
||||
/// `ChannelMethod::Direct`.
|
||||
///
|
||||
/// It opens almost-simultaneous parallel TCP connections to each address, and
|
||||
/// chooses the first one to succeed.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct DefaultTransport<R: Runtime> {
|
||||
/// The runtime that we use for connecting.
|
||||
runtime: R,
|
||||
}
|
||||
|
||||
impl<R: Runtime> DefaultTransport<R> {
|
||||
/// Construct a new DefaultTransport
|
||||
pub(crate) fn new(runtime: R) -> Self {
|
||||
Self { runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<R: Runtime> crate::transport::TransportHelper for DefaultTransport<R> {
|
||||
type Stream = <R as TcpProvider>::TcpStream;
|
||||
|
||||
/// Implements the transport: makes a TCP connection (possibly
|
||||
/// tunneled over whatever protocol) if possible.
|
||||
async fn connect(
|
||||
&self,
|
||||
target: &OwnedChanTarget,
|
||||
) -> crate::Result<(OwnedChanTarget, Self::Stream)> {
|
||||
let direct_addrs: Vec<_> = match target.chan_method() {
|
||||
ChannelMethod::Direct(addrs) => addrs,
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => {
|
||||
return Err(Error::UnusableTarget(bad_api_usage!(
|
||||
"Used default transport implementation for an unsupported transport."
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?;
|
||||
let using_target = match target.restrict_addr(&addr) {
|
||||
Ok(v) => v,
|
||||
Err(v) => v,
|
||||
};
|
||||
|
||||
Ok((using_target, stream))
|
||||
}
|
||||
}
|
||||
|
||||
/// Time to wait between starting parallel connections to the same relay.
|
||||
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
|
||||
|
||||
/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
|
||||
///
|
||||
/// This implements a basic version of RFC 8305 "happy eyeballs".
|
||||
async fn connect_to_one<R: Runtime>(
|
||||
rt: &R,
|
||||
addrs: &[SocketAddr],
|
||||
) -> crate::Result<(<R as TcpProvider>::TcpStream, SocketAddr)> {
|
||||
// We need *some* addresses to connect to.
|
||||
if addrs.is_empty() {
|
||||
return Err(Error::UnusableTarget(bad_api_usage!(
|
||||
"No addresses for chosen relay"
|
||||
)));
|
||||
}
|
||||
|
||||
// Turn each address into a future that waits (i * CONNECTION_DELAY), then
|
||||
// attempts to connect to the address using the runtime (where i is the
|
||||
// array index). Shove all of these into a `FuturesUnordered`, polling them
|
||||
// simultaneously and returning the results in completion order.
|
||||
//
|
||||
// This is basically the concurrent-connection stuff from RFC 8305, ish.
|
||||
// TODO(eta): sort the addresses first?
|
||||
let mut connections = addrs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, a)| {
|
||||
let delay = rt.sleep(CONNECTION_DELAY * i as u32);
|
||||
delay.then(move |_| {
|
||||
tracing::debug!("Connecting to {}", a);
|
||||
rt.connect(a)
|
||||
.map_ok(move |stream| (stream, *a))
|
||||
.map_err(move |e| (e, *a))
|
||||
})
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
|
||||
let mut ret = None;
|
||||
let mut errors = vec![];
|
||||
|
||||
while let Some(result) = connections.next().await {
|
||||
match result {
|
||||
Ok(s) => {
|
||||
// We got a stream (and address).
|
||||
ret = Some(s);
|
||||
break;
|
||||
}
|
||||
Err((e, a)) => {
|
||||
// We got a failure on one of the streams. Store the error.
|
||||
// TODO(eta): ideally we'd start the next connection attempt immediately.
|
||||
tracing::warn!("Connection to {} failed: {}", sv(a), e);
|
||||
errors.push((e, a));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we don't continue trying to make connections.
|
||||
drop(connections);
|
||||
|
||||
ret.ok_or_else(|| Error::ChannelBuild {
|
||||
addresses: errors.into_iter().map(|(e, a)| (a, Arc::new(e))).collect(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
// @@ begin test lint list maintained by maint/add_warning @@
|
||||
#![allow(clippy::bool_assert_comparison)]
|
||||
#![allow(clippy::clone_on_copy)]
|
||||
#![allow(clippy::dbg_macro)]
|
||||
#![allow(clippy::print_stderr)]
|
||||
#![allow(clippy::print_stdout)]
|
||||
#![allow(clippy::single_char_pattern)]
|
||||
#![allow(clippy::unwrap_used)]
|
||||
//! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use tor_rtcompat::{test_with_one_runtime, SleepProviderExt};
|
||||
use tor_rtmock::net::MockNetwork;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_connect_one() {
|
||||
let client_addr = "192.0.1.16".parse().unwrap();
|
||||
// We'll put a "relay" at this address
|
||||
let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
|
||||
// We'll put nothing at this address, to generate errors.
|
||||
let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
|
||||
// Well put a black hole at this address, to generate timeouts.
|
||||
let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
|
||||
// We'll put a "relay" at this address too
|
||||
let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
|
||||
|
||||
test_with_one_runtime!(|rt| async move {
|
||||
// Stub out the internet so that this connection can work.
|
||||
let network = MockNetwork::new();
|
||||
|
||||
// Set up a client and server runtime with a given IP
|
||||
let client_rt = network
|
||||
.builder()
|
||||
.add_address(client_addr)
|
||||
.runtime(rt.clone());
|
||||
let server_rt = network
|
||||
.builder()
|
||||
.add_address(addr1.ip())
|
||||
.add_address(addr4.ip())
|
||||
.runtime(rt.clone());
|
||||
let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
|
||||
let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
|
||||
// TODO: Because this test doesn't mock time, there will actually be
|
||||
// delays as we wait for connections to this address to time out. It
|
||||
// would be good to use MockSleepProvider instead, once we figure
|
||||
// out how to make it both reliable and convenient.
|
||||
network.add_blackhole(addr3).unwrap();
|
||||
|
||||
// No addresses? Can't succeed.
|
||||
let failure = connect_to_one(&client_rt, &[]).await;
|
||||
assert!(failure.is_err());
|
||||
|
||||
// Connect to a set of addresses including addr1? That's a success.
|
||||
for addresses in [
|
||||
&[addr1][..],
|
||||
&[addr1, addr2][..],
|
||||
&[addr2, addr1][..],
|
||||
&[addr1, addr3][..],
|
||||
&[addr3, addr1][..],
|
||||
&[addr1, addr2, addr3][..],
|
||||
&[addr3, addr2, addr1][..],
|
||||
] {
|
||||
let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap();
|
||||
assert_eq!(addr, addr1);
|
||||
}
|
||||
|
||||
// Connect to a set of addresses including addr2 but not addr1?
|
||||
// That's an error of one kind or another.
|
||||
for addresses in [
|
||||
&[addr2][..],
|
||||
&[addr2, addr3][..],
|
||||
&[addr3, addr2][..],
|
||||
&[addr3][..],
|
||||
] {
|
||||
let expect_timeout = addresses.contains(&addr3);
|
||||
let failure = rt
|
||||
.timeout(
|
||||
Duration::from_millis(300),
|
||||
connect_to_one(&client_rt, addresses),
|
||||
)
|
||||
.await;
|
||||
if expect_timeout {
|
||||
assert!(failure.is_err());
|
||||
} else {
|
||||
assert!(failure.unwrap().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to addr1 and addr4? The first one should win.
|
||||
let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap();
|
||||
assert_eq!(addr, addr1);
|
||||
let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap();
|
||||
assert_eq!(addr, addr4);
|
||||
});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue