From 68dc2f56ceec8aa0eaabeb0394d32be1a040dc58 Mon Sep 17 00:00:00 2001 From: Nick Mathewson Date: Thu, 22 Apr 2021 10:05:47 -0400 Subject: [PATCH 1/5] Begin refactoring ChanMgr to have better separation of concerns Additionally: Use futures::future::Shared instead of event_listener. --- tor-chanmgr/src/err.rs | 25 +++++++ tor-chanmgr/src/lib.rs | 1 + tor-chanmgr/src/mgr.rs | 142 +++++++++++++++++++++++++++++++++++++ tor-chanmgr/src/mgr/map.rs | 122 +++++++++++++++++++++++++++++++ 4 files changed, 290 insertions(+) create mode 100644 tor-chanmgr/src/mgr.rs create mode 100644 tor-chanmgr/src/mgr/map.rs diff --git a/tor-chanmgr/src/err.rs b/tor-chanmgr/src/err.rs index a6d280ec3..0563945d1 100644 --- a/tor-chanmgr/src/err.rs +++ b/tor-chanmgr/src/err.rs @@ -32,6 +32,10 @@ pub enum Error { /// An internal error of some kind that should never occur. #[error("Internal error: {0}")] Internal(&'static str), + + /// We were waiting for a channel to complete, but it failed. + #[error("Pending channel failed to open: {0}")] + PendingChanFailed(#[from] PendingChanError), } impl From for Error { @@ -39,3 +43,24 @@ impl From for Error { Error::Internal("Couldn't spawn channel reactor") } } + +impl From> for Error { + fn from(_: std::sync::PoisonError) -> Error { + Error::Internal("Thread failed while holding lock") + } +} + +/// An error transmitted by a future that trying to build a channel. +#[derive(Debug, Clone)] +pub struct PendingChanError(String); +impl std::error::Error for PendingChanError {} +impl std::fmt::Display for PendingChanError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} +impl From for PendingChanError { + fn from(e: Error) -> PendingChanError { + PendingChanError(e.to_string()) + } +} diff --git a/tor-chanmgr/src/lib.rs b/tor-chanmgr/src/lib.rs index 7c7415c70..fda663ebd 100644 --- a/tor-chanmgr/src/lib.rs +++ b/tor-chanmgr/src/lib.rs @@ -13,6 +13,7 @@ mod connect; mod err; +mod mgr; #[cfg(test)] pub(crate) mod testing; pub mod transport; diff --git a/tor-chanmgr/src/mgr.rs b/tor-chanmgr/src/mgr.rs new file mode 100644 index 000000000..f7750f54a --- /dev/null +++ b/tor-chanmgr/src/mgr.rs @@ -0,0 +1,142 @@ +#![allow(unused)] +#![allow(missing_docs)] +#![allow(clippy::missing_docs_in_private_items)] + +use crate::err::PendingChanError; +use crate::Result; +use crate::TargetInfo; + +use async_trait::async_trait; +use futures::channel::oneshot; +use futures::future::{FutureExt, Shared}; +use futures::task::Spawn; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Arc; +use tor_llcrypto::pk::ed25519::Ed25519Identity; +use tor_rtcompat::Runtime; + +mod map; + +pub(crate) trait AbstractChannel { + type Ident: Hash + Eq + Clone; + fn ident(&self) -> &Self::Ident; + fn is_usable(&self) -> bool; +} + +#[async_trait] +pub(crate) trait ChannelFactory { + type Channel: AbstractChannel; + type BuildSpec; + + async fn build_channel( + &self, + runtime: &(dyn Spawn + Sync), + target: &Self::BuildSpec, + ) -> Result>; +} + +pub(crate) struct AbstractChannelMgr { + /// Abstract runtime, used to launch tasks, create network + /// connections via TCP and TLS, and run timeouts. + runtime: RT, + + /// A 'connector' object that we use to create channels. + connector: CF, + + /// A map from ed25519 identity to channel, or to pending channel status. + channels: map::ChannelMap, +} + +type PendResult = std::result::Result; + +type Pending = Shared>>>; +type Sending = oneshot::Sender>>; + +impl AbstractChannelMgr { + pub(crate) fn new(runtime: RT, connector: CF) -> Self { + AbstractChannelMgr { + runtime, + connector, + channels: map::ChannelMap::new(), + } + } + + pub fn remove_unusable_entries(&self) -> Result<()> { + self.channels.remove_unusable() + } + + fn setup_launch(&self) -> (map::ChannelState, Sending) { + let (snd, rcv) = oneshot::channel(); + let shared = rcv.shared(); + (map::ChannelState::Building(shared), snd) + } + + pub async fn get_or_launch( + &self, + ident: <::Channel as AbstractChannel>::Ident, + target: CF::BuildSpec, + ) -> Result> { + use map::ChannelState::*; + enum Action { + Launch(Sending), + Wait(Pending), + Return(Arc), + } + const N_ATTEMPTS: usize = 2; + + 'retry: for _ in 0..N_ATTEMPTS { + let action = self + .channels + .change_state(&ident, |oldstate| match oldstate { + Some(Open(ref ch)) => { + if ch.is_usable() { + let action = Action::Return(Arc::clone(ch)); + (oldstate, action) + } else { + let (newstate, send) = self.setup_launch(); + let action = Action::Launch(send); + (Some(newstate), action) + } + } + Some(Building(ref pending)) => { + let action = Action::Wait(pending.clone()); + (oldstate, action) + } + Some(Poisoned(_)) => panic!(), + None => { + let (newstate, send) = self.setup_launch(); + let action = Action::Launch(send); + (Some(newstate), action) + } + })?; + + match action { + Action::Return(v) => { + return Ok(v); + } + Action::Wait(pend) => match pend.await { + Ok(Ok(chan)) => return Ok(chan), + _ => continue 'retry, + }, + Action::Launch(send) => { + match self.connector.build_channel(&self.runtime, &target).await { + Ok(chan) => { + self.channels + .replace(ident.clone(), Open(Arc::clone(&chan)))?; + send.send(Ok(Arc::clone(&chan))); + return Ok(chan); + } + Err(e) => { + self.channels.remove(&ident)?; + send.send(Err(e.into())); + continue 'retry; + } + } + } + } + } + + Err(crate::Error::ChanTimeout) // not quite right. XXXX + } +} diff --git a/tor-chanmgr/src/mgr/map.rs b/tor-chanmgr/src/mgr/map.rs new file mode 100644 index 000000000..86f2cb0b7 --- /dev/null +++ b/tor-chanmgr/src/mgr/map.rs @@ -0,0 +1,122 @@ +use super::{AbstractChannel, Pending}; +use crate::Result; + +use tor_llcrypto::pk::ed25519::Ed25519Identity; + +use std::collections::{hash_map, HashMap}; +use std::sync::Arc; + +pub(crate) struct ChannelMap { + /// A map from identity to channel, or to pending channel status. + /// + /// (Danger: this uses a blocking mutex close to async code. This mutex + /// must never be held while an await is happening.) + channels: std::sync::Mutex>>, +} + +// used to ensure that only this module can construct a ChannelState::Poisoned. +pub struct Priv { + _unused: (), +} + +pub(crate) enum ChannelState { + Open(Arc), + Building(Pending), + // XXXX explain what this is for. + Poisoned(Priv), +} + +impl ChannelState { + pub(super) fn clone_ref(&self) -> Self { + use ChannelState::*; + match self { + Open(chan) => Open(Arc::clone(chan)), + Building(pending) => Building(pending.clone()), + Poisoned(_) => panic!(), + } + } +} + +impl ChannelState { + /// DOCDOC returns true if identity COULD BE `ident` + fn check_ident(&self, ident: &C::Ident) -> bool { + match self { + ChannelState::Open(chan) => chan.ident() == ident, + ChannelState::Poisoned(_) => false, + ChannelState::Building(_) => true, + } + } +} + +impl ChannelMap { + pub(crate) fn new() -> Self { + ChannelMap { + channels: std::sync::Mutex::new(HashMap::new()), + } + } + + pub(crate) fn get(&self, ident: &C::Ident) -> Result>> { + let map = self.channels.lock()?; + Ok(map.get(ident).map(ChannelState::clone_ref)) + } + + pub(crate) fn replace( + &self, + ident: C::Ident, + newval: ChannelState, + ) -> Result>> { + assert!(newval.check_ident(&ident)); + let mut map = self.channels.lock()?; + Ok(map.insert(ident, newval)) + } + + pub(crate) fn remove(&self, ident: &C::Ident) -> Result>> { + let mut map = self.channels.lock()?; + Ok(map.remove(ident)) + } + + pub(crate) fn remove_unusable(&self) -> Result<()> { + let mut map = self.channels.lock()?; + map.retain(|_, state| match state { + ChannelState::Poisoned(_) => panic!(), + ChannelState::Open(ch) => ch.is_usable(), + ChannelState::Building(_) => true, + }); + Ok(()) + } + + pub(crate) fn change_state(&self, ident: &C::Ident, func: F) -> Result + where + F: FnOnce(Option>) -> (Option>, V), + { + use hash_map::Entry::*; + let mut map = self.channels.lock()?; + let mut entry = map.entry(ident.clone()); + match entry { + Occupied(mut occupied) => { + // DOCDOC explain what's up here. + let mut oldent = ChannelState::Poisoned(Priv { _unused: () }); + std::mem::swap(occupied.get_mut(), &mut oldent); + let (newval, output) = func(Some(oldent)); + match newval { + Some(mut newent) => { + assert!(newent.check_ident(ident)); + std::mem::swap(occupied.get_mut(), &mut newent); + } + None => { + occupied.remove(); + } + }; + Ok(output) + } + Vacant(vacant) => { + let (newval, output) = func(None); + if let Some(newent) = newval { + assert!(newent.check_ident(ident)); + vacant.insert(newent); + } + Ok(output) + } + } + } +} From e3db6678f65f9e7787cd96dc695683b0fbbbe232 Mon Sep 17 00:00:00 2001 From: Nick Mathewson Date: Thu, 22 Apr 2021 15:48:05 -0400 Subject: [PATCH 2/5] Start on tests for reactored ChanMsg. --- tor-chanmgr/src/err.rs | 8 +- tor-chanmgr/src/mgr.rs | 262 +++++++++++++++++++++++++++++++------ tor-chanmgr/src/mgr/map.rs | 134 ++++++++++++++++++- 3 files changed, 352 insertions(+), 52 deletions(-) diff --git a/tor-chanmgr/src/err.rs b/tor-chanmgr/src/err.rs index 0563945d1..77951195a 100644 --- a/tor-chanmgr/src/err.rs +++ b/tor-chanmgr/src/err.rs @@ -17,10 +17,6 @@ pub enum Error { #[error("Channel timed out")] ChanTimeout, - /// An internal error or assumption violation in the TLS implementation. - #[error("Invalid TLS connection")] - InvalidTls, - /// A protocol error while making a channel #[error("Protocol error while opening a channel: {0}")] Proto(#[from] tor_proto::Error), @@ -59,8 +55,8 @@ impl std::fmt::Display for PendingChanError { write!(f, "{}", self.0) } } -impl From for PendingChanError { - fn from(e: Error) -> PendingChanError { +impl From<&Error> for PendingChanError { + fn from(e: &Error) -> PendingChanError { PendingChanError(e.to_string()) } } diff --git a/tor-chanmgr/src/mgr.rs b/tor-chanmgr/src/mgr.rs index f7750f54a..dd059a39c 100644 --- a/tor-chanmgr/src/mgr.rs +++ b/tor-chanmgr/src/mgr.rs @@ -1,20 +1,12 @@ -#![allow(unused)] -#![allow(missing_docs)] -#![allow(clippy::missing_docs_in_private_items)] - +#![allow(dead_code)] use crate::err::PendingChanError; -use crate::Result; -use crate::TargetInfo; +use crate::{Error, Result}; use async_trait::async_trait; use futures::channel::oneshot; use futures::future::{FutureExt, Shared}; -use futures::task::Spawn; -use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; -use tor_llcrypto::pk::ed25519::Ed25519Identity; -use tor_rtcompat::Runtime; mod map; @@ -29,18 +21,10 @@ pub(crate) trait ChannelFactory { type Channel: AbstractChannel; type BuildSpec; - async fn build_channel( - &self, - runtime: &(dyn Spawn + Sync), - target: &Self::BuildSpec, - ) -> Result>; + async fn build_channel(&self, target: &Self::BuildSpec) -> Result>; } -pub(crate) struct AbstractChannelMgr { - /// Abstract runtime, used to launch tasks, create network - /// connections via TCP and TLS, and run timeouts. - runtime: RT, - +pub(crate) struct AbstractChannelMgr { /// A 'connector' object that we use to create channels. connector: CF, @@ -48,15 +32,14 @@ pub(crate) struct AbstractChannelMgr { channels: map::ChannelMap, } -type PendResult = std::result::Result; +type PendResult = std::result::Result; type Pending = Shared>>>; type Sending = oneshot::Sender>>; -impl AbstractChannelMgr { - pub(crate) fn new(runtime: RT, connector: CF) -> Self { +impl AbstractChannelMgr { + pub(crate) fn new(connector: CF) -> Self { AbstractChannelMgr { - runtime, connector, channels: map::ChannelMap::new(), } @@ -85,7 +68,11 @@ impl AbstractChannelMgr { } const N_ATTEMPTS: usize = 2; - 'retry: for _ in 0..N_ATTEMPTS { + // XXXX It would be neat to use tor_retry instead, but it's + // too tied to anyhow right now. + let mut last_err = Err(Error::Internal("Error was never set!?")); + + for _ in 0..N_ATTEMPTS { let action = self .channels .change_state(&ident, |oldstate| match oldstate { @@ -117,26 +104,217 @@ impl AbstractChannelMgr { } Action::Wait(pend) => match pend.await { Ok(Ok(chan)) => return Ok(chan), - _ => continue 'retry, - }, - Action::Launch(send) => { - match self.connector.build_channel(&self.runtime, &target).await { - Ok(chan) => { - self.channels - .replace(ident.clone(), Open(Arc::clone(&chan)))?; - send.send(Ok(Arc::clone(&chan))); - return Ok(chan); - } - Err(e) => { - self.channels.remove(&ident)?; - send.send(Err(e.into())); - continue 'retry; - } + Ok(Err(e)) => { + last_err = Err(e.into()); } - } + Err(_) => { + last_err = Err(Error::Internal("channel build task disappeared")); + } + }, + Action::Launch(send) => match self.connector.build_channel(&target).await { + Ok(chan) => { + self.channels + .replace(ident.clone(), Open(Arc::clone(&chan)))?; + // It's okay if all the receivers went away: + // that means that nobody was waiting for this channel. + let _ignore_err = send.send(Ok(Arc::clone(&chan))); + return Ok(chan); + } + Err(e) => { + self.channels.remove(&ident)?; + // (As above) + let _ignore_err = send.send(Err((&e).into())); + last_err = Err(e); + } + }, } } - Err(crate::Error::ChanTimeout) // not quite right. XXXX + last_err + } + + #[cfg(test)] + pub fn get_nowait( + &self, + ident: &<::Channel as AbstractChannel>::Ident, + ) -> Option> { + use map::ChannelState::*; + match self.channels.get(ident) { + Ok(Some(Open(ref ch))) if ch.is_usable() => Some(Arc::clone(ch)), + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::Error; + + use futures::join; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::Duration; + + use tor_rtcompat::{task::yield_now, test_with_runtime, Runtime}; + + struct FakeChannelFactory { + runtime: RT, + } + + #[derive(Debug)] + struct FakeChannel { + ident: u32, + mood: char, + closing: AtomicBool, + } + + impl AbstractChannel for FakeChannel { + type Ident = u32; + fn ident(&self) -> &u32 { + &self.ident + } + fn is_usable(&self) -> bool { + !self.closing.load(Ordering::SeqCst) + } + } + + impl FakeChannel { + fn start_closing(&self) { + self.closing.store(true, Ordering::SeqCst); + } + } + + impl FakeChannelFactory { + fn new(runtime: RT) -> Self { + FakeChannelFactory { runtime } + } + } + + #[async_trait] + impl ChannelFactory for FakeChannelFactory { + type Channel = FakeChannel; + type BuildSpec = (u32, char); + + async fn build_channel(&self, target: &Self::BuildSpec) -> Result> { + yield_now().await; + let (ident, mood) = *target; + match mood { + // "X" means never connect. + '❌' | '🔥' => return Err(Error::UnusableTarget("emoji".into())), + // "zzz" means wait for 15 seconds then succeed. + '💤' => { + self.runtime.sleep(Duration::new(15, 0)).await; + } + _ => {} + } + Ok(Arc::new(FakeChannel { + ident, + mood, + closing: AtomicBool::new(false), + })) + } + } + + #[test] + fn connect_one_ok() { + test_with_runtime(|runtime| async { + let cf = FakeChannelFactory::new(runtime); + let mgr = AbstractChannelMgr::new(cf); + let target = (413, '!'); + let chan1 = mgr.get_or_launch(413, target.clone()).await.unwrap(); + let chan2 = mgr.get_or_launch(413, target.clone()).await.unwrap(); + + assert!(Arc::ptr_eq(&chan1, &chan2)); + + let chan3 = mgr.get_nowait(&413).unwrap(); + assert!(Arc::ptr_eq(&chan1, &chan3)); + }); + } + + #[test] + fn connect_one_fail() { + test_with_runtime(|runtime| async { + let cf = FakeChannelFactory::new(runtime); + let mgr = AbstractChannelMgr::new(cf); + + // This is set up to always fail. + let target = (999, '❌'); + let res1 = mgr.get_or_launch(999, target).await; + dbg!(&res1); + assert!(matches!(res1, Err(Error::UnusableTarget(_)))); + + let chan3 = mgr.get_nowait(&999); + assert!(chan3.is_none()); + }); + } + + #[test] + fn test_concurrent() { + test_with_runtime(|runtime| async { + let cf = FakeChannelFactory::new(runtime); + let mgr = AbstractChannelMgr::new(cf); + + // TODO XXXX: figure out how to make these actually run + // concurrently. Right now it seems that they don't actually + // interact. + let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!( + mgr.get_or_launch(3, (3, 'a')), + mgr.get_or_launch(3, (3, 'b')), + mgr.get_or_launch(44, (44, 'a')), + mgr.get_or_launch(44, (44, 'b')), + mgr.get_or_launch(86, (86, '❌')), + mgr.get_or_launch(86, (86, '🔥')), + ); + let ch3a = ch3a.unwrap(); + let ch3b = ch3b.unwrap(); + let ch44a = ch44a.unwrap(); + let ch44b = ch44b.unwrap(); + let err_a = ch86a.unwrap_err(); + let err_b = ch86b.unwrap_err(); + + assert!(Arc::ptr_eq(&ch3a, &ch3b)); + assert!(Arc::ptr_eq(&ch44a, &ch44b)); + assert!(!Arc::ptr_eq(&ch44a, &ch3a)); + + assert!(matches!( + err_a, + Error::UnusableTarget(_) | Error::PendingChanFailed(_) + )); + assert!(matches!( + err_b, + Error::UnusableTarget(_) | Error::PendingChanFailed(_) + )); + }); + } + + #[test] + fn unusable_entries() { + test_with_runtime(|runtime| async { + let cf = FakeChannelFactory::new(runtime); + let mgr = AbstractChannelMgr::new(cf); + + let (ch3, ch4, ch5) = join!( + mgr.get_or_launch(3, (3, 'a')), + mgr.get_or_launch(4, (4, 'a')), + mgr.get_or_launch(5, (5, 'a')), + ); + + let ch3 = ch3.unwrap(); + let _ch4 = ch4.unwrap(); + let ch5 = ch5.unwrap(); + + ch3.start_closing(); + ch5.start_closing(); + + let ch3_new = mgr.get_or_launch(3, (3, 'b')).await.unwrap(); + assert!(!Arc::ptr_eq(&ch3, &ch3_new)); + assert_eq!(ch3_new.mood, 'b'); + + mgr.remove_unusable_entries().unwrap(); + + assert!(mgr.get_nowait(&3).is_some()); + assert!(mgr.get_nowait(&4).is_some()); + assert!(mgr.get_nowait(&5).is_none()); + }) } } diff --git a/tor-chanmgr/src/mgr/map.rs b/tor-chanmgr/src/mgr/map.rs index 86f2cb0b7..1ec75e8da 100644 --- a/tor-chanmgr/src/mgr/map.rs +++ b/tor-chanmgr/src/mgr/map.rs @@ -1,8 +1,6 @@ use super::{AbstractChannel, Pending}; use crate::Result; -use tor_llcrypto::pk::ed25519::Ed25519Identity; - use std::collections::{hash_map, HashMap}; use std::sync::Arc; @@ -27,7 +25,7 @@ pub(crate) enum ChannelState { } impl ChannelState { - pub(super) fn clone_ref(&self) -> Self { + fn clone_ref(&self) -> Self { use ChannelState::*; match self { Open(chan) => Open(Arc::clone(chan)), @@ -35,6 +33,14 @@ impl ChannelState { Poisoned(_) => panic!(), } } + + #[cfg(test)] + fn unwrap_open(&self) -> Arc { + match self { + ChannelState::Open(chan) => Arc::clone(chan), + _ => panic!("Not an oppen channel"), + } + } } impl ChannelState { @@ -91,7 +97,7 @@ impl ChannelMap { { use hash_map::Entry::*; let mut map = self.channels.lock()?; - let mut entry = map.entry(ident.clone()); + let entry = map.entry(ident.clone()); match entry { Occupied(mut occupied) => { // DOCDOC explain what's up here. @@ -120,3 +126,123 @@ impl ChannelMap { } } } + +#[cfg(test)] +mod test { + use super::*; + #[derive(Eq, PartialEq, Debug)] + struct FakeChannel { + ident: &'static str, + usable: bool, + } + impl AbstractChannel for FakeChannel { + type Ident = u8; + fn ident(&self) -> &Self::Ident { + &self.ident.as_bytes()[0] + } + fn is_usable(&self) -> bool { + self.usable + } + } + fn ch(ident: &'static str) -> ChannelState { + ChannelState::Open(Arc::new(FakeChannel { + ident, + usable: true, + })) + } + fn closed(ident: &'static str) -> ChannelState { + ChannelState::Open(Arc::new(FakeChannel { + ident, + usable: false, + })) + } + + #[test] + fn simple_ops() { + let map = ChannelMap::new(); + use ChannelState::Open; + + assert!(map.replace(b'h', ch("hello")).unwrap().is_none()); + assert!(map.replace(b'w', ch("wello")).unwrap().is_none()); + + match map.get(&b'h') { + Ok(Some(Open(chan))) if chan.ident == "hello" => {} + _ => panic!(), + } + + assert!(map.get(&b'W').unwrap().is_none()); + + match map.replace(b'h', ch("hebbo")) { + Ok(Some(Open(chan))) if chan.ident == "hello" => {} + _ => panic!(), + } + + assert!(map.remove(&b'Z').unwrap().is_none()); + match map.remove(&b'h') { + Ok(Some(Open(chan))) if chan.ident == "hebbo" => {} + _ => panic!(), + } + } + + #[test] + fn rmv_unusable() { + let map = ChannelMap::new(); + + map.replace(b'm', closed("machen")).unwrap(); + map.replace(b'f', ch("feinen")).unwrap(); + map.replace(b'w', closed("wir")).unwrap(); + map.replace(b'F', ch("Fug")).unwrap(); + + map.remove_unusable().unwrap(); + + assert!(map.get(&b'm').unwrap().is_none()); + assert!(map.get(&b'w').unwrap().is_none()); + assert!(map.get(&b'f').unwrap().is_some()); + assert!(map.get(&b'F').unwrap().is_some()); + } + + #[test] + fn change() { + let map = ChannelMap::new(); + + map.replace(b'w', ch("wir")).unwrap(); + map.replace(b'm', ch("machen")).unwrap(); + map.replace(b'f', ch("feinen")).unwrap(); + map.replace(b'F', ch("Fug")).unwrap(); + + // Replace Some with Some. + let (old, v) = map + .change_state(&b'F', |state| (Some(ch("FUG")), (state, 99_u8))) + .unwrap(); + assert_eq!(old.unwrap().unwrap_open().ident, "Fug"); + assert_eq!(v, 99); + assert_eq!(map.get(&b'F').unwrap().unwrap().unwrap_open().ident, "FUG"); + + // Replace Some with None. + let (old, v) = map + .change_state(&b'f', |state| (None, (state, 123_u8))) + .unwrap(); + assert_eq!(old.unwrap().unwrap_open().ident, "feinen"); + assert_eq!(v, 123); + assert!(map.get(&b'f').unwrap().is_none()); + + // Replace None with Some. + let (old, v) = map + .change_state(&b'G', |state| (Some(ch("Geheimnisse")), (state, "Hi"))) + .unwrap(); + assert!(old.is_none()); + assert_eq!(v, "Hi"); + assert_eq!( + map.get(&b'G').unwrap().unwrap().unwrap_open().ident, + "Geheimnisse" + ); + + // Replace None with None + let (old, v) = map + .change_state(&b'Q', |state| (None, (state, "---"))) + .unwrap(); + assert!(old.is_none()); + assert_eq!(v, "---"); + assert!(map.get(&b'Q').unwrap().is_none()); + } +} From 8c75da128ae4a3f7c21dc343747914c3a318bc1c Mon Sep 17 00:00:00 2001 From: Nick Mathewson Date: Fri, 23 Apr 2021 11:50:01 -0400 Subject: [PATCH 3/5] Full documentations on the new parts of ChanMgr. --- tor-chanmgr/src/mgr.rs | 106 +++++++++++++++++++++++++++++++++---- tor-chanmgr/src/mgr/map.rs | 75 +++++++++++++++++++++----- 2 files changed, 157 insertions(+), 24 deletions(-) diff --git a/tor-chanmgr/src/mgr.rs b/tor-chanmgr/src/mgr.rs index dd059a39c..7a39c1e70 100644 --- a/tor-chanmgr/src/mgr.rs +++ b/tor-chanmgr/src/mgr.rs @@ -1,3 +1,5 @@ +//! Abstract implementation of a channel maanger + #![allow(dead_code)] use crate::err::PendingChanError; use crate::{Error, Result}; @@ -10,21 +12,49 @@ use std::sync::Arc; mod map; +/// Trait to describe as much of a +/// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr` +/// needs to use. pub(crate) trait AbstractChannel { + /// Identity type for the other side of the channel. type Ident: Hash + Eq + Clone; + /// Return this channel's identity. fn ident(&self) -> &Self::Ident; + /// Return true if this channel is usable. + /// + /// A channel might be unusable because it is closed, because it has + /// hit a bug, or for some other reason. We don't return unusable + /// channels back to the user. fn is_usable(&self) -> bool; } +/// Trait to describe how channels are created. #[async_trait] pub(crate) trait ChannelFactory { + /// The type of channel that this factory can build. type Channel: AbstractChannel; + /// Type that explains how to build a channel. type BuildSpec; + /// Construct a new channel to the destination described at `target`. + /// + /// This function must take care of all timeouts, error detection, + /// and so on. + /// + /// It should not retry; that is handled at a higher level. async fn build_channel(&self, target: &Self::BuildSpec) -> Result>; } -pub(crate) struct AbstractChannelMgr { +/// A type- and network-agnostic implementation for +/// [`ChanMgr`](crate::ChanMgr). +/// +/// This type does the work of keeping track of open channels and +/// pending channel requests, launching requests as needed, waiting +/// for pending requests, and so forth. +/// +/// The actual job of launching conenctions is deferred to a ChannelFactory +/// type. +pub(crate) struct AbstractChanMgr { /// A 'connector' object that we use to create channels. connector: CF, @@ -32,40 +62,70 @@ pub(crate) struct AbstractChannelMgr { channels: map::ChannelMap, } +/// A Result whose error is a [`PendingChanError`]. +/// +/// We need a separate type here because [`Error`] doesn't implement `Clone`. type PendResult = std::result::Result; +/// Type alias for a future that we wait on to see when a pending +/// channel is done or failed. type Pending = Shared>>>; + +/// Type alias for the sender we notify when we complete a channel (or +/// fail to complete it). type Sending = oneshot::Sender>>; -impl AbstractChannelMgr { +impl AbstractChanMgr { + /// Make a new empty channel manager. pub(crate) fn new(connector: CF) -> Self { - AbstractChannelMgr { + AbstractChanMgr { connector, channels: map::ChannelMap::new(), } } + /// Remove every unusable entry from this channel manager. pub fn remove_unusable_entries(&self) -> Result<()> { self.channels.remove_unusable() } + /// Helper: return the objects used to inform pending tasks + /// about a newly open or failed channel. fn setup_launch(&self) -> (map::ChannelState, Sending) { let (snd, rcv) = oneshot::channel(); let shared = rcv.shared(); (map::ChannelState::Building(shared), snd) } + /// Get a channel whose identity is `ident`. + /// + /// If a usable channel exists with that identity, return it. + /// + /// If no such channel exists already, and none is in progress, + /// launch a new request using `target`, which must match `ident`. + /// + /// If no such channel exists already, but we have one that's in + /// progress, wait for it to succeed or fail. pub async fn get_or_launch( &self, ident: <::Channel as AbstractChannel>::Ident, target: CF::BuildSpec, ) -> Result> { use map::ChannelState::*; + + /// Possible actions that we'll decide to take based on the + /// channel's initial state. enum Action { + /// We found no channel. We're going to launch a new one, + /// then tell everybody about it. Launch(Sending), + /// We found an in-progress attempt at making a channel. + /// We're going to wait for it to finish. Wait(Pending), - Return(Arc), + /// We found a usable channel. We're going to return it. + Return(Result>), } + /// How many times do we try? const N_ATTEMPTS: usize = 2; // XXXX It would be neat to use tor_retry instead, but it's @@ -73,14 +133,19 @@ impl AbstractChannelMgr { let mut last_err = Err(Error::Internal("Error was never set!?")); for _ in 0..N_ATTEMPTS { + // First, see what state we're in, and what we should do + // about it. let action = self .channels .change_state(&ident, |oldstate| match oldstate { Some(Open(ref ch)) => { if ch.is_usable() { - let action = Action::Return(Arc::clone(ch)); + // Good channel. Return it. + let action = Action::Return(Ok(Arc::clone(ch))); (oldstate, action) } else { + // Unusable channel. Move to the Building + // state and launch a new channel. let (newstate, send) = self.setup_launch(); let action = Action::Launch(send); (Some(newstate), action) @@ -90,18 +155,30 @@ impl AbstractChannelMgr { let action = Action::Wait(pending.clone()); (oldstate, action) } - Some(Poisoned(_)) => panic!(), + Some(Poisoned(_)) => { + // We should never be able to see this state; this + // is a bug. + ( + None, + Action::Return(Err(Error::Internal("Found a poisoned entry"))), + ) + } None => { + // No channel. Move to the Building + // state and launch a new channel. let (newstate, send) = self.setup_launch(); let action = Action::Launch(send); (Some(newstate), action) } })?; + // Now we act based on the channel. match action { + // Easy case: we have an error or a channel to return. Action::Return(v) => { - return Ok(v); + return v; } + // There's an in-progress channel. Wait for it. Action::Wait(pend) => match pend.await { Ok(Ok(chan)) => return Ok(chan), Ok(Err(e)) => { @@ -111,8 +188,11 @@ impl AbstractChannelMgr { last_err = Err(Error::Internal("channel build task disappeared")); } }, + // We need to launch a channel. Action::Launch(send) => match self.connector.build_channel(&target).await { Ok(chan) => { + // The channel got built: remember it, tell the + // others, and return it. self.channels .replace(ident.clone(), Open(Arc::clone(&chan)))?; // It's okay if all the receivers went away: @@ -121,6 +201,8 @@ impl AbstractChannelMgr { return Ok(chan); } Err(e) => { + // The channel failed. Make it non-pending, tell the + // others, and set the error. self.channels.remove(&ident)?; // (As above) let _ignore_err = send.send(Err((&e).into())); @@ -133,6 +215,8 @@ impl AbstractChannelMgr { last_err } + /// Test only: return the current open usable channel with a given + /// `ident`, if any. #[cfg(test)] pub fn get_nowait( &self, @@ -219,7 +303,7 @@ mod test { fn connect_one_ok() { test_with_runtime(|runtime| async { let cf = FakeChannelFactory::new(runtime); - let mgr = AbstractChannelMgr::new(cf); + let mgr = AbstractChanMgr::new(cf); let target = (413, '!'); let chan1 = mgr.get_or_launch(413, target.clone()).await.unwrap(); let chan2 = mgr.get_or_launch(413, target.clone()).await.unwrap(); @@ -235,7 +319,7 @@ mod test { fn connect_one_fail() { test_with_runtime(|runtime| async { let cf = FakeChannelFactory::new(runtime); - let mgr = AbstractChannelMgr::new(cf); + let mgr = AbstractChanMgr::new(cf); // This is set up to always fail. let target = (999, '❌'); @@ -252,7 +336,7 @@ mod test { fn test_concurrent() { test_with_runtime(|runtime| async { let cf = FakeChannelFactory::new(runtime); - let mgr = AbstractChannelMgr::new(cf); + let mgr = AbstractChanMgr::new(cf); // TODO XXXX: figure out how to make these actually run // concurrently. Right now it seems that they don't actually @@ -291,7 +375,7 @@ mod test { fn unusable_entries() { test_with_runtime(|runtime| async { let cf = FakeChannelFactory::new(runtime); - let mgr = AbstractChannelMgr::new(cf); + let mgr = AbstractChanMgr::new(cf); let (ch3, ch4, ch5) = join!( mgr.get_or_launch(3, (3, 'a')), diff --git a/tor-chanmgr/src/mgr/map.rs b/tor-chanmgr/src/mgr/map.rs index 1ec75e8da..40fcf80d6 100644 --- a/tor-chanmgr/src/mgr/map.rs +++ b/tor-chanmgr/src/mgr/map.rs @@ -1,9 +1,17 @@ +//! Simple implementaiton for the internal map state of a ChanMgr. + use super::{AbstractChannel, Pending}; -use crate::Result; +use crate::{Error, Result}; use std::collections::{hash_map, HashMap}; use std::sync::Arc; +/// A map from channel id to channel state. +/// +/// We make this a separate type instead of just using +/// `Mutex>` to limit the amount of code that can see and +/// lock the Mutex here. (We're using a blocking mutex close to async +/// code, so we need to be careful.) pub(crate) struct ChannelMap { /// A map from identity to channel, or to pending channel status. /// @@ -12,19 +20,32 @@ pub(crate) struct ChannelMap { channels: std::sync::Mutex>>, } -// used to ensure that only this module can construct a ChannelState::Poisoned. -pub struct Priv { +/// Structure that can only be constructed from within this module. +/// Used to make sure that only we can construct ChannelState::Poisoned. +pub(crate) struct Priv { + /// (This field is private) _unused: (), } +/// The state of a channel (or channel build attempt) within a map. pub(crate) enum ChannelState { + /// An open channel. + /// + /// This channel might not be usable: it might be closing or + /// broken. We need to check its is_usable() method before + /// yielding it to the user. Open(Arc), + /// A channel that's getting built. Building(Pending), - // XXXX explain what this is for. + /// A temporary invalid state. + /// + /// We insert this into the map temporarily as a placeholder in + /// `change_state()`. Poisoned(Priv), } impl ChannelState { + /// Create a new shallow copy of this ChannelState. fn clone_ref(&self) -> Self { use ChannelState::*; match self { @@ -34,53 +55,68 @@ impl ChannelState { } } + /// For testing: either give the Open channnel inside this state, + /// or panic if there is none. #[cfg(test)] fn unwrap_open(&self) -> Arc { match self { ChannelState::Open(chan) => Arc::clone(chan), - _ => panic!("Not an oppen channel"), + _ => panic!("Not an open channel"), } } } impl ChannelState { - /// DOCDOC returns true if identity COULD BE `ident` - fn check_ident(&self, ident: &C::Ident) -> bool { + /// Return an error if `ident`is definitely not a matching + /// matching identity for this state. + fn check_ident(&self, ident: &C::Ident) -> Result<()> { match self { - ChannelState::Open(chan) => chan.ident() == ident, - ChannelState::Poisoned(_) => false, - ChannelState::Building(_) => true, + ChannelState::Open(chan) => { + if chan.ident() == ident { + Ok(()) + } else { + Err(Error::Internal("Identity mismatch")) + } + } + ChannelState::Poisoned(_) => Err(Error::Internal("Poisoned state in channel map")), + ChannelState::Building(_) => Ok(()), } } } impl ChannelMap { + /// Create a new empty ChannelMap. pub(crate) fn new() -> Self { ChannelMap { channels: std::sync::Mutex::new(HashMap::new()), } } + /// Return the channel state for the given identity, if any. pub(crate) fn get(&self, ident: &C::Ident) -> Result>> { let map = self.channels.lock()?; Ok(map.get(ident).map(ChannelState::clone_ref)) } + /// Replace the channel state for `ident` with `newval`, and return the + /// previous value if any. pub(crate) fn replace( &self, ident: C::Ident, newval: ChannelState, ) -> Result>> { - assert!(newval.check_ident(&ident)); + newval.check_ident(&ident)?; let mut map = self.channels.lock()?; Ok(map.insert(ident, newval)) } + /// Remove and return the state for `ident`, if any. pub(crate) fn remove(&self, ident: &C::Ident) -> Result>> { let mut map = self.channels.lock()?; Ok(map.remove(ident)) } + /// Remove every unusable state from the map. pub(crate) fn remove_unusable(&self) -> Result<()> { let mut map = self.channels.lock()?; map.retain(|_, state| match state { @@ -91,6 +127,19 @@ impl ChannelMap { Ok(()) } + /// Replace the state whose identity is `ident` with a new state. + /// + /// The provided function `func` is invoked on the old state (if + /// any), and must return a tuple containing an optional new + /// state, and an arbitrary return value for this function. + /// + /// Because `func` is run while holding the lock on this object, + /// it should be fast and nonblocking. In return, you can be sure + /// that it's running atomically with respect to other accessesors + /// of this map. + /// + /// If `func` panics, this map will become unusable and future + /// accesses will fail. pub(crate) fn change_state(&self, ident: &C::Ident, func: F) -> Result where F: FnOnce(Option>) -> (Option>, V), @@ -106,7 +155,7 @@ impl ChannelMap { let (newval, output) = func(Some(oldent)); match newval { Some(mut newent) => { - assert!(newent.check_ident(ident)); + newent.check_ident(ident)?; // XXX leaves it poisoned std::mem::swap(occupied.get_mut(), &mut newent); } None => { @@ -118,7 +167,7 @@ impl ChannelMap { Vacant(vacant) => { let (newval, output) = func(None); if let Some(newent) = newval { - assert!(newent.check_ident(ident)); + newent.check_ident(ident)?; vacant.insert(newent); } Ok(output) From 12c45882f079bbd6de39ac5956ca4cc6954d6b6f Mon Sep 17 00:00:00 2001 From: Nick Mathewson Date: Fri, 23 Apr 2021 12:34:33 -0400 Subject: [PATCH 4/5] Refactor and simplify ChanMgr to use AbstractChanMgr. --- tor-chanmgr/src/builder.rs | 148 +++++++++ tor-chanmgr/src/connect.rs | 101 ------ tor-chanmgr/src/err.rs | 6 + tor-chanmgr/src/lib.rs | 410 +------------------------ tor-chanmgr/src/testing.rs | 92 ------ tor-chanmgr/src/transport.rs | 33 -- tor-chanmgr/src/transport/nativetls.rs | 54 ---- tor-client/src/client.rs | 7 +- tor-proto/src/channel.rs | 5 + 9 files changed, 176 insertions(+), 680 deletions(-) create mode 100644 tor-chanmgr/src/builder.rs delete mode 100644 tor-chanmgr/src/connect.rs delete mode 100644 tor-chanmgr/src/testing.rs delete mode 100644 tor-chanmgr/src/transport.rs delete mode 100644 tor-chanmgr/src/transport/nativetls.rs diff --git a/tor-chanmgr/src/builder.rs b/tor-chanmgr/src/builder.rs new file mode 100644 index 000000000..1cfd9d906 --- /dev/null +++ b/tor-chanmgr/src/builder.rs @@ -0,0 +1,148 @@ +//! Implement a concrete type to build channels. + +use crate::Error; + +use tor_linkspec::ChanTarget; +use tor_llcrypto::pk; +use tor_rtcompat::{tls::TlsConnector, Runtime, TlsProvider}; + +use async_trait::async_trait; +use futures::task::SpawnExt; +use std::net::SocketAddr; +use std::sync::Arc; + +/// TLS-based channel builder. +/// +/// This is a separate type so that we can keep our channel management +/// code network-agnostic. +pub(crate) struct ChanBuilder { + /// Asynchronous runtime for TLS, TCP, spawning, and timeouts. + runtime: R, + /// Object to build TLS connections. + tls_connector: ::Connector, +} + +impl ChanBuilder { + /// Construct a new ChanBuilder. + pub(crate) fn new(runtime: R) -> Self { + let tls_connector = runtime.tls_connector(); + ChanBuilder { + runtime, + tls_connector, + } + } +} + +#[async_trait] +impl crate::mgr::ChannelFactory for ChanBuilder { + type Channel = tor_proto::channel::Channel; + type BuildSpec = TargetInfo; + + async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result> { + use tor_rtcompat::SleepProviderExt; + + // TODO: make this an option. And make a better value. + let five_seconds = std::time::Duration::new(5, 0); + + self.runtime + .timeout(five_seconds, self.build_channel_notimeout(target)) + .await? + } +} + +impl ChanBuilder { + /// As build_channel, but don't include a timeout. + async fn build_channel_notimeout( + &self, + target: &TargetInfo, + ) -> crate::Result> { + use tor_proto::channel::ChannelBuilder; + use tor_rtcompat::tls::CertifiedConn; + + // 1. Negotiate the TLS connection. + + // TODO: This just uses the first address. Instead we could be smarter, + // or use "happy eyeballs, or whatever. Maybe we will want to + // refactor as we do so? + let addr = target + .addrs() + .get(0) + .ok_or_else(|| Error::UnusableTarget("No addresses for chosen relay".into()))?; + + log::info!("Negotiating TLS with {}", addr); + + // TODO: add a random hostname here if it will be used for SNI? + let tls = self + .tls_connector + .connect_unvalidated(addr, "ignored") + .await?; + + let peer_cert = tls + .peer_certificate()? + .ok_or(Error::Internal("TLS connection with no peer certificate"))?; + + // 2. Set up the channel. + let mut builder = ChannelBuilder::new(); + builder.set_declared_addr(*addr); + let chan = builder.launch(tls).connect().await?; + let chan = chan.check(target, &peer_cert)?; + let (chan, reactor) = chan.finish().await?; + + // 3. Launch a task to run the channel reactor. + self.runtime.spawn(async { + let _ = reactor.run().await; + })?; + Ok(chan) + } +} + +impl crate::mgr::AbstractChannel for tor_proto::channel::Channel { + type Ident = pk::ed25519::Ed25519Identity; + fn ident(&self) -> &Self::Ident { + self.peer_ed25519_id() + } + fn is_usable(&self) -> bool { + !self.is_closing() + } +} + +/// TargetInfo is a summary of a [`ChanTarget`] that we can pass to +/// [`ChanBuilder::build_channel`]. +/// +/// This is a separate type since we can't declare ChanBuilder as having +/// a parameterized method in today's Rust. +#[derive(Debug, Clone)] +pub(crate) struct TargetInfo { + /// Copy of the addresses from the underlying ChanTarget. + addrs: Vec, + /// Copy of the ed25519 id from the underlying ChanTarget. + ed_identity: pk::ed25519::Ed25519Identity, + /// Copy of the rsa id from the underlying ChanTarget. + rsa_identity: pk::rsa::RsaIdentity, +} + +impl ChanTarget for TargetInfo { + fn addrs(&self) -> &[SocketAddr] { + &self.addrs[..] + } + fn ed_identity(&self) -> &pk::ed25519::Ed25519Identity { + &self.ed_identity + } + fn rsa_identity(&self) -> &pk::rsa::RsaIdentity { + &self.rsa_identity + } +} + +impl TargetInfo { + /// Construct a TargetInfo from a given ChanTarget. + pub(crate) fn from_chan_target(target: &C) -> Self + where + C: ChanTarget + ?Sized, + { + TargetInfo { + addrs: target.addrs().to_vec(), + ed_identity: *target.ed_identity(), + rsa_identity: *target.rsa_identity(), + } + } +} diff --git a/tor-chanmgr/src/connect.rs b/tor-chanmgr/src/connect.rs deleted file mode 100644 index 1ee2e1e7c..000000000 --- a/tor-chanmgr/src/connect.rs +++ /dev/null @@ -1,101 +0,0 @@ -//! Trait and implementation for a "Connector" type. -//! -//! The `Connector` trait is internal to the tor-chanmgr crate, and -//! helps us avoid having `ChanMgr` be polymorphic on transport type. -//! Instead, it can hold a boxed Connector. - -use crate::transport::Transport; -use crate::{Error, Result}; - -use tor_linkspec::ChanTarget; -use tor_llcrypto::pk; - -#[cfg(test)] -use crate::testing::{FakeChannel as Channel, FakeChannelBuilder as ChannelBuilder}; -#[cfg(not(test))] -use tor_proto::channel::{Channel, ChannelBuilder}; - -use async_trait::async_trait; -use futures::task::{Spawn, SpawnExt}; -use std::net::SocketAddr; -use std::sync::Arc; - -/// A Connector knows how to make a channel given the summarized information -/// from a ChanTarget. -#[async_trait] -pub(crate) trait Connector { - /// Create a new channel to `target`, trying exactly once, not timing out. - async fn build_channel( - &self, - runtime: &(dyn Spawn + Sync), - target: &TargetInfo, - ) -> Result>; -} - -// Every Transport is automatically a Connector. -#[async_trait] -impl Connector for TR { - async fn build_channel( - &self, - runtime: &(dyn Spawn + Sync), - target: &TargetInfo, - ) -> Result> { - use tor_rtcompat::tls::CertifiedConn; - let (addr, tls) = self.connect(target).await?; - - let peer_cert = tls - .peer_certificate()? - .ok_or(Error::Internal("TLS connection with no peer certificate"))?; - - let mut builder = ChannelBuilder::new(); - builder.set_declared_addr(addr); - let chan = builder.launch(tls).connect().await?; - let chan = chan.check(target, &peer_cert)?; - let (chan, reactor) = chan.finish().await?; - - runtime.spawn(async { - let _ = reactor.run().await; - })?; - Ok(chan) - } -} - -/// TargetInfo is a summary of a [`ChanTarget`] that we can pass to -/// [`Connector::build_channel`]. -/// -/// This is a separate type since we can't declare Connector as having -/// a method that's parameterized in today's Rust. -pub(crate) struct TargetInfo { - /// Copy of the addresses from the underlying ChanTarget. - addrs: Vec, - /// Copy of the ed25519 id from the underlying ChanTarget. - ed_identity: pk::ed25519::Ed25519Identity, - /// Copy of the rsa id from the underlying ChanTarget. - rsa_identity: pk::rsa::RsaIdentity, -} - -impl ChanTarget for TargetInfo { - fn addrs(&self) -> &[SocketAddr] { - &self.addrs[..] - } - fn ed_identity(&self) -> &pk::ed25519::Ed25519Identity { - &self.ed_identity - } - fn rsa_identity(&self) -> &pk::rsa::RsaIdentity { - &self.rsa_identity - } -} - -impl TargetInfo { - /// Construct a TargetInfo from a given ChanTarget. - pub(crate) fn from_chan_target(target: &C) -> Self - where - C: ChanTarget + ?Sized, - { - TargetInfo { - addrs: target.addrs().to_vec(), - ed_identity: *target.ed_identity(), - rsa_identity: *target.rsa_identity(), - } - } -} diff --git a/tor-chanmgr/src/err.rs b/tor-chanmgr/src/err.rs index 77951195a..551cab583 100644 --- a/tor-chanmgr/src/err.rs +++ b/tor-chanmgr/src/err.rs @@ -40,6 +40,12 @@ impl From for Error { } } +impl From for Error { + fn from(_: tor_rtcompat::TimeoutError) -> Error { + Error::ChanTimeout + } +} + impl From> for Error { fn from(_: std::sync::PoisonError) -> Error { Error::Internal("Thread failed while holding lock") diff --git a/tor-chanmgr/src/lib.rs b/tor-chanmgr/src/lib.rs index fda663ebd..5e7ef1043 100644 --- a/tor-chanmgr/src/lib.rs +++ b/tor-chanmgr/src/lib.rs @@ -11,32 +11,17 @@ #![deny(missing_docs)] #![deny(clippy::missing_docs_in_private_items)] -mod connect; +mod builder; mod err; mod mgr; -#[cfg(test)] -pub(crate) mod testing; -pub mod transport; - -use crate::connect::{Connector, TargetInfo}; -use crate::transport::Transport; use tor_linkspec::ChanTarget; -use tor_llcrypto::pk::ed25519::Ed25519Identity; - -#[cfg(test)] -use testing::FakeChannel as Channel; -#[cfg(not(test))] use tor_proto::channel::Channel; -use futures::lock::Mutex; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - pub use err::Error; +use std::sync::Arc; -use tor_rtcompat::{Runtime, SleepProviderExt}; +use tor_rtcompat::Runtime; /// A Result as returned by this crate. pub type Result = std::result::Result; @@ -47,63 +32,16 @@ pub type Result = std::result::Result; /// Use the [ChanMgr::get_or_launch] function to craete a new channel, or /// get one if it exists. pub struct ChanMgr { - /// Map from Ed25519 identity to channel state. - /// - /// Note that eventually we might want to have this be only - /// _canonical_ connections (those whose address matches the - /// relay's official address) and we might want this to be indexed - /// by pluggable transport too. But since right now only - /// client-initiated channels are supported, and pluggable - /// transports are not supported, this structure is fine. - /// - /// Note that other Channels may exist that are not indexed here. - channels: Mutex>, - - /// Object used to create TLS connections to relays. - connector: Box, - - /// DOCDOC - runtime: R, -} - -/// Possible states for a managed channel -enum ChannelState { - /// The channel is open, authenticated, and canonical: we can give - /// it out as needed. - Open(Arc), - /// Some task is building the channel, and will notify all - /// listeners on this event on success or failure. - Building(Arc), + /// Internal channel manager object that does the actual work. + mgr: mgr::AbstractChanMgr>, } impl ChanMgr { - /// Construct a new channel manager. It will use `transport` to construct - /// TLS streams, and `spawn` to launch reactor tasks. - pub fn new(runtime: R, transport: TR) -> Self - where - TR: Transport + Send + Sync + 'static, - { - let connector = Box::new(transport); - ChanMgr { - channels: Mutex::new(HashMap::new()), - connector, - runtime, - } - } - - /// Helper: Return the channel if it matches the target; otherwise - /// return an error. - /// - /// We need to do this check since it's theoretically possible for - /// a channel to (for example) match the Ed25519 key of the - /// target, but not the RSA key. - fn check_chan_match( - &self, - target: &T, - ch: Arc, - ) -> Result> { - ch.check_match(target)?; - Ok(ch) + /// Construct a new channel manager. + pub fn new(runtime: R) -> Self { + let builder = builder::ChanBuilder::new(runtime); + let mgr = mgr::AbstractChanMgr::new(builder); + ChanMgr { mgr } } /// Try to get a suitable channel to the provided `target`, @@ -114,328 +52,12 @@ impl ChanMgr { /// or fail depending on its outcome. pub async fn get_or_launch(&self, target: &T) -> Result> { let ed_identity = target.ed_identity(); - use ChannelState::*; + let targetinfo = builder::TargetInfo::from_chan_target(target); - // Look up the current cache entry. - let (should_launch, event) = { - let mut channels = self.channels.lock().await; - let state = channels.get(ed_identity); - - match state { - Some(Open(ch)) => { - if ch.is_closing() { - // duplicate with below. XXXXX - let e = Arc::new(event_listener::Event::new()); - let state = Building(Arc::clone(&e)); - channels.insert(*ed_identity, state); - (true, e) - } else { - return self.check_chan_match(target, Arc::clone(ch)); - } - } - Some(Building(e)) => (false, Arc::clone(e)), - None => { - let e = Arc::new(event_listener::Event::new()); - let state = Building(Arc::clone(&e)); - channels.insert(*ed_identity, state); - (true, e) - } - } - }; - - if should_launch { - let result = self.build_channel(target).await; - { - let mut channels = self.channels.lock().await; - match &result { - Ok(ch) => { - channels.insert(*ed_identity, Open(Arc::clone(ch))); - } - Err(_) => { - channels.remove(ed_identity); - } - } - } - event.notify(usize::MAX); - result - } else { - event.listen().await; - let chan = self - .get_nowait_by_ed_id(ed_identity) - .await - .ok_or(Error::PendingFailed)?; - self.check_chan_match(target, chan) - } - } - - /// Helper: construct a new channel for a target. - async fn build_channel(&self, target: &T) -> Result> { - // XXXX make this a parameter. - let timeout = Duration::new(5, 0); - - let result = self - .runtime - .timeout(timeout, self.build_channel_once(target)) - .await; - - match result { - Ok(Ok(chan)) => Ok(chan), - Ok(Err(e)) => Err(e), - Err(_) => Err(Error::ChanTimeout), - } - } - - /// Helper: construct a new channel for a target, trying only once, - /// and not timing out. - async fn build_channel_once(&self, target: &T) -> Result> { - let target = TargetInfo::from_chan_target(target); - self.connector.build_channel(&self.runtime, &target).await - } - - /// Helper: Get the Channel with the given Ed25519 identity, if there - /// is one. - async fn get_nowait_by_ed_id(&self, ed_id: &Ed25519Identity) -> Option> { - use ChannelState::*; - let channels = self.channels.lock().await; - match channels.get(ed_id) { - Some(Open(ch)) => Some(Arc::clone(ch)), - _ => None, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - use tor_llcrypto::pk::rsa::RsaIdentity; - - use async_trait::async_trait; - use futures::io::{AsyncRead, AsyncWrite}; - use futures::join; - use futures::task::Context; - use std::io::Result as IoResult; - use std::net::SocketAddr; - use std::pin::Pin; - use std::task::Poll; - - use tor_rtcompat::test_with_runtime; - - struct FakeTransport; - struct FakeConnection; - - #[async_trait] - impl crate::transport::Transport for FakeTransport { - type Connection = FakeConnection; - async fn connect( - &self, - t: &T, - ) -> Result<(std::net::SocketAddr, FakeConnection)> { - let addr = t.addrs().get(0).unwrap(); - if addr.port() == 1337 { - Err(Error::UnusableTarget("too leet!".into()).into()) - } else { - Ok((*addr, FakeConnection)) - } - } - } - - impl tor_rtcompat::tls::CertifiedConn for FakeConnection { - fn peer_certificate(&self) -> IoResult>> { - Ok(Some(vec![])) - } - } - impl AsyncRead for FakeConnection { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { - Poll::Ready(Ok(0)) - } - } - impl AsyncWrite for FakeConnection { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - Poll::Ready(Ok(0)) - } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - fn poll_close( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - } - - struct Target { - addr: [std::net::SocketAddr; 1], - ed_id: Ed25519Identity, - rsa_id: RsaIdentity, - } - impl ChanTarget for Target { - fn addrs(&self) -> &[SocketAddr] { - &self.addr[..] - } - fn ed_identity(&self) -> &Ed25519Identity { - &self.ed_id - } - fn rsa_identity(&self) -> &RsaIdentity { - &self.rsa_id - } - } - - #[test] - fn connect_one_ok() { - test_with_runtime(|runtime| async { - let mgr = ChanMgr::new(runtime, FakeTransport); - let target = Target { - addr: ["127.0.0.1:443".parse().unwrap()], - ed_id: [3; 32].into(), - rsa_id: [2; 20].into(), - }; - let chan1 = mgr.get_or_launch(&target).await.unwrap(); - let chan2 = mgr.get_or_launch(&target).await.unwrap(); - - assert!(chan1.same_channel(&chan2)); - - { - let channels = mgr.channels.lock().await; - let entry = channels.get(&[3; 32].into()); - match entry { - Some(ChannelState::Open(c)) => assert!(c.same_channel(&chan1)), - _ => panic!(), - } - } - - let chan3 = mgr.get_nowait_by_ed_id(&[3; 32].into()).await; - assert!(chan3.unwrap().same_channel(&chan1)); - }); - } - - #[test] - fn connect_one_fail() { - test_with_runtime(|runtime| async { - let mgr = ChanMgr::new(runtime, FakeTransport); - // port 1337 is set up to always fail in FakeTransport. - let target = Target { - addr: ["127.0.0.1:1337".parse().unwrap()], - ed_id: [3; 32].into(), - rsa_id: [2; 20].into(), - }; - - let res1 = mgr.get_or_launch(&target).await; - assert!(matches!(res1.unwrap_err(), Error::UnusableTarget(_))); - - // port 8686 is set up to always fail in FakeTransport. - let target = Target { - addr: ["127.0.0.1:8686".parse().unwrap()], - ed_id: [4; 32].into(), - rsa_id: [2; 20].into(), - }; - - let res1 = mgr.get_or_launch(&target).await; - assert!(matches!(res1.unwrap_err(), Error::Proto(_))); - - let chan3 = mgr.get_nowait_by_ed_id(&[4; 32].into()).await; - assert!(chan3.is_none()); - }); - } - - #[test] - fn test_concurrent() { - test_with_runtime(|runtime| async { - let mgr = ChanMgr::new(runtime, FakeTransport); - let target3 = Target { - addr: ["127.0.0.1:99".parse().unwrap()], - ed_id: [3; 32].into(), - rsa_id: [2; 20].into(), - }; - let target44 = Target { - addr: ["127.0.0.2:99".parse().unwrap()], - ed_id: [44; 32].into(), // note different ed key. - rsa_id: [2; 20].into(), - }; - let target_bad = Target { - addr: ["127.0.0.1:8686".parse().unwrap()], - ed_id: [66; 32].into(), - rsa_id: [2; 20].into(), - }; - - // TODO XXXX: figure out how to make these actually run - // concurrently. Right now it seems that they don't actually - // interact. - let (ch3a, ch3b, ch44a, ch44b, ch86a, ch86b) = join!( - mgr.get_or_launch(&target3), - mgr.get_or_launch(&target3), - mgr.get_or_launch(&target44), - mgr.get_or_launch(&target44), - mgr.get_or_launch(&target_bad), - mgr.get_or_launch(&target_bad), - ); - let ch3a = ch3a.unwrap(); - let ch3b = ch3b.unwrap(); - let ch44a = ch44a.unwrap(); - let ch44b = ch44b.unwrap(); - let err_a = ch86a.unwrap_err(); - let err_b = ch86b.unwrap_err(); - - assert!(ch3a.same_channel(&ch3b)); - assert!(ch44a.same_channel(&ch44b)); - assert!(!ch3a.same_channel(&ch44b)); - - assert!(matches!(err_a, Error::Proto(_))); - assert!(matches!(err_b, Error::Proto(_))); - }); - } - - #[test] - fn test_stall() { - test_with_runtime(|runtime| async { - use futures::FutureExt; - - let mgr = ChanMgr::new(runtime, FakeTransport); - let target = Target { - addr: ["127.0.0.1:99".parse().unwrap()], - ed_id: [12; 32].into(), - rsa_id: [2; 20].into(), - }; - - { - let mut channels = mgr.channels.lock().await; - let e = Arc::new(event_listener::Event::new()); - let state = ChannelState::Building(Arc::clone(&e)); - channels.insert([12; 32].into(), state); - } - - let h = mgr.get_or_launch(&target); - - assert!(h.now_or_never().is_none()); - }); - } - - #[test] - fn connect_two_closing() { - test_with_runtime(|runtime| async { - let mgr = ChanMgr::new(runtime, FakeTransport); - let target = Target { - addr: ["127.0.0.1:443".parse().unwrap()], - ed_id: [3; 32].into(), - rsa_id: [2; 20].into(), - }; - let chan1 = mgr.get_or_launch(&target).await.unwrap(); - chan1.mark_closing(); - let chan2 = mgr.get_or_launch(&target).await.unwrap(); - - assert!(!chan1.same_channel(&chan2)); - }); + let chan = self.mgr.get_or_launch(*ed_identity, targetinfo).await?; + // Double-check the match to make sure that the RSA identity is + // what we wanted too. + chan.check_match(target)?; + Ok(chan) } } diff --git a/tor-chanmgr/src/testing.rs b/tor-chanmgr/src/testing.rs deleted file mode 100644 index f74c789b8..000000000 --- a/tor-chanmgr/src/testing.rs +++ /dev/null @@ -1,92 +0,0 @@ -//! Testing stubs for the chahannel manager code. Only enabled -//! with `cfg(test)`. - -#![allow(missing_docs)] -#![allow(clippy::missing_docs_in_private_items)] - -use crate::{Error, Result}; -use tor_linkspec::ChanTarget; -use tor_llcrypto::pk::rsa::RsaIdentity; - -use std::net::SocketAddr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -#[derive(Debug)] -pub struct FakeChannel { - chan: FakeChannelInner, -} - -#[derive(Debug)] -pub(crate) struct FakeChannelInner { - closing: AtomicBool, - want_rsa_id: Option, - addr: SocketAddr, -} - -#[derive(Debug)] -pub(crate) struct FakeChannelBuilder { - addr: Option, -} - -#[derive(Debug)] -pub(crate) struct FakeReactor {} - -impl FakeChannelBuilder { - pub fn new() -> Self { - FakeChannelBuilder { addr: None } - } - pub fn set_declared_addr(&mut self, addr: SocketAddr) { - self.addr = Some(addr); - } - pub fn launch(self, _ignore: T) -> FakeChannel { - FakeChannel::new(self.addr.unwrap()) - } -} - -impl FakeChannel { - pub fn new(addr: SocketAddr) -> Self { - let inner = FakeChannelInner { - closing: false.into(), - want_rsa_id: None, - addr, - }; - FakeChannel { chan: inner } - } - pub async fn connect(self) -> Result { - Ok(self) - } - pub fn check(self, _target: &T, _cert: &[u8]) -> Result { - if self.chan.addr.port() == 8686 { - Err(tor_proto::Error::ChanProto("86ed".into()).into()) - } else { - Ok(self) - } - } - pub fn same_channel(self: &Arc, other: &Arc) -> bool { - Arc::ptr_eq(self, other) - } - pub(crate) async fn finish(self) -> Result<(Arc, FakeReactor)> { - Ok((Arc::new(self), FakeReactor {})) - } - pub fn is_closing(&self) -> bool { - self.chan.closing.load(Ordering::SeqCst) - } - pub fn mark_closing(&self) { - self.chan.closing.store(true, Ordering::SeqCst) - } - pub fn check_match(&self, target: &T) -> Result<()> { - if let Some(ref id) = self.chan.want_rsa_id { - if id != target.rsa_identity() { - return Err(Error::UnusableTarget("Wrong RSA".into()).into()); - } - } - Ok(()) - } -} - -impl FakeReactor { - pub async fn run(self) -> Result<()> { - Ok(()) - } -} diff --git a/tor-chanmgr/src/transport.rs b/tor-chanmgr/src/transport.rs deleted file mode 100644 index 2820045d4..000000000 --- a/tor-chanmgr/src/transport.rs +++ /dev/null @@ -1,33 +0,0 @@ -//! Types for launching TLS connections to relays -//! -//! TODO: Perhaps this type is no longer needed, and we should just -//! use the TlsConnector trait in tor_rtcompat. - -pub mod nativetls; - -use crate::Result; - -use tor_linkspec::ChanTarget; -use tor_rtcompat::tls::CertifiedConn; - -use async_trait::async_trait; -use futures::io::{AsyncRead, AsyncWrite}; - -/// A Transport knows how to build a TLS connection to a relay, in a way -/// that Tor can use. -/// -/// Tor doesn't expect to get any particular hostname or sequence of -/// certificates in the reply; it only expects that the peer certificate -/// will later be authenticated inside the Tor handshake. -#[async_trait] -pub trait Transport { - /// The type that will be returned by this transport. This should - /// be an asynchronous TLS connection. - type Connection: AsyncRead + AsyncWrite + Send + Unpin + CertifiedConn + 'static; - - /// Try to connect to a given relay. - async fn connect( - &self, - target: &T, - ) -> Result<(std::net::SocketAddr, Self::Connection)>; -} diff --git a/tor-chanmgr/src/transport/nativetls.rs b/tor-chanmgr/src/transport/nativetls.rs deleted file mode 100644 index a128af126..000000000 --- a/tor-chanmgr/src/transport/nativetls.rs +++ /dev/null @@ -1,54 +0,0 @@ -//! Build TLS connections using the async_native_tls or tokio_native_tls crate. - -// XXXX-A2 This should get refactored significantly. Probably we should have -// a boxed-connection-factory type that we can use instead. Once we have a -// pluggable designn, we'll really need something like that. -// -// Probably, much of this code should move into tor-rtcompat, or a new -// crate similar to tor-rtcompat, that can handle our TLS drama. - -use super::Transport; -use crate::{Error, Result}; -use tor_linkspec::ChanTarget; -use tor_rtcompat::tls::TlsConnector; - -use async_trait::async_trait; - -use log::info; - -/// A Transport that uses a connector based on native_tls. -pub struct NativeTlsTransport { - /// connector object used to build TLS connections - connector: C, -} - -impl NativeTlsTransport { - /// Construct a new NativeTlsTransport. - pub fn new(connector: C) -> Result { - Ok(NativeTlsTransport { connector }) - } -} - -#[async_trait] -impl Transport for NativeTlsTransport { - type Connection = C::Conn; - - async fn connect( - &self, - target: &T, - ) -> Result<(std::net::SocketAddr, Self::Connection)> { - // TODO: This just uses the first address. Instead we could be smarter, - // or use "happy eyeballs, or whatever. Maybe we will want to - // refactor as we do so? - let addr = target - .addrs() - .get(0) - .ok_or_else(|| Error::UnusableTarget("No addresses for chosen relay".into()))?; - - info!("Negotiating TLS with {}", addr); - - // TODO: add a random hostname here if it will be used for SNI? - let connection = self.connector.connect_unvalidated(addr, "ignored").await?; - Ok((*addr, connection)) - } -} diff --git a/tor-client/src/client.rs b/tor-client/src/client.rs index e11424387..20d65d754 100644 --- a/tor-client/src/client.rs +++ b/tor-client/src/client.rs @@ -3,7 +3,6 @@ //! To construct a client, run the `TorClient::bootstrap()` method. //! Once the client is bootstrapped, you can make connections over the Tor //! network using `TorClient::connect()`. -use tor_chanmgr::transport::nativetls::NativeTlsTransport; use tor_circmgr::TargetPort; use tor_dirmgr::NetDirConfig; use tor_proto::circuit::IpVersionPreference; @@ -78,11 +77,7 @@ impl TorClient { /// Return a client once there is enough directory material to /// connect safely over the Tor network. pub async fn bootstrap(runtime: R, dircfg: NetDirConfig) -> Result> { - let transport = { - let connector = runtime.tls_connector(); - NativeTlsTransport::new(connector)? - }; - let chanmgr = Arc::new(tor_chanmgr::ChanMgr::new(runtime.clone(), transport)); + let chanmgr = Arc::new(tor_chanmgr::ChanMgr::new(runtime.clone())); let circmgr = Arc::new(tor_circmgr::CircMgr::new( runtime.clone(), Arc::clone(&chanmgr), diff --git a/tor-proto/src/channel.rs b/tor-proto/src/channel.rs index 742595fdb..1df8abc3f 100644 --- a/tor-proto/src/channel.rs +++ b/tor-proto/src/channel.rs @@ -236,6 +236,11 @@ impl Channel { self.unique_id } + /// Return the Ed25519 identity for the peer of this channel. + pub fn peer_ed25519_id(&self) -> &Ed25519Identity { + &self.ed25519_id + } + /// Return an error if this channel is somehow mismatched with the /// given target. pub fn check_match(&self, target: &T) -> Result<()> { From 579767f6709f4d988eb62bb554f3de912ca818e8 Mon Sep 17 00:00:00 2001 From: Nick Mathewson Date: Fri, 23 Apr 2021 12:55:53 -0400 Subject: [PATCH 5/5] Another test in ChanMgr, for TargetInfo. --- tor-chanmgr/src/builder.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tor-chanmgr/src/builder.rs b/tor-chanmgr/src/builder.rs index 1cfd9d906..50c871408 100644 --- a/tor-chanmgr/src/builder.rs +++ b/tor-chanmgr/src/builder.rs @@ -146,3 +146,22 @@ impl TargetInfo { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn targetinfo() { + let ti = TargetInfo { + addrs: vec!["127.0.0.1:11".parse().unwrap()], + ed_identity: [42; 32].into(), + rsa_identity: [45; 20].into(), + }; + + let ti2 = TargetInfo::from_chan_target(&ti); + assert_eq!(ti.addrs, ti2.addrs); + assert_eq!(ti.ed_identity, ti2.ed_identity); + assert_eq!(ti.rsa_identity, ti2.rsa_identity); + } +}