diff --git a/crates/tor-chanmgr/src/factory.rs b/crates/tor-chanmgr/src/factory.rs index c4cb16488..ce3b512a9 100644 --- a/crates/tor-chanmgr/src/factory.rs +++ b/crates/tor-chanmgr/src/factory.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use async_trait::async_trait; use tor_error::{HasKind, HasRetryTime}; -use tor_linkspec::{OwnedChanTarget, PtTransportName}; +use tor_linkspec::{HasChanMethod, OwnedChanTarget, PtTransportName}; use tor_proto::channel::Channel; use tracing::debug; @@ -74,8 +74,85 @@ pub trait AbstractPtMgr { async fn factory_for_transport( &self, transport: &PtTransportName, - ) -> Result>, Arc>; + ) -> Result>, Arc>; } /// Alias for an Arc ChannelFactory with all of the traits that we require. pub(crate) type ArcFactory = Arc; + +/// Alias for an Arc PtMgr with all of the traits that we require. +pub(crate) type ArcPtMgr = Arc; + +#[async_trait] +impl

AbstractPtMgr for Option

+where + P: AbstractPtMgr + Send + Sync, +{ + async fn factory_for_transport( + &self, + transport: &PtTransportName, + ) -> Result>, Arc> { + match self { + Some(mgr) => mgr.factory_for_transport(transport).await, + None => Ok(None), + } + } +} + +/// A ChannelFactory built from an optional PtMgr to use for pluggable transports, and a +/// ChannelFactory to use for everything else. +#[derive(Clone)] +pub(crate) struct Factory { + #[cfg(feature = "pt-client")] + /// The PtMgr to use for pluggable transports + ptmgr: Option, + /// The factory to use for everything else + default_factory: ArcFactory, +} + +#[async_trait] +impl ChannelFactory for Factory { + async fn connect_via_transport(&self, target: &OwnedChanTarget) -> crate::Result { + use tor_linkspec::ChannelMethod::*; + let factory = match target.chan_method() { + Direct(_) => self.default_factory.clone(), + #[cfg(feature = "pt-client")] + Pluggable(a) => match self.ptmgr.as_ref() { + Some(mgr) => mgr + .factory_for_transport(a.transport()) + .await + .expect("TODO pt-client") + .ok_or_else(|| crate::Error::NoSuchTransport(a.transport().clone().into()))?, + None => return Err(crate::Error::NoSuchTransport(a.transport().clone().into())), + }, + }; + + factory.connect_via_transport(target).await + } +} + +impl Factory { + /// Create a new `Factory` that will try to use `ptmgr` to handle pluggable + /// transports requests, and `default_factory` to handle everything else. + pub(crate) fn new( + default_factory: ArcFactory, + #[cfg(feature = "pt-client")] ptmgr: Option, + ) -> Self { + Self { + default_factory, + #[cfg(feature = "pt-client")] + ptmgr, + } + } + + /// Replace the default factory in this object. + pub(crate) fn replace_default_factory(&mut self, factory: ArcFactory) { + self.default_factory = factory; + } + + #[cfg(feature = "pt-client")] + /// Replace the PtMgr in this object. + pub(crate) fn replace_ptmgr(&mut self, ptmgr: ArcPtMgr) { + self.ptmgr = Some(ptmgr); + } +} diff --git a/crates/tor-chanmgr/src/lib.rs b/crates/tor-chanmgr/src/lib.rs index b815f709d..d801aaf1f 100644 --- a/crates/tor-chanmgr/src/lib.rs +++ b/crates/tor-chanmgr/src/lib.rs @@ -48,7 +48,6 @@ mod testing; pub mod transport; use educe::Educe; -use factory::ArcFactory; use futures::select_biased; use futures::task::SpawnExt; use futures::StreamExt; @@ -81,7 +80,7 @@ use tor_rtcompat::scheduler::{TaskHandle, TaskSchedule}; /// get one if it exists. pub struct ChanMgr { /// Internal channel manager object that does the actual work. - mgr: mgr::AbstractChanMgr, + mgr: mgr::AbstractChanMgr, /// Stream of [`ConnStatus`] events. bootstrap_status: event::ConnStatusEvents, @@ -169,8 +168,12 @@ impl ChanMgr { let sender = Arc::new(std::sync::Mutex::new(sender)); let transport = transport::DefaultTransport::new(runtime.clone()); let builder = builder::ChanBuilder::new(runtime, transport, sender); - let builder: ArcFactory = Arc::new(builder); - let mgr = mgr::AbstractChanMgr::new(builder, config, dormancy, netparams); + let factory = factory::Factory::new( + Arc::new(builder), + #[cfg(feature = "pt-client")] + None, + ); + let mgr = mgr::AbstractChanMgr::new(factory, config, dormancy, netparams); ChanMgr { mgr, bootstrap_status: receiver, @@ -271,23 +274,25 @@ impl ChanMgr { /// /// This method can be used to e.g. tell Arti to use a proxy for /// outgoing connections. - pub fn set_default_transport(&self, _factory: impl factory::ChannelFactory) { - // TODO pt-client: Perhaps we actually want to remove this and have it - // be part of the constructor? The only way to actually implement it is - // to make the channel factory in AbstractChanMgr mutable, which seels a - // little ugly. Do we ever want to change this on a _running_ ChanMgr? - #![allow(clippy::missing_panics_doc, clippy::needless_pass_by_value)] - todo!("TODO pt-client: implement this.") + pub fn set_default_transport( + &self, + factory: impl factory::ChannelFactory + Send + Sync + 'static, + ) { + // TODO pt-client: Perhaps we actually want to take this as part of the constructor instead? + // TODO pt-client: It's not clear to me that we really need this method. + // TODO pt-client: Should this method take an ArcFactory instead? + self.mgr + .with_mut_builder(|f| f.replace_default_factory(Arc::new(factory))); } - /* - TODO pt-client: use AbstractPtMgr instead /// Replace the transport registry with one that may know about /// more transports. #[cfg(feature = "pt-client")] - pub fn set_transport_registry(&self, _registry: impl factory::TransportRegistry) { + pub fn set_pt_mgr(&self, ptmgr: impl factory::AbstractPtMgr + Send + Sync + 'static) { + // TODO pt-client: Should this method take an ArcPtMgr instead? + self.mgr + .with_mut_builder(|f| f.replace_ptmgr(Arc::new(ptmgr))); } - */ /// Watch for things that ought to change the configuration of all channels in the client /// diff --git a/crates/tor-chanmgr/src/mgr.rs b/crates/tor-chanmgr/src/mgr.rs index 00c2ca72b..ae7464bac 100644 --- a/crates/tor-chanmgr/src/mgr.rs +++ b/crates/tor-chanmgr/src/mgr.rs @@ -106,6 +106,14 @@ impl AbstractChanMgr { } } + /// Run a function to modify the channel builder in this object. + pub(crate) fn with_mut_builder(&self, func: F) + where + F: FnOnce(&mut CF), + { + self.channels.with_mut_builder(func); + } + /// Remove every unusable entry from this channel manager. #[cfg(test)] pub(crate) fn remove_unusable_entries(&self) -> Result<()> { diff --git a/crates/tor-chanmgr/src/mgr/state.rs b/crates/tor-chanmgr/src/mgr/state.rs index c8294d147..de577f4e0 100644 --- a/crates/tor-chanmgr/src/mgr/state.rs +++ b/crates/tor-chanmgr/src/mgr/state.rs @@ -272,11 +272,13 @@ impl MgrState { inner.builder.clone() } - /// Replace the builder stored in this state. - #[allow(dead_code)] //TODO pt-client: remove. - pub(crate) fn replace_builder(&self, builder: C) { + /// Run a function to modify the builder stored in this state. + pub(crate) fn with_mut_builder(&self, func: F) + where + F: FnOnce(&mut C), + { let mut inner = self.inner.lock().expect("lock poisoned"); - inner.builder = builder; + func(&mut inner.builder); } /// Run a function on the `ByRelayIds` that implements the map in this `MgrState`.