Merge branch 'chanmgr_refactor_and_test'
This is an apparent test coverage reduction in chanmgr, but that's only apparent: it actually throws away a lot of untested code that was previously hidden.
This commit is contained in:
commit
ddd7ce0074
|
@ -0,0 +1,167 @@
|
|||
//! Implement a concrete type to build channels.
|
||||
|
||||
use crate::Error;
|
||||
|
||||
use tor_linkspec::ChanTarget;
|
||||
use tor_llcrypto::pk;
|
||||
use tor_rtcompat::{tls::TlsConnector, Runtime, TlsProvider};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::task::SpawnExt;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// TLS-based channel builder.
|
||||
///
|
||||
/// This is a separate type so that we can keep our channel management
|
||||
/// code network-agnostic.
|
||||
pub(crate) struct ChanBuilder<R: Runtime> {
|
||||
/// Asynchronous runtime for TLS, TCP, spawning, and timeouts.
|
||||
runtime: R,
|
||||
/// Object to build TLS connections.
|
||||
tls_connector: <R as TlsProvider>::Connector,
|
||||
}
|
||||
|
||||
impl<R: Runtime> ChanBuilder<R> {
|
||||
/// Construct a new ChanBuilder.
|
||||
pub(crate) fn new(runtime: R) -> Self {
|
||||
let tls_connector = runtime.tls_connector();
|
||||
ChanBuilder {
|
||||
runtime,
|
||||
tls_connector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<R: Runtime> crate::mgr::ChannelFactory for ChanBuilder<R> {
|
||||
type Channel = tor_proto::channel::Channel;
|
||||
type BuildSpec = TargetInfo;
|
||||
|
||||
async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result<Arc<Self::Channel>> {
|
||||
use tor_rtcompat::SleepProviderExt;
|
||||
|
||||
// TODO: make this an option. And make a better value.
|
||||
let five_seconds = std::time::Duration::new(5, 0);
|
||||
|
||||
self.runtime
|
||||
.timeout(five_seconds, self.build_channel_notimeout(target))
|
||||
.await?
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> ChanBuilder<R> {
|
||||
/// As build_channel, but don't include a timeout.
|
||||
async fn build_channel_notimeout(
|
||||
&self,
|
||||
target: &TargetInfo,
|
||||
) -> crate::Result<Arc<tor_proto::channel::Channel>> {
|
||||
use tor_proto::channel::ChannelBuilder;
|
||||
use tor_rtcompat::tls::CertifiedConn;
|
||||
|
||||
// 1. Negotiate the TLS connection.
|
||||
|
||||
// TODO: This just uses the first address. Instead we could be smarter,
|
||||
// or use "happy eyeballs, or whatever. Maybe we will want to
|
||||
// refactor as we do so?
|
||||
let addr = target
|
||||
.addrs()
|
||||
.get(0)
|
||||
.ok_or_else(|| Error::UnusableTarget("No addresses for chosen relay".into()))?;
|
||||
|
||||
log::info!("Negotiating TLS with {}", addr);
|
||||
|
||||
// TODO: add a random hostname here if it will be used for SNI?
|
||||
let tls = self
|
||||
.tls_connector
|
||||
.connect_unvalidated(addr, "ignored")
|
||||
.await?;
|
||||
|
||||
let peer_cert = tls
|
||||
.peer_certificate()?
|
||||
.ok_or(Error::Internal("TLS connection with no peer certificate"))?;
|
||||
|
||||
// 2. Set up the channel.
|
||||
let mut builder = ChannelBuilder::new();
|
||||
builder.set_declared_addr(*addr);
|
||||
let chan = builder.launch(tls).connect().await?;
|
||||
let chan = chan.check(target, &peer_cert)?;
|
||||
let (chan, reactor) = chan.finish().await?;
|
||||
|
||||
// 3. Launch a task to run the channel reactor.
|
||||
self.runtime.spawn(async {
|
||||
let _ = reactor.run().await;
|
||||
})?;
|
||||
Ok(chan)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::mgr::AbstractChannel for tor_proto::channel::Channel {
|
||||
type Ident = pk::ed25519::Ed25519Identity;
|
||||
fn ident(&self) -> &Self::Ident {
|
||||
self.peer_ed25519_id()
|
||||
}
|
||||
fn is_usable(&self) -> bool {
|
||||
!self.is_closing()
|
||||
}
|
||||
}
|
||||
|
||||
/// TargetInfo is a summary of a [`ChanTarget`] that we can pass to
|
||||
/// [`ChanBuilder::build_channel`].
|
||||
///
|
||||
/// This is a separate type since we can't declare ChanBuilder as having
|
||||
/// a parameterized method in today's Rust.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct TargetInfo {
|
||||
/// Copy of the addresses from the underlying ChanTarget.
|
||||
addrs: Vec<SocketAddr>,
|
||||
/// Copy of the ed25519 id from the underlying ChanTarget.
|
||||
ed_identity: pk::ed25519::Ed25519Identity,
|
||||
/// Copy of the rsa id from the underlying ChanTarget.
|
||||
rsa_identity: pk::rsa::RsaIdentity,
|
||||
}
|
||||
|
||||
impl ChanTarget for TargetInfo {
|
||||
fn addrs(&self) -> &[SocketAddr] {
|
||||
&self.addrs[..]
|
||||
}
|
||||
fn ed_identity(&self) -> &pk::ed25519::Ed25519Identity {
|
||||
&self.ed_identity
|
||||
}
|
||||
fn rsa_identity(&self) -> &pk::rsa::RsaIdentity {
|
||||
&self.rsa_identity
|
||||
}
|
||||
}
|
||||
|
||||
impl TargetInfo {
|
||||
/// Construct a TargetInfo from a given ChanTarget.
|
||||
pub(crate) fn from_chan_target<C>(target: &C) -> Self
|
||||
where
|
||||
C: ChanTarget + ?Sized,
|
||||
{
|
||||
TargetInfo {
|
||||
addrs: target.addrs().to_vec(),
|
||||
ed_identity: *target.ed_identity(),
|
||||
rsa_identity: *target.rsa_identity(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn targetinfo() {
|
||||
let ti = TargetInfo {
|
||||
addrs: vec!["127.0.0.1:11".parse().unwrap()],
|
||||
ed_identity: [42; 32].into(),
|
||||
rsa_identity: [45; 20].into(),
|
||||
};
|
||||
|
||||
let ti2 = TargetInfo::from_chan_target(&ti);
|
||||
assert_eq!(ti.addrs, ti2.addrs);
|
||||
assert_eq!(ti.ed_identity, ti2.ed_identity);
|
||||
assert_eq!(ti.rsa_identity, ti2.rsa_identity);
|
||||
}
|
||||
}
|
|
@ -1,101 +0,0 @@
|
|||
//! Trait and implementation for a "Connector" type.
|
||||
//!
|
||||
//! The `Connector` trait is internal to the tor-chanmgr crate, and
|
||||
//! helps us avoid having `ChanMgr` be polymorphic on transport type.
|
||||
//! Instead, it can hold a boxed Connector.
|
||||
|
||||
use crate::transport::Transport;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use tor_linkspec::ChanTarget;
|
||||
use tor_llcrypto::pk;
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::testing::{FakeChannel as Channel, FakeChannelBuilder as ChannelBuilder};
|
||||
#[cfg(not(test))]
|
||||
use tor_proto::channel::{Channel, ChannelBuilder};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::task::{Spawn, SpawnExt};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A Connector knows how to make a channel given the summarized information
|
||||
/// from a ChanTarget.
|
||||
#[async_trait]
|
||||
pub(crate) trait Connector {
|
||||
/// Create a new channel to `target`, trying exactly once, not timing out.
|
||||
async fn build_channel(
|
||||
&self,
|
||||
runtime: &(dyn Spawn + Sync),
|
||||
target: &TargetInfo,
|
||||
) -> Result<Arc<Channel>>;
|
||||
}
|
||||
|
||||
// Every Transport is automatically a Connector.
|
||||
#[async_trait]
|
||||
impl<TR: Transport + Send + Sync> Connector for TR {
|
||||
async fn build_channel(
|
||||
&self,
|
||||
runtime: &(dyn Spawn + Sync),
|
||||
target: &TargetInfo,
|
||||
) -> Result<Arc<Channel>> {
|
||||
use tor_rtcompat::tls::CertifiedConn;
|
||||
let (addr, tls) = self.connect(target).await?;
|
||||
|
||||
let peer_cert = tls
|
||||
.peer_certificate()?
|
||||
.ok_or(Error::Internal("TLS connection with no peer certificate"))?;
|
||||
|
||||
let mut builder = ChannelBuilder::new();
|
||||
builder.set_declared_addr(addr);
|
||||
let chan = builder.launch(tls).connect().await?;
|
||||
let chan = chan.check(target, &peer_cert)?;
|
||||
let (chan, reactor) = chan.finish().await?;
|
||||
|
||||
runtime.spawn(async {
|
||||
let _ = reactor.run().await;
|
||||
})?;
|
||||
Ok(chan)
|
||||
}
|
||||
}
|
||||
|
||||
/// TargetInfo is a summary of a [`ChanTarget`] that we can pass to
|
||||
/// [`Connector::build_channel`].
|
||||
///
|
||||
/// This is a separate type since we can't declare Connector as having
|
||||
/// a method that's parameterized in today's Rust.
|
||||
pub(crate) struct TargetInfo {
|
||||
/// Copy of the addresses from the underlying ChanTarget.
|
||||
addrs: Vec<SocketAddr>,
|
||||
/// Copy of the ed25519 id from the underlying ChanTarget.
|
||||
ed_identity: pk::ed25519::Ed25519Identity,
|
||||
/// Copy of the rsa id from the underlying ChanTarget.
|
||||
rsa_identity: pk::rsa::RsaIdentity,
|
||||
}
|
||||
|
||||
impl ChanTarget for TargetInfo {
|
||||
fn addrs(&self) -> &[SocketAddr] {
|
||||
&self.addrs[..]
|
||||
}
|
||||
fn ed_identity(&self) -> &pk::ed25519::Ed25519Identity {
|
||||
&self.ed_identity
|
||||
}
|
||||
fn rsa_identity(&self) -> &pk::rsa::RsaIdentity {
|
||||
&self.rsa_identity
|
||||
}
|
||||
}
|
||||
|
||||
impl TargetInfo {
|
||||
/// Construct a TargetInfo from a given ChanTarget.
|
||||
pub(crate) fn from_chan_target<C>(target: &C) -> Self
|
||||
where
|
||||
C: ChanTarget + ?Sized,
|
||||
{
|
||||
TargetInfo {
|
||||
addrs: target.addrs().to_vec(),
|
||||
ed_identity: *target.ed_identity(),
|
||||
rsa_identity: *target.rsa_identity(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,10 +17,6 @@ pub enum Error {
|
|||
#[error("Channel timed out")]
|
||||
ChanTimeout,
|
||||
|
||||
/// An internal error or assumption violation in the TLS implementation.
|
||||
#[error("Invalid TLS connection")]
|
||||
InvalidTls,
|
||||
|
||||
/// A protocol error while making a channel
|
||||
#[error("Protocol error while opening a channel: {0}")]
|
||||
Proto(#[from] tor_proto::Error),
|
||||
|
@ -32,6 +28,10 @@ pub enum Error {
|
|||
/// An internal error of some kind that should never occur.
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(&'static str),
|
||||
|
||||
/// We were waiting for a channel to complete, but it failed.
|
||||
#[error("Pending channel failed to open: {0}")]
|
||||
PendingChanFailed(#[from] PendingChanError),
|
||||
}
|
||||
|
||||
impl From<futures::task::SpawnError> for Error {
|
||||
|
@ -39,3 +39,30 @@ impl From<futures::task::SpawnError> for Error {
|
|||
Error::Internal("Couldn't spawn channel reactor")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tor_rtcompat::TimeoutError> for Error {
|
||||
fn from(_: tor_rtcompat::TimeoutError) -> Error {
|
||||
Error::ChanTimeout
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for Error {
|
||||
fn from(_: std::sync::PoisonError<T>) -> Error {
|
||||
Error::Internal("Thread failed while holding lock")
|
||||
}
|
||||
}
|
||||
|
||||
/// An error transmitted by a future that trying to build a channel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingChanError(String);
|
||||
impl std::error::Error for PendingChanError {}
|
||||
impl std::fmt::Display for PendingChanError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
impl From<&Error> for PendingChanError {
|
||||
fn from(e: &Error) -> PendingChanError {
|
||||
PendingChanError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,31 +11,17 @@
|
|||
#![deny(missing_docs)]
|
||||
#![deny(clippy::missing_docs_in_private_items)]
|
||||
|
||||
mod connect;
|
||||
mod builder;
|
||||
mod err;
|
||||
#[cfg(test)]
|
||||
pub(crate) mod testing;
|
||||
pub mod transport;
|
||||
|
||||
use crate::connect::{Connector, TargetInfo};
|
||||
use crate::transport::Transport;
|
||||
mod mgr;
|
||||
|
||||
use tor_linkspec::ChanTarget;
|
||||
use tor_llcrypto::pk::ed25519::Ed25519Identity;
|
||||
|
||||
#[cfg(test)]
|
||||
use testing::FakeChannel as Channel;
|
||||
#[cfg(not(test))]
|
||||
use tor_proto::channel::Channel;
|
||||
|
||||
use futures::lock::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use err::Error;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tor_rtcompat::{Runtime, SleepProviderExt};
|
||||
use tor_rtcompat::Runtime;
|
||||
|
||||
/// A Result as returned by this crate.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
@ -46,63 +32,16 @@ pub type Result<T> = std::result::Result<T, Error>;
|
|||
/// Use the [ChanMgr::get_or_launch] function to craete a new channel, or
|
||||
/// get one if it exists.
|
||||
pub struct ChanMgr<R: Runtime> {
|
||||
/// Map from Ed25519 identity to channel state.
|
||||
///
|
||||
/// Note that eventually we might want to have this be only
|
||||
/// _canonical_ connections (those whose address matches the
|
||||
/// relay's official address) and we might want this to be indexed
|
||||
/// by pluggable transport too. But since right now only
|
||||
/// client-initiated channels are supported, and pluggable
|
||||
/// transports are not supported, this structure is fine.
|
||||
///
|
||||
/// Note that other Channels may exist that are not indexed here.
|
||||
channels: Mutex<HashMap<Ed25519Identity, ChannelState>>,
|
||||
|
||||
/// Object used to create TLS connections to relays.
|
||||
connector: Box<dyn Connector + Sync + Send + 'static>,
|
||||
|
||||
/// DOCDOC
|
||||
runtime: R,
|
||||
}
|
||||
|
||||
/// Possible states for a managed channel
|
||||
enum ChannelState {
|
||||
/// The channel is open, authenticated, and canonical: we can give
|
||||
/// it out as needed.
|
||||
Open(Arc<Channel>),
|
||||
/// Some task is building the channel, and will notify all
|
||||
/// listeners on this event on success or failure.
|
||||
Building(Arc<event_listener::Event>),
|
||||
/// Internal channel manager object that does the actual work.
|
||||
mgr: mgr::AbstractChanMgr<builder::ChanBuilder<R>>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> ChanMgr<R> {
|
||||
/// Construct a new channel manager. It will use `transport` to construct
|
||||
/// TLS streams, and `spawn` to launch reactor tasks.
|
||||
pub fn new<TR>(runtime: R, transport: TR) -> Self
|
||||
where
|
||||
TR: Transport + Send + Sync + 'static,
|
||||
{
|
||||
let connector = Box::new(transport);
|
||||
ChanMgr {
|
||||
channels: Mutex::new(HashMap::new()),
|
||||
connector,
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: Return the channel if it matches the target; otherwise
|
||||
/// return an error.
|
||||
///
|
||||
/// We need to do this check since it's theoretically possible for
|
||||
/// a channel to (for example) match the Ed25519 key of the
|
||||
/// target, but not the RSA key.
|
||||
fn check_chan_match<T: ChanTarget + ?Sized>(
|
||||
&self,
|
||||
target: &T,
|
||||
ch: Arc<Channel>,
|
||||
) -> Result<Arc<Channel>> {
|
||||
ch.check_match(target)?;
|
||||
Ok(ch)
|
||||
/// Construct a new channel manager.
|
||||
pub fn new(runtime: R) -> Self {
|
||||
let builder = builder::ChanBuilder::new(runtime);
|
||||
let mgr = mgr::AbstractChanMgr::new(builder);
|
||||
ChanMgr { mgr }
|
||||
}
|
||||
|
||||
/// Try to get a suitable channel to the provided `target`,
|
||||
|
@ -113,328 +52,12 @@ impl<R: Runtime> ChanMgr<R> {
|
|||
/// or fail depending on its outcome.
|
||||
pub async fn get_or_launch<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<Arc<Channel>> {
|
||||
let ed_identity = target.ed_identity();
|
||||
use ChannelState::*;
|
||||
let targetinfo = builder::TargetInfo::from_chan_target(target);
|
||||
|
||||
// Look up the current cache entry.
|
||||
let (should_launch, event) = {
|
||||
let mut channels = self.channels.lock().await;
|
||||
let state = channels.get(ed_identity);
|
||||
|
||||
match state {
|
||||
Some(Open(ch)) => {
|
||||
if ch.is_closing() {
|
||||
// duplicate with below. XXXXX
|
||||
let e = Arc::new(event_listener::Event::new());
|
||||
let state = Building(Arc::clone(&e));
|
||||
channels.insert(*ed_identity, state);
|
||||
(true, e)
|
||||
} else {
|
||||
return self.check_chan_match(target, Arc::clone(ch));
|
||||
}
|
||||
}
|
||||
Some(Building(e)) => (false, Arc::clone(e)),
|
||||
None => {
|
||||
let e = Arc::new(event_listener::Event::new());
|
||||
let state = Building(Arc::clone(&e));
|
||||
channels.insert(*ed_identity, state);
|
||||
(true, e)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if should_launch {
|
||||
let result = self.build_channel(target).await;
|
||||
{
|
||||
let mut channels = self.channels.lock().await;
|
||||
match &result {
|
||||
Ok(ch) => {
|
||||
channels.insert(*ed_identity, Open(Arc::clone(ch)));
|
||||
}
|
||||
Err(_) => {
|
||||
channels.remove(ed_identity);
|
||||
}
|
||||
}
|
||||
}
|
||||
event.notify(usize::MAX);
|
||||
result
|
||||
} else {
|
||||
event.listen().await;
|
||||
let chan = self
|
||||
.get_nowait_by_ed_id(ed_identity)
|
||||
.await
|
||||
.ok_or(Error::PendingFailed)?;
|
||||
self.check_chan_match(target, chan)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: construct a new channel for a target.
|
||||
async fn build_channel<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<Arc<Channel>> {
|
||||
// XXXX make this a parameter.
|
||||
let timeout = Duration::new(5, 0);
|
||||
|
||||
let result = self
|
||||
.runtime
|
||||
.timeout(timeout, self.build_channel_once(target))
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(chan)) => Ok(chan),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_) => Err(Error::ChanTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: construct a new channel for a target, trying only once,
|
||||
/// and not timing out.
|
||||
async fn build_channel_once<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<Arc<Channel>> {
|
||||
let target = TargetInfo::from_chan_target(target);
|
||||
self.connector.build_channel(&self.runtime, &target).await
|
||||
}
|
||||
|
||||
/// Helper: Get the Channel with the given Ed25519 identity, if there
|
||||
/// is one.
|
||||
async fn get_nowait_by_ed_id(&self, ed_id: &Ed25519Identity) -> Option<Arc<Channel>> {
|
||||
use ChannelState::*;
|
||||
let channels = self.channels.lock().await;
|
||||
match channels.get(ed_id) {
|
||||
Some(Open(ch)) => Some(Arc::clone(ch)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
use tor_llcrypto::pk::rsa::RsaIdentity;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::io::{AsyncRead, AsyncWrite};
|
||||
use futures::join;
|
||||
use futures::task::Context;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
|
||||
use tor_rtcompat::test_with_runtime;
|
||||
|
||||
struct FakeTransport;
|
||||
struct FakeConnection;
|
||||
|
||||
#[async_trait]
|
||||
impl crate::transport::Transport for FakeTransport {
|
||||
type Connection = FakeConnection;
|
||||
async fn connect<T: ChanTarget + Sync + ?Sized>(
|
||||
&self,
|
||||
t: &T,
|
||||
) -> Result<(std::net::SocketAddr, FakeConnection)> {
|
||||
let addr = t.addrs().get(0).unwrap();
|
||||
if addr.port() == 1337 {
|
||||
Err(Error::UnusableTarget("too leet!".into()).into())
|
||||
} else {
|
||||
Ok((*addr, FakeConnection))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tor_rtcompat::tls::CertifiedConn for FakeConnection {
|
||||
fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
|
||||
Ok(Some(vec![]))
|
||||
}
|
||||
}
|
||||
impl AsyncRead for FakeConnection {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut [u8],
|
||||
) -> Poll<std::result::Result<usize, std::io::Error>> {
|
||||
Poll::Ready(Ok(0))
|
||||
}
|
||||
}
|
||||
impl AsyncWrite for FakeConnection {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<std::result::Result<usize, std::io::Error>> {
|
||||
Poll::Ready(Ok(0))
|
||||
}
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct Target {
|
||||
addr: [std::net::SocketAddr; 1],
|
||||
ed_id: Ed25519Identity,
|
||||
rsa_id: RsaIdentity,
|
||||
}
|
||||
impl ChanTarget for Target {
|
||||
fn addrs(&self) -> &[SocketAddr] {
|
||||
&self.addr[..]
|
||||
}
|
||||
fn ed_identity(&self) -> &Ed25519Identity {
|
||||
&self.ed_id
|
||||
}
|
||||
fn rsa_identity(&self) -> &RsaIdentity {
|
||||
&self.rsa_id
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_one_ok() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let mgr = ChanMgr::new(runtime, FakeTransport);
|
||||
let target = Target {
|
||||
addr: ["127.0.0.1:443".parse().unwrap()],
|
||||
ed_id: [3; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
let chan1 = mgr.get_or_launch(&target).await.unwrap();
|
||||
let chan2 = mgr.get_or_launch(&target).await.unwrap();
|
||||
|
||||
assert!(chan1.same_channel(&chan2));
|
||||
|
||||
{
|
||||
let channels = mgr.channels.lock().await;
|
||||
let entry = channels.get(&[3; 32].into());
|
||||
match entry {
|
||||
Some(ChannelState::Open(c)) => assert!(c.same_channel(&chan1)),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
let chan3 = mgr.get_nowait_by_ed_id(&[3; 32].into()).await;
|
||||
assert!(chan3.unwrap().same_channel(&chan1));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_one_fail() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let mgr = ChanMgr::new(runtime, FakeTransport);
|
||||
// port 1337 is set up to always fail in FakeTransport.
|
||||
let target = Target {
|
||||
addr: ["127.0.0.1:1337".parse().unwrap()],
|
||||
ed_id: [3; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
|
||||
let res1 = mgr.get_or_launch(&target).await;
|
||||
assert!(matches!(res1.unwrap_err(), Error::UnusableTarget(_)));
|
||||
|
||||
// port 8686 is set up to always fail in FakeTransport.
|
||||
let target = Target {
|
||||
addr: ["127.0.0.1:8686".parse().unwrap()],
|
||||
ed_id: [4; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
|
||||
let res1 = mgr.get_or_launch(&target).await;
|
||||
assert!(matches!(res1.unwrap_err(), Error::Proto(_)));
|
||||
|
||||
let chan3 = mgr.get_nowait_by_ed_id(&[4; 32].into()).await;
|
||||
assert!(chan3.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let mgr = ChanMgr::new(runtime, FakeTransport);
|
||||
let target3 = Target {
|
||||
addr: ["127.0.0.1:99".parse().unwrap()],
|
||||
ed_id: [3; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
let target44 = Target {
|
||||
addr: ["127.0.0.2:99".parse().unwrap()],
|
||||
ed_id: [44; 32].into(), // note different ed key.
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
let target_bad = Target {
|
||||
addr: ["127.0.0.1:8686".parse().unwrap()],
|
||||
ed_id: [66; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
|
||||
// TODO XXXX: figure out how to make these actually run
|
||||
// concurrently. Right now it seems that they don't actually
|
||||
// interact.
|
||||
let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!(
|
||||
mgr.get_or_launch(&target3),
|
||||
mgr.get_or_launch(&target3),
|
||||
mgr.get_or_launch(&target44),
|
||||
mgr.get_or_launch(&target44),
|
||||
mgr.get_or_launch(&target_bad),
|
||||
mgr.get_or_launch(&target_bad),
|
||||
);
|
||||
let ch3a = ch3a.unwrap();
|
||||
let ch3b = ch3b.unwrap();
|
||||
let ch44a = ch44a.unwrap();
|
||||
let ch44b = ch44b.unwrap();
|
||||
let err_a = ch86a.unwrap_err();
|
||||
let err_b = ch86b.unwrap_err();
|
||||
|
||||
assert!(ch3a.same_channel(&ch3b));
|
||||
assert!(ch44a.same_channel(&ch44b));
|
||||
assert!(!ch3a.same_channel(&ch44b));
|
||||
|
||||
assert!(matches!(err_a, Error::Proto(_)));
|
||||
assert!(matches!(err_b, Error::Proto(_)));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stall() {
|
||||
test_with_runtime(|runtime| async {
|
||||
use futures::FutureExt;
|
||||
|
||||
let mgr = ChanMgr::new(runtime, FakeTransport);
|
||||
let target = Target {
|
||||
addr: ["127.0.0.1:99".parse().unwrap()],
|
||||
ed_id: [12; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut channels = mgr.channels.lock().await;
|
||||
let e = Arc::new(event_listener::Event::new());
|
||||
let state = ChannelState::Building(Arc::clone(&e));
|
||||
channels.insert([12; 32].into(), state);
|
||||
}
|
||||
|
||||
let h = mgr.get_or_launch(&target);
|
||||
|
||||
assert!(h.now_or_never().is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_two_closing() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let mgr = ChanMgr::new(runtime, FakeTransport);
|
||||
let target = Target {
|
||||
addr: ["127.0.0.1:443".parse().unwrap()],
|
||||
ed_id: [3; 32].into(),
|
||||
rsa_id: [2; 20].into(),
|
||||
};
|
||||
let chan1 = mgr.get_or_launch(&target).await.unwrap();
|
||||
chan1.mark_closing();
|
||||
let chan2 = mgr.get_or_launch(&target).await.unwrap();
|
||||
|
||||
assert!(!chan1.same_channel(&chan2));
|
||||
});
|
||||
let chan = self.mgr.get_or_launch(*ed_identity, targetinfo).await?;
|
||||
// Double-check the match to make sure that the RSA identity is
|
||||
// what we wanted too.
|
||||
chan.check_match(target)?;
|
||||
Ok(chan)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,404 @@
|
|||
//! Abstract implementation of a channel maanger
|
||||
|
||||
#![allow(dead_code)]
|
||||
use crate::err::PendingChanError;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::channel::oneshot;
|
||||
use futures::future::{FutureExt, Shared};
|
||||
use std::hash::Hash;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod map;
|
||||
|
||||
/// Trait to describe as much of a
|
||||
/// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr`
|
||||
/// needs to use.
|
||||
pub(crate) trait AbstractChannel {
|
||||
/// Identity type for the other side of the channel.
|
||||
type Ident: Hash + Eq + Clone;
|
||||
/// Return this channel's identity.
|
||||
fn ident(&self) -> &Self::Ident;
|
||||
/// Return true if this channel is usable.
|
||||
///
|
||||
/// A channel might be unusable because it is closed, because it has
|
||||
/// hit a bug, or for some other reason. We don't return unusable
|
||||
/// channels back to the user.
|
||||
fn is_usable(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Trait to describe how channels are created.
|
||||
#[async_trait]
|
||||
pub(crate) trait ChannelFactory {
|
||||
/// The type of channel that this factory can build.
|
||||
type Channel: AbstractChannel;
|
||||
/// Type that explains how to build a channel.
|
||||
type BuildSpec;
|
||||
|
||||
/// Construct a new channel to the destination described at `target`.
|
||||
///
|
||||
/// This function must take care of all timeouts, error detection,
|
||||
/// and so on.
|
||||
///
|
||||
/// It should not retry; that is handled at a higher level.
|
||||
async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<Self::Channel>>;
|
||||
}
|
||||
|
||||
/// A type- and network-agnostic implementation for
|
||||
/// [`ChanMgr`](crate::ChanMgr).
|
||||
///
|
||||
/// This type does the work of keeping track of open channels and
|
||||
/// pending channel requests, launching requests as needed, waiting
|
||||
/// for pending requests, and so forth.
|
||||
///
|
||||
/// The actual job of launching conenctions is deferred to a ChannelFactory
|
||||
/// type.
|
||||
pub(crate) struct AbstractChanMgr<CF: ChannelFactory> {
|
||||
/// A 'connector' object that we use to create channels.
|
||||
connector: CF,
|
||||
|
||||
/// A map from ed25519 identity to channel, or to pending channel status.
|
||||
channels: map::ChannelMap<CF::Channel>,
|
||||
}
|
||||
|
||||
/// A Result whose error is a [`PendingChanError`].
|
||||
///
|
||||
/// We need a separate type here because [`Error`] doesn't implement `Clone`.
|
||||
type PendResult<T> = std::result::Result<T, PendingChanError>;
|
||||
|
||||
/// Type alias for a future that we wait on to see when a pending
|
||||
/// channel is done or failed.
|
||||
type Pending<C> = Shared<oneshot::Receiver<PendResult<Arc<C>>>>;
|
||||
|
||||
/// Type alias for the sender we notify when we complete a channel (or
|
||||
/// fail to complete it).
|
||||
type Sending<C> = oneshot::Sender<PendResult<Arc<C>>>;
|
||||
|
||||
impl<CF: ChannelFactory> AbstractChanMgr<CF> {
|
||||
/// Make a new empty channel manager.
|
||||
pub(crate) fn new(connector: CF) -> Self {
|
||||
AbstractChanMgr {
|
||||
connector,
|
||||
channels: map::ChannelMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove every unusable entry from this channel manager.
|
||||
pub fn remove_unusable_entries(&self) -> Result<()> {
|
||||
self.channels.remove_unusable()
|
||||
}
|
||||
|
||||
/// Helper: return the objects used to inform pending tasks
|
||||
/// about a newly open or failed channel.
|
||||
fn setup_launch<C>(&self) -> (map::ChannelState<C>, Sending<C>) {
|
||||
let (snd, rcv) = oneshot::channel();
|
||||
let shared = rcv.shared();
|
||||
(map::ChannelState::Building(shared), snd)
|
||||
}
|
||||
|
||||
/// Get a channel whose identity is `ident`.
|
||||
///
|
||||
/// If a usable channel exists with that identity, return it.
|
||||
///
|
||||
/// If no such channel exists already, and none is in progress,
|
||||
/// launch a new request using `target`, which must match `ident`.
|
||||
///
|
||||
/// If no such channel exists already, but we have one that's in
|
||||
/// progress, wait for it to succeed or fail.
|
||||
pub async fn get_or_launch(
|
||||
&self,
|
||||
ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
|
||||
target: CF::BuildSpec,
|
||||
) -> Result<Arc<CF::Channel>> {
|
||||
use map::ChannelState::*;
|
||||
|
||||
/// Possible actions that we'll decide to take based on the
|
||||
/// channel's initial state.
|
||||
enum Action<C> {
|
||||
/// We found no channel. We're going to launch a new one,
|
||||
/// then tell everybody about it.
|
||||
Launch(Sending<C>),
|
||||
/// We found an in-progress attempt at making a channel.
|
||||
/// We're going to wait for it to finish.
|
||||
Wait(Pending<C>),
|
||||
/// We found a usable channel. We're going to return it.
|
||||
Return(Result<Arc<C>>),
|
||||
}
|
||||
/// How many times do we try?
|
||||
const N_ATTEMPTS: usize = 2;
|
||||
|
||||
// XXXX It would be neat to use tor_retry instead, but it's
|
||||
// too tied to anyhow right now.
|
||||
let mut last_err = Err(Error::Internal("Error was never set!?"));
|
||||
|
||||
for _ in 0..N_ATTEMPTS {
|
||||
// First, see what state we're in, and what we should do
|
||||
// about it.
|
||||
let action = self
|
||||
.channels
|
||||
.change_state(&ident, |oldstate| match oldstate {
|
||||
Some(Open(ref ch)) => {
|
||||
if ch.is_usable() {
|
||||
// Good channel. Return it.
|
||||
let action = Action::Return(Ok(Arc::clone(ch)));
|
||||
(oldstate, action)
|
||||
} else {
|
||||
// Unusable channel. Move to the Building
|
||||
// state and launch a new channel.
|
||||
let (newstate, send) = self.setup_launch();
|
||||
let action = Action::Launch(send);
|
||||
(Some(newstate), action)
|
||||
}
|
||||
}
|
||||
Some(Building(ref pending)) => {
|
||||
let action = Action::Wait(pending.clone());
|
||||
(oldstate, action)
|
||||
}
|
||||
Some(Poisoned(_)) => {
|
||||
// We should never be able to see this state; this
|
||||
// is a bug.
|
||||
(
|
||||
None,
|
||||
Action::Return(Err(Error::Internal("Found a poisoned entry"))),
|
||||
)
|
||||
}
|
||||
None => {
|
||||
// No channel. Move to the Building
|
||||
// state and launch a new channel.
|
||||
let (newstate, send) = self.setup_launch();
|
||||
let action = Action::Launch(send);
|
||||
(Some(newstate), action)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Now we act based on the channel.
|
||||
match action {
|
||||
// Easy case: we have an error or a channel to return.
|
||||
Action::Return(v) => {
|
||||
return v;
|
||||
}
|
||||
// There's an in-progress channel. Wait for it.
|
||||
Action::Wait(pend) => match pend.await {
|
||||
Ok(Ok(chan)) => return Ok(chan),
|
||||
Ok(Err(e)) => {
|
||||
last_err = Err(e.into());
|
||||
}
|
||||
Err(_) => {
|
||||
last_err = Err(Error::Internal("channel build task disappeared"));
|
||||
}
|
||||
},
|
||||
// We need to launch a channel.
|
||||
Action::Launch(send) => match self.connector.build_channel(&target).await {
|
||||
Ok(chan) => {
|
||||
// The channel got built: remember it, tell the
|
||||
// others, and return it.
|
||||
self.channels
|
||||
.replace(ident.clone(), Open(Arc::clone(&chan)))?;
|
||||
// It's okay if all the receivers went away:
|
||||
// that means that nobody was waiting for this channel.
|
||||
let _ignore_err = send.send(Ok(Arc::clone(&chan)));
|
||||
return Ok(chan);
|
||||
}
|
||||
Err(e) => {
|
||||
// The channel failed. Make it non-pending, tell the
|
||||
// others, and set the error.
|
||||
self.channels.remove(&ident)?;
|
||||
// (As above)
|
||||
let _ignore_err = send.send(Err((&e).into()));
|
||||
last_err = Err(e);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
last_err
|
||||
}
|
||||
|
||||
/// Test only: return the current open usable channel with a given
|
||||
/// `ident`, if any.
|
||||
#[cfg(test)]
|
||||
pub fn get_nowait(
|
||||
&self,
|
||||
ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
|
||||
) -> Option<Arc<CF::Channel>> {
|
||||
use map::ChannelState::*;
|
||||
match self.channels.get(ident) {
|
||||
Ok(Some(Open(ref ch))) if ch.is_usable() => Some(Arc::clone(ch)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::Error;
|
||||
|
||||
use futures::join;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use tor_rtcompat::{task::yield_now, test_with_runtime, Runtime};
|
||||
|
||||
struct FakeChannelFactory<RT> {
|
||||
runtime: RT,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FakeChannel {
|
||||
ident: u32,
|
||||
mood: char,
|
||||
closing: AtomicBool,
|
||||
}
|
||||
|
||||
impl AbstractChannel for FakeChannel {
|
||||
type Ident = u32;
|
||||
fn ident(&self) -> &u32 {
|
||||
&self.ident
|
||||
}
|
||||
fn is_usable(&self) -> bool {
|
||||
!self.closing.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeChannel {
|
||||
fn start_closing(&self) {
|
||||
self.closing.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl<RT: Runtime> FakeChannelFactory<RT> {
|
||||
fn new(runtime: RT) -> Self {
|
||||
FakeChannelFactory { runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<RT: Runtime> ChannelFactory for FakeChannelFactory<RT> {
|
||||
type Channel = FakeChannel;
|
||||
type BuildSpec = (u32, char);
|
||||
|
||||
async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<FakeChannel>> {
|
||||
yield_now().await;
|
||||
let (ident, mood) = *target;
|
||||
match mood {
|
||||
// "X" means never connect.
|
||||
'❌' | '🔥' => return Err(Error::UnusableTarget("emoji".into())),
|
||||
// "zzz" means wait for 15 seconds then succeed.
|
||||
'💤' => {
|
||||
self.runtime.sleep(Duration::new(15, 0)).await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(Arc::new(FakeChannel {
|
||||
ident,
|
||||
mood,
|
||||
closing: AtomicBool::new(false),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_one_ok() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let cf = FakeChannelFactory::new(runtime);
|
||||
let mgr = AbstractChanMgr::new(cf);
|
||||
let target = (413, '!');
|
||||
let chan1 = mgr.get_or_launch(413, target.clone()).await.unwrap();
|
||||
let chan2 = mgr.get_or_launch(413, target.clone()).await.unwrap();
|
||||
|
||||
assert!(Arc::ptr_eq(&chan1, &chan2));
|
||||
|
||||
let chan3 = mgr.get_nowait(&413).unwrap();
|
||||
assert!(Arc::ptr_eq(&chan1, &chan3));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_one_fail() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let cf = FakeChannelFactory::new(runtime);
|
||||
let mgr = AbstractChanMgr::new(cf);
|
||||
|
||||
// This is set up to always fail.
|
||||
let target = (999, '❌');
|
||||
let res1 = mgr.get_or_launch(999, target).await;
|
||||
dbg!(&res1);
|
||||
assert!(matches!(res1, Err(Error::UnusableTarget(_))));
|
||||
|
||||
let chan3 = mgr.get_nowait(&999);
|
||||
assert!(chan3.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let cf = FakeChannelFactory::new(runtime);
|
||||
let mgr = AbstractChanMgr::new(cf);
|
||||
|
||||
// TODO XXXX: figure out how to make these actually run
|
||||
// concurrently. Right now it seems that they don't actually
|
||||
// interact.
|
||||
let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!(
|
||||
mgr.get_or_launch(3, (3, 'a')),
|
||||
mgr.get_or_launch(3, (3, 'b')),
|
||||
mgr.get_or_launch(44, (44, 'a')),
|
||||
mgr.get_or_launch(44, (44, 'b')),
|
||||
mgr.get_or_launch(86, (86, '❌')),
|
||||
mgr.get_or_launch(86, (86, '🔥')),
|
||||
);
|
||||
let ch3a = ch3a.unwrap();
|
||||
let ch3b = ch3b.unwrap();
|
||||
let ch44a = ch44a.unwrap();
|
||||
let ch44b = ch44b.unwrap();
|
||||
let err_a = ch86a.unwrap_err();
|
||||
let err_b = ch86b.unwrap_err();
|
||||
|
||||
assert!(Arc::ptr_eq(&ch3a, &ch3b));
|
||||
assert!(Arc::ptr_eq(&ch44a, &ch44b));
|
||||
assert!(!Arc::ptr_eq(&ch44a, &ch3a));
|
||||
|
||||
assert!(matches!(
|
||||
err_a,
|
||||
Error::UnusableTarget(_) | Error::PendingChanFailed(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
err_b,
|
||||
Error::UnusableTarget(_) | Error::PendingChanFailed(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unusable_entries() {
|
||||
test_with_runtime(|runtime| async {
|
||||
let cf = FakeChannelFactory::new(runtime);
|
||||
let mgr = AbstractChanMgr::new(cf);
|
||||
|
||||
let (ch3, ch4, ch5) = join!(
|
||||
mgr.get_or_launch(3, (3, 'a')),
|
||||
mgr.get_or_launch(4, (4, 'a')),
|
||||
mgr.get_or_launch(5, (5, 'a')),
|
||||
);
|
||||
|
||||
let ch3 = ch3.unwrap();
|
||||
let _ch4 = ch4.unwrap();
|
||||
let ch5 = ch5.unwrap();
|
||||
|
||||
ch3.start_closing();
|
||||
ch5.start_closing();
|
||||
|
||||
let ch3_new = mgr.get_or_launch(3, (3, 'b')).await.unwrap();
|
||||
assert!(!Arc::ptr_eq(&ch3, &ch3_new));
|
||||
assert_eq!(ch3_new.mood, 'b');
|
||||
|
||||
mgr.remove_unusable_entries().unwrap();
|
||||
|
||||
assert!(mgr.get_nowait(&3).is_some());
|
||||
assert!(mgr.get_nowait(&4).is_some());
|
||||
assert!(mgr.get_nowait(&5).is_none());
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
//! Simple implementaiton for the internal map state of a ChanMgr.
|
||||
|
||||
use super::{AbstractChannel, Pending};
|
||||
use crate::{Error, Result};
|
||||
|
||||
use std::collections::{hash_map, HashMap};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A map from channel id to channel state.
|
||||
///
|
||||
/// We make this a separate type instead of just using
|
||||
/// `Mutex<HashMap<...>>` to limit the amount of code that can see and
|
||||
/// lock the Mutex here. (We're using a blocking mutex close to async
|
||||
/// code, so we need to be careful.)
|
||||
pub(crate) struct ChannelMap<C: AbstractChannel> {
|
||||
/// A map from identity to channel, or to pending channel status.
|
||||
///
|
||||
/// (Danger: this uses a blocking mutex close to async code. This mutex
|
||||
/// must never be held while an await is happening.)
|
||||
channels: std::sync::Mutex<HashMap<C::Ident, ChannelState<C>>>,
|
||||
}
|
||||
|
||||
/// Structure that can only be constructed from within this module.
|
||||
/// Used to make sure that only we can construct ChannelState::Poisoned.
|
||||
pub(crate) struct Priv {
|
||||
/// (This field is private)
|
||||
_unused: (),
|
||||
}
|
||||
|
||||
/// The state of a channel (or channel build attempt) within a map.
|
||||
pub(crate) enum ChannelState<C> {
|
||||
/// An open channel.
|
||||
///
|
||||
/// This channel might not be usable: it might be closing or
|
||||
/// broken. We need to check its is_usable() method before
|
||||
/// yielding it to the user.
|
||||
Open(Arc<C>),
|
||||
/// A channel that's getting built.
|
||||
Building(Pending<C>),
|
||||
/// A temporary invalid state.
|
||||
///
|
||||
/// We insert this into the map temporarily as a placeholder in
|
||||
/// `change_state()`.
|
||||
Poisoned(Priv),
|
||||
}
|
||||
|
||||
impl<C> ChannelState<C> {
|
||||
/// Create a new shallow copy of this ChannelState.
|
||||
fn clone_ref(&self) -> Self {
|
||||
use ChannelState::*;
|
||||
match self {
|
||||
Open(chan) => Open(Arc::clone(chan)),
|
||||
Building(pending) => Building(pending.clone()),
|
||||
Poisoned(_) => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// For testing: either give the Open channnel inside this state,
|
||||
/// or panic if there is none.
|
||||
#[cfg(test)]
|
||||
fn unwrap_open(&self) -> Arc<C> {
|
||||
match self {
|
||||
ChannelState::Open(chan) => Arc::clone(chan),
|
||||
_ => panic!("Not an open channel"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: AbstractChannel> ChannelState<C> {
|
||||
/// Return an error if `ident`is definitely not a matching
|
||||
/// matching identity for this state.
|
||||
fn check_ident(&self, ident: &C::Ident) -> Result<()> {
|
||||
match self {
|
||||
ChannelState::Open(chan) => {
|
||||
if chan.ident() == ident {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Internal("Identity mismatch"))
|
||||
}
|
||||
}
|
||||
ChannelState::Poisoned(_) => Err(Error::Internal("Poisoned state in channel map")),
|
||||
ChannelState::Building(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: AbstractChannel> ChannelMap<C> {
|
||||
/// Create a new empty ChannelMap.
|
||||
pub(crate) fn new() -> Self {
|
||||
ChannelMap {
|
||||
channels: std::sync::Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the channel state for the given identity, if any.
|
||||
pub(crate) fn get(&self, ident: &C::Ident) -> Result<Option<ChannelState<C>>> {
|
||||
let map = self.channels.lock()?;
|
||||
Ok(map.get(ident).map(ChannelState::clone_ref))
|
||||
}
|
||||
|
||||
/// Replace the channel state for `ident` with `newval`, and return the
|
||||
/// previous value if any.
|
||||
pub(crate) fn replace(
|
||||
&self,
|
||||
ident: C::Ident,
|
||||
newval: ChannelState<C>,
|
||||
) -> Result<Option<ChannelState<C>>> {
|
||||
newval.check_ident(&ident)?;
|
||||
let mut map = self.channels.lock()?;
|
||||
Ok(map.insert(ident, newval))
|
||||
}
|
||||
|
||||
/// Remove and return the state for `ident`, if any.
|
||||
pub(crate) fn remove(&self, ident: &C::Ident) -> Result<Option<ChannelState<C>>> {
|
||||
let mut map = self.channels.lock()?;
|
||||
Ok(map.remove(ident))
|
||||
}
|
||||
|
||||
/// Remove every unusable state from the map.
|
||||
pub(crate) fn remove_unusable(&self) -> Result<()> {
|
||||
let mut map = self.channels.lock()?;
|
||||
map.retain(|_, state| match state {
|
||||
ChannelState::Poisoned(_) => panic!(),
|
||||
ChannelState::Open(ch) => ch.is_usable(),
|
||||
ChannelState::Building(_) => true,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replace the state whose identity is `ident` with a new state.
|
||||
///
|
||||
/// The provided function `func` is invoked on the old state (if
|
||||
/// any), and must return a tuple containing an optional new
|
||||
/// state, and an arbitrary return value for this function.
|
||||
///
|
||||
/// Because `func` is run while holding the lock on this object,
|
||||
/// it should be fast and nonblocking. In return, you can be sure
|
||||
/// that it's running atomically with respect to other accessesors
|
||||
/// of this map.
|
||||
///
|
||||
/// If `func` panics, this map will become unusable and future
|
||||
/// accesses will fail.
|
||||
pub(crate) fn change_state<F, V>(&self, ident: &C::Ident, func: F) -> Result<V>
|
||||
where
|
||||
F: FnOnce(Option<ChannelState<C>>) -> (Option<ChannelState<C>>, V),
|
||||
{
|
||||
use hash_map::Entry::*;
|
||||
let mut map = self.channels.lock()?;
|
||||
let entry = map.entry(ident.clone());
|
||||
match entry {
|
||||
Occupied(mut occupied) => {
|
||||
// DOCDOC explain what's up here.
|
||||
let mut oldent = ChannelState::Poisoned(Priv { _unused: () });
|
||||
std::mem::swap(occupied.get_mut(), &mut oldent);
|
||||
let (newval, output) = func(Some(oldent));
|
||||
match newval {
|
||||
Some(mut newent) => {
|
||||
newent.check_ident(ident)?; // XXX leaves it poisoned
|
||||
std::mem::swap(occupied.get_mut(), &mut newent);
|
||||
}
|
||||
None => {
|
||||
occupied.remove();
|
||||
}
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
Vacant(vacant) => {
|
||||
let (newval, output) = func(None);
|
||||
if let Some(newent) = newval {
|
||||
newent.check_ident(ident)?;
|
||||
vacant.insert(newent);
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
struct FakeChannel {
|
||||
ident: &'static str,
|
||||
usable: bool,
|
||||
}
|
||||
impl AbstractChannel for FakeChannel {
|
||||
type Ident = u8;
|
||||
fn ident(&self) -> &Self::Ident {
|
||||
&self.ident.as_bytes()[0]
|
||||
}
|
||||
fn is_usable(&self) -> bool {
|
||||
self.usable
|
||||
}
|
||||
}
|
||||
fn ch(ident: &'static str) -> ChannelState<FakeChannel> {
|
||||
ChannelState::Open(Arc::new(FakeChannel {
|
||||
ident,
|
||||
usable: true,
|
||||
}))
|
||||
}
|
||||
fn closed(ident: &'static str) -> ChannelState<FakeChannel> {
|
||||
ChannelState::Open(Arc::new(FakeChannel {
|
||||
ident,
|
||||
usable: false,
|
||||
}))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_ops() {
|
||||
let map = ChannelMap::new();
|
||||
use ChannelState::Open;
|
||||
|
||||
assert!(map.replace(b'h', ch("hello")).unwrap().is_none());
|
||||
assert!(map.replace(b'w', ch("wello")).unwrap().is_none());
|
||||
|
||||
match map.get(&b'h') {
|
||||
Ok(Some(Open(chan))) if chan.ident == "hello" => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
|
||||
assert!(map.get(&b'W').unwrap().is_none());
|
||||
|
||||
match map.replace(b'h', ch("hebbo")) {
|
||||
Ok(Some(Open(chan))) if chan.ident == "hello" => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
|
||||
assert!(map.remove(&b'Z').unwrap().is_none());
|
||||
match map.remove(&b'h') {
|
||||
Ok(Some(Open(chan))) if chan.ident == "hebbo" => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rmv_unusable() {
|
||||
let map = ChannelMap::new();
|
||||
|
||||
map.replace(b'm', closed("machen")).unwrap();
|
||||
map.replace(b'f', ch("feinen")).unwrap();
|
||||
map.replace(b'w', closed("wir")).unwrap();
|
||||
map.replace(b'F', ch("Fug")).unwrap();
|
||||
|
||||
map.remove_unusable().unwrap();
|
||||
|
||||
assert!(map.get(&b'm').unwrap().is_none());
|
||||
assert!(map.get(&b'w').unwrap().is_none());
|
||||
assert!(map.get(&b'f').unwrap().is_some());
|
||||
assert!(map.get(&b'F').unwrap().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn change() {
|
||||
let map = ChannelMap::new();
|
||||
|
||||
map.replace(b'w', ch("wir")).unwrap();
|
||||
map.replace(b'm', ch("machen")).unwrap();
|
||||
map.replace(b'f', ch("feinen")).unwrap();
|
||||
map.replace(b'F', ch("Fug")).unwrap();
|
||||
|
||||
// Replace Some with Some.
|
||||
let (old, v) = map
|
||||
.change_state(&b'F', |state| (Some(ch("FUG")), (state, 99_u8)))
|
||||
.unwrap();
|
||||
assert_eq!(old.unwrap().unwrap_open().ident, "Fug");
|
||||
assert_eq!(v, 99);
|
||||
assert_eq!(map.get(&b'F').unwrap().unwrap().unwrap_open().ident, "FUG");
|
||||
|
||||
// Replace Some with None.
|
||||
let (old, v) = map
|
||||
.change_state(&b'f', |state| (None, (state, 123_u8)))
|
||||
.unwrap();
|
||||
assert_eq!(old.unwrap().unwrap_open().ident, "feinen");
|
||||
assert_eq!(v, 123);
|
||||
assert!(map.get(&b'f').unwrap().is_none());
|
||||
|
||||
// Replace None with Some.
|
||||
let (old, v) = map
|
||||
.change_state(&b'G', |state| (Some(ch("Geheimnisse")), (state, "Hi")))
|
||||
.unwrap();
|
||||
assert!(old.is_none());
|
||||
assert_eq!(v, "Hi");
|
||||
assert_eq!(
|
||||
map.get(&b'G').unwrap().unwrap().unwrap_open().ident,
|
||||
"Geheimnisse"
|
||||
);
|
||||
|
||||
// Replace None with None
|
||||
let (old, v) = map
|
||||
.change_state(&b'Q', |state| (None, (state, "---")))
|
||||
.unwrap();
|
||||
assert!(old.is_none());
|
||||
assert_eq!(v, "---");
|
||||
assert!(map.get(&b'Q').unwrap().is_none());
|
||||
}
|
||||
}
|
|
@ -1,92 +0,0 @@
|
|||
//! Testing stubs for the chahannel manager code. Only enabled
|
||||
//! with `cfg(test)`.
|
||||
|
||||
#![allow(missing_docs)]
|
||||
#![allow(clippy::missing_docs_in_private_items)]
|
||||
|
||||
use crate::{Error, Result};
|
||||
use tor_linkspec::ChanTarget;
|
||||
use tor_llcrypto::pk::rsa::RsaIdentity;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FakeChannel {
|
||||
chan: FakeChannelInner,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct FakeChannelInner {
|
||||
closing: AtomicBool,
|
||||
want_rsa_id: Option<RsaIdentity>,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct FakeChannelBuilder {
|
||||
addr: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct FakeReactor {}
|
||||
|
||||
impl FakeChannelBuilder {
|
||||
pub fn new() -> Self {
|
||||
FakeChannelBuilder { addr: None }
|
||||
}
|
||||
pub fn set_declared_addr(&mut self, addr: SocketAddr) {
|
||||
self.addr = Some(addr);
|
||||
}
|
||||
pub fn launch<T>(self, _ignore: T) -> FakeChannel {
|
||||
FakeChannel::new(self.addr.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeChannel {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
let inner = FakeChannelInner {
|
||||
closing: false.into(),
|
||||
want_rsa_id: None,
|
||||
addr,
|
||||
};
|
||||
FakeChannel { chan: inner }
|
||||
}
|
||||
pub async fn connect(self) -> Result<Self> {
|
||||
Ok(self)
|
||||
}
|
||||
pub fn check<T: ChanTarget + ?Sized>(self, _target: &T, _cert: &[u8]) -> Result<Self> {
|
||||
if self.chan.addr.port() == 8686 {
|
||||
Err(tor_proto::Error::ChanProto("86ed".into()).into())
|
||||
} else {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
pub fn same_channel(self: &Arc<Self>, other: &Arc<FakeChannel>) -> bool {
|
||||
Arc::ptr_eq(self, other)
|
||||
}
|
||||
pub(crate) async fn finish(self) -> Result<(Arc<Self>, FakeReactor)> {
|
||||
Ok((Arc::new(self), FakeReactor {}))
|
||||
}
|
||||
pub fn is_closing(&self) -> bool {
|
||||
self.chan.closing.load(Ordering::SeqCst)
|
||||
}
|
||||
pub fn mark_closing(&self) {
|
||||
self.chan.closing.store(true, Ordering::SeqCst)
|
||||
}
|
||||
pub fn check_match<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<()> {
|
||||
if let Some(ref id) = self.chan.want_rsa_id {
|
||||
if id != target.rsa_identity() {
|
||||
return Err(Error::UnusableTarget("Wrong RSA".into()).into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeReactor {
|
||||
pub async fn run(self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
//! Types for launching TLS connections to relays
|
||||
//!
|
||||
//! TODO: Perhaps this type is no longer needed, and we should just
|
||||
//! use the TlsConnector trait in tor_rtcompat.
|
||||
|
||||
pub mod nativetls;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
use tor_linkspec::ChanTarget;
|
||||
use tor_rtcompat::tls::CertifiedConn;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// A Transport knows how to build a TLS connection to a relay, in a way
|
||||
/// that Tor can use.
|
||||
///
|
||||
/// Tor doesn't expect to get any particular hostname or sequence of
|
||||
/// certificates in the reply; it only expects that the peer certificate
|
||||
/// will later be authenticated inside the Tor handshake.
|
||||
#[async_trait]
|
||||
pub trait Transport {
|
||||
/// The type that will be returned by this transport. This should
|
||||
/// be an asynchronous TLS connection.
|
||||
type Connection: AsyncRead + AsyncWrite + Send + Unpin + CertifiedConn + 'static;
|
||||
|
||||
/// Try to connect to a given relay.
|
||||
async fn connect<T: ChanTarget + Sync + ?Sized>(
|
||||
&self,
|
||||
target: &T,
|
||||
) -> Result<(std::net::SocketAddr, Self::Connection)>;
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
//! Build TLS connections using the async_native_tls or tokio_native_tls crate.
|
||||
|
||||
// XXXX-A2 This should get refactored significantly. Probably we should have
|
||||
// a boxed-connection-factory type that we can use instead. Once we have a
|
||||
// pluggable designn, we'll really need something like that.
|
||||
//
|
||||
// Probably, much of this code should move into tor-rtcompat, or a new
|
||||
// crate similar to tor-rtcompat, that can handle our TLS drama.
|
||||
|
||||
use super::Transport;
|
||||
use crate::{Error, Result};
|
||||
use tor_linkspec::ChanTarget;
|
||||
use tor_rtcompat::tls::TlsConnector;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use log::info;
|
||||
|
||||
/// A Transport that uses a connector based on native_tls.
|
||||
pub struct NativeTlsTransport<C: TlsConnector> {
|
||||
/// connector object used to build TLS connections
|
||||
connector: C,
|
||||
}
|
||||
|
||||
impl<C: TlsConnector> NativeTlsTransport<C> {
|
||||
/// Construct a new NativeTlsTransport.
|
||||
pub fn new(connector: C) -> Result<Self> {
|
||||
Ok(NativeTlsTransport { connector })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<C: TlsConnector + Send + Sync + Unpin> Transport for NativeTlsTransport<C> {
|
||||
type Connection = C::Conn;
|
||||
|
||||
async fn connect<T: ChanTarget + Sync + ?Sized>(
|
||||
&self,
|
||||
target: &T,
|
||||
) -> Result<(std::net::SocketAddr, Self::Connection)> {
|
||||
// TODO: This just uses the first address. Instead we could be smarter,
|
||||
// or use "happy eyeballs, or whatever. Maybe we will want to
|
||||
// refactor as we do so?
|
||||
let addr = target
|
||||
.addrs()
|
||||
.get(0)
|
||||
.ok_or_else(|| Error::UnusableTarget("No addresses for chosen relay".into()))?;
|
||||
|
||||
info!("Negotiating TLS with {}", addr);
|
||||
|
||||
// TODO: add a random hostname here if it will be used for SNI?
|
||||
let connection = self.connector.connect_unvalidated(addr, "ignored").await?;
|
||||
Ok((*addr, connection))
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@
|
|||
//! To construct a client, run the `TorClient::bootstrap()` method.
|
||||
//! Once the client is bootstrapped, you can make connections over the Tor
|
||||
//! network using `TorClient::connect()`.
|
||||
use tor_chanmgr::transport::nativetls::NativeTlsTransport;
|
||||
use tor_circmgr::TargetPort;
|
||||
use tor_dirmgr::NetDirConfig;
|
||||
use tor_proto::circuit::IpVersionPreference;
|
||||
|
@ -78,11 +77,7 @@ impl<R: Runtime> TorClient<R> {
|
|||
/// Return a client once there is enough directory material to
|
||||
/// connect safely over the Tor network.
|
||||
pub async fn bootstrap(runtime: R, dircfg: NetDirConfig) -> Result<TorClient<R>> {
|
||||
let transport = {
|
||||
let connector = runtime.tls_connector();
|
||||
NativeTlsTransport::new(connector)?
|
||||
};
|
||||
let chanmgr = Arc::new(tor_chanmgr::ChanMgr::new(runtime.clone(), transport));
|
||||
let chanmgr = Arc::new(tor_chanmgr::ChanMgr::new(runtime.clone()));
|
||||
let circmgr = Arc::new(tor_circmgr::CircMgr::new(
|
||||
runtime.clone(),
|
||||
Arc::clone(&chanmgr),
|
||||
|
|
|
@ -236,6 +236,11 @@ impl Channel {
|
|||
self.unique_id
|
||||
}
|
||||
|
||||
/// Return the Ed25519 identity for the peer of this channel.
|
||||
pub fn peer_ed25519_id(&self) -> &Ed25519Identity {
|
||||
&self.ed25519_id
|
||||
}
|
||||
|
||||
/// Return an error if this channel is somehow mismatched with the
|
||||
/// given target.
|
||||
pub fn check_match<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<()> {
|
||||
|
|
Loading…
Reference in New Issue