chanmgr: get rid of Arc around Channel

This commit is contained in:
Ian Jackson 2022-01-13 13:12:29 +00:00 committed by eta
parent b4761f8cfd
commit 9b723cba53
6 changed files with 119 additions and 96 deletions

View File

@ -8,7 +8,6 @@ use tor_rtcompat::{tls::TlsConnector, Runtime, TlsProvider};
use async_trait::async_trait;
use futures::task::SpawnExt;
use std::sync::Arc;
/// TLS-based channel builder.
///
@ -37,17 +36,15 @@ impl<R: Runtime> crate::mgr::ChannelFactory for ChanBuilder<R> {
type Channel = tor_proto::channel::Channel;
type BuildSpec = OwnedChanTarget;
async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result<Arc<Self::Channel>> {
async fn build_channel(&self, target: &Self::BuildSpec) -> crate::Result<Self::Channel> {
use tor_rtcompat::SleepProviderExt;
// TODO: make this an option. And make a better value.
let five_seconds = std::time::Duration::new(5, 0);
// FIXME(eta): there doesn't need to be an `Arc` here; `Channel` implements `Clone`!
self.runtime
.timeout(five_seconds, self.build_channel_notimeout(target))
.await?
.map(Arc::new)
}
}
@ -169,7 +166,7 @@ mod test {
// Create the channelbuilder that we want to test.
let builder = ChanBuilder::new(client_rt);
let (r1, r2): (Result<Arc<Channel>>, Result<LocalStream>) = futures::join!(
let (r1, r2): (Result<Channel>, Result<LocalStream>) = futures::join!(
async {
// client-side: build a channel!
builder.build_channel(&target).await

View File

@ -56,7 +56,6 @@ use tor_linkspec::{ChanTarget, OwnedChanTarget};
use tor_proto::channel::Channel;
pub use err::Error;
use std::sync::Arc;
use tor_rtcompat::Runtime;
@ -87,7 +86,7 @@ impl<R: Runtime> ChanMgr<R> {
/// If there is already a channel launch attempt in progress, this
/// function will wait until that launch is complete, and succeed
/// or fail depending on its outcome.
pub async fn get_or_launch<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<Arc<Channel>> {
pub async fn get_or_launch<T: ChanTarget + ?Sized>(&self, target: &T) -> Result<Channel> {
let ed_identity = target.ed_identity();
let targetinfo = OwnedChanTarget::from_chan_target(target);

View File

@ -6,14 +6,13 @@ use async_trait::async_trait;
use futures::channel::oneshot;
use futures::future::{FutureExt, Shared};
use std::hash::Hash;
use std::sync::Arc;
mod map;
/// Trait to describe as much of a
/// [`Channel`](tor_proto::channel::Channel) as `AbstractChanMgr`
/// needs to use.
pub(crate) trait AbstractChannel {
pub(crate) trait AbstractChannel: Clone {
/// Identity type for the other side of the channel.
type Ident: Hash + Eq + Clone;
/// Return this channel's identity.
@ -40,7 +39,7 @@ pub(crate) trait ChannelFactory {
/// and so on.
///
/// It should not retry; that is handled at a higher level.
async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<Self::Channel>>;
async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Self::Channel>;
}
/// A type- and network-agnostic implementation for
@ -62,11 +61,11 @@ pub(crate) struct AbstractChanMgr<CF: ChannelFactory> {
/// Type alias for a future that we wait on to see when a pending
/// channel is done or failed.
type Pending<C> = Shared<oneshot::Receiver<Result<Arc<C>>>>;
type Pending<C> = Shared<oneshot::Receiver<Result<C>>>;
/// Type alias for the sender we notify when we complete a channel (or
/// fail to complete it).
type Sending<C> = oneshot::Sender<Result<Arc<C>>>;
type Sending<C> = oneshot::Sender<Result<C>>;
impl<CF: ChannelFactory> AbstractChanMgr<CF> {
/// Make a new empty channel manager.
@ -85,7 +84,7 @@ impl<CF: ChannelFactory> AbstractChanMgr<CF> {
/// Helper: return the objects used to inform pending tasks
/// about a newly open or failed channel.
fn setup_launch<C>(&self) -> (map::ChannelState<C>, Sending<C>) {
fn setup_launch<C: Clone>(&self) -> (map::ChannelState<C>, Sending<C>) {
let (snd, rcv) = oneshot::channel();
let shared = rcv.shared();
(map::ChannelState::Building(shared), snd)
@ -104,7 +103,7 @@ impl<CF: ChannelFactory> AbstractChanMgr<CF> {
&self,
ident: <<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
target: CF::BuildSpec,
) -> Result<Arc<CF::Channel>> {
) -> Result<CF::Channel> {
use map::ChannelState::*;
/// Possible actions that we'll decide to take based on the
@ -117,7 +116,7 @@ impl<CF: ChannelFactory> AbstractChanMgr<CF> {
/// We're going to wait for it to finish.
Wait(Pending<C>),
/// We found a usable channel. We're going to return it.
Return(Result<Arc<C>>),
Return(Result<C>),
}
/// How many times do we try?
const N_ATTEMPTS: usize = 2;
@ -134,7 +133,7 @@ impl<CF: ChannelFactory> AbstractChanMgr<CF> {
Some(Open(ref ch)) => {
if ch.is_usable() {
// Good channel. Return it.
let action = Action::Return(Ok(Arc::clone(ch)));
let action = Action::Return(Ok(ch.clone()));
(oldstate, action)
} else {
// Unusable channel. Move to the Building
@ -186,11 +185,10 @@ impl<CF: ChannelFactory> AbstractChanMgr<CF> {
Ok(chan) => {
// The channel got built: remember it, tell the
// others, and return it.
self.channels
.replace(ident.clone(), Open(Arc::clone(&chan)))?;
self.channels.replace(ident.clone(), Open(chan.clone()))?;
// It's okay if all the receivers went away:
// that means that nobody was waiting for this channel.
let _ignore_err = send.send(Ok(Arc::clone(&chan)));
let _ignore_err = send.send(Ok(chan.clone()));
return Ok(chan);
}
Err(e) => {
@ -214,10 +212,10 @@ impl<CF: ChannelFactory> AbstractChanMgr<CF> {
pub(crate) fn get_nowait(
&self,
ident: &<<CF as ChannelFactory>::Channel as AbstractChannel>::Ident,
) -> Option<Arc<CF::Channel>> {
) -> Option<CF::Channel> {
use map::ChannelState::*;
match self.channels.get(ident) {
Ok(Some(Open(ref ch))) if ch.is_usable() => Some(Arc::clone(ch)),
Ok(Some(Open(ref ch))) if ch.is_usable() => Some(ch.clone()),
_ => None,
}
}
@ -231,6 +229,7 @@ mod test {
use futures::join;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tor_rtcompat::{task::yield_now, test_with_one_runtime, Runtime};
@ -239,11 +238,18 @@ mod test {
runtime: RT,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct FakeChannel {
ident: u32,
mood: char,
closing: AtomicBool,
closing: Arc<AtomicBool>,
detect_reuse: Arc<char>,
}
impl PartialEq for FakeChannel {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.detect_reuse, &other.detect_reuse)
}
}
impl AbstractChannel for FakeChannel {
@ -273,7 +279,7 @@ mod test {
type Channel = FakeChannel;
type BuildSpec = (u32, char);
async fn build_channel(&self, target: &Self::BuildSpec) -> Result<Arc<FakeChannel>> {
async fn build_channel(&self, target: &Self::BuildSpec) -> Result<FakeChannel> {
yield_now().await;
let (ident, mood) = *target;
match mood {
@ -285,11 +291,12 @@ mod test {
}
_ => {}
}
Ok(Arc::new(FakeChannel {
Ok(FakeChannel {
ident,
mood,
closing: AtomicBool::new(false),
}))
closing: Arc::new(AtomicBool::new(false)),
detect_reuse: Default::default(),
})
}
}
@ -302,10 +309,10 @@ mod test {
let chan1 = mgr.get_or_launch(413, target).await.unwrap();
let chan2 = mgr.get_or_launch(413, target).await.unwrap();
assert!(Arc::ptr_eq(&chan1, &chan2));
assert_eq!(chan1, chan2);
let chan3 = mgr.get_nowait(&413).unwrap();
assert!(Arc::ptr_eq(&chan1, &chan3));
assert_eq!(chan1, chan3);
});
}
@ -349,9 +356,9 @@ mod test {
let err_a = ch86a.unwrap_err();
let err_b = ch86b.unwrap_err();
assert!(Arc::ptr_eq(&ch3a, &ch3b));
assert!(Arc::ptr_eq(&ch44a, &ch44b));
assert!(!Arc::ptr_eq(&ch44a, &ch3a));
assert_eq!(ch3a, ch3b);
assert_eq!(ch44a, ch44b);
assert_ne!(ch44a, ch3a);
assert!(matches!(err_a, Error::UnusableTarget(_)));
assert!(matches!(err_b, Error::UnusableTarget(_)));
@ -378,7 +385,7 @@ mod test {
ch5.start_closing();
let ch3_new = mgr.get_or_launch(3, (3, 'b')).await.unwrap();
assert!(!Arc::ptr_eq(&ch3, &ch3_new));
assert_ne!(ch3, ch3_new);
assert_eq!(ch3_new.mood, 'b');
mgr.remove_unusable_entries().unwrap();

View File

@ -4,7 +4,6 @@ use super::{AbstractChannel, Pending};
use crate::{Error, Result};
use std::collections::{hash_map, HashMap};
use std::sync::Arc;
/// A map from channel id to channel state.
///
@ -34,7 +33,7 @@ pub(crate) enum ChannelState<C> {
/// This channel might not be usable: it might be closing or
/// broken. We need to check its is_usable() method before
/// yielding it to the user.
Open(Arc<C>),
Open(C),
/// A channel that's getting built.
Building(Pending<C>),
/// A temporary invalid state.
@ -44,13 +43,13 @@ pub(crate) enum ChannelState<C> {
Poisoned(Priv),
}
impl<C> ChannelState<C> {
impl<C: Clone> ChannelState<C> {
/// Create a new shallow copy of this ChannelState.
#[cfg(test)]
fn clone_ref(&self) -> Result<Self> {
use ChannelState::*;
match self {
Open(chan) => Ok(Open(Arc::clone(chan))),
Open(chan) => Ok(Open(chan.clone())),
Building(pending) => Ok(Building(pending.clone())),
Poisoned(_) => Err(Error::Internal("Poisoned state in channel map")),
}
@ -59,9 +58,9 @@ impl<C> ChannelState<C> {
/// For testing: either give the Open channel inside this state,
/// or panic if there is none.
#[cfg(test)]
fn unwrap_open(&self) -> Arc<C> {
fn unwrap_open(&self) -> C {
match self {
ChannelState::Open(chan) => Arc::clone(chan),
ChannelState::Open(chan) => chan.clone(),
_ => panic!("Not an open channel"),
}
}
@ -185,7 +184,7 @@ impl<C: AbstractChannel> ChannelMap<C> {
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
#[derive(Eq, PartialEq, Debug)]
#[derive(Eq, PartialEq, Clone, Debug)]
struct FakeChannel {
ident: &'static str,
usable: bool,
@ -200,16 +199,16 @@ mod test {
}
}
fn ch(ident: &'static str) -> ChannelState<FakeChannel> {
ChannelState::Open(Arc::new(FakeChannel {
ChannelState::Open(FakeChannel {
ident,
usable: true,
}))
})
}
fn closed(ident: &'static str) -> ChannelState<FakeChannel> {
ChannelState::Open(Arc::new(FakeChannel {
ChannelState::Open(FakeChannel {
ident,
usable: false,
}))
})
}
#[test]

View File

@ -95,8 +95,31 @@ type CellFrame<T> = futures_codec::Framed<T, crate::channel::codec::ChannelCodec
/// An open client channel, ready to send and receive Tor cells.
///
/// A channel is a direct connection to a Tor relay, implemented using TLS.
///
/// This struct is a frontend that can be used to send cells (using the `Sink<ChanCell>`
/// impl and otherwise control the channel. The main state is in the Reactor object.
/// `Channel` is cheap to clone.
///
/// (Users need a mutable reference because of the types in `Sink`, and ultimately because
/// `cell_tx: mpsc::Sender` doesn't work without mut.
#[derive(Clone, Debug)]
pub struct Channel {
/// A channel used to send control messages to the Reactor.
control: mpsc::UnboundedSender<CtrlMsg>,
/// A channel used to send cells to the Reactor.
cell_tx: mpsc::Sender<ChanCell>,
/// Information shared with the reactor
details: Arc<ChannelDetails>,
}
/// This is information shared between the reactor and the frontend.
///
/// This exists to make `Channel` cheap to clone, which is desirable because every circuit wants
/// an owned mutable `Channel`.
///
/// `control` can't be here because we rely on it getting dropped when the last user goes away.
#[derive(Debug)]
pub(crate) struct ChannelDetails {
/// A unique identifier for this channel.
unique_id: UniqId,
/// Validated Ed25519 identity for this peer.
@ -104,11 +127,7 @@ pub struct Channel {
/// Validated RSA identity for this peer.
rsa_id: RsaIdentity,
/// If true, this channel is closing.
closed: Arc<AtomicBool>,
/// A channel used to send control messages to the Reactor.
control: mpsc::UnboundedSender<CtrlMsg>,
/// A channel used to send cells to the Reactor.
cell_tx: mpsc::Sender<ChanCell>,
closed: AtomicBool,
}
impl Sink<ChanCell> for Channel {
@ -123,7 +142,7 @@ impl Sink<ChanCell> for Channel {
fn start_send(self: Pin<&mut Self>, cell: ChanCell) -> Result<()> {
let this = self.get_mut();
if this.closed.load(Ordering::SeqCst) {
if this.details.closed.load(Ordering::SeqCst) {
return Err(Error::ChannelClosed);
}
this.check_cell(&cell)?;
@ -133,7 +152,7 @@ impl Sink<ChanCell> for Channel {
Relay(_) | Padding(_) | VPadding(_) => {} // too frequent to log.
_ => trace!(
"{}: Sending {} for {}",
this.unique_id,
this.details.unique_id,
cell.msg().cmd(),
cell.circid()
),
@ -225,15 +244,20 @@ impl Channel {
let (control_tx, control_rx) = mpsc::unbounded();
let (cell_tx, cell_rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let closed = Arc::new(AtomicBool::new(false));
let closed = AtomicBool::new(false);
let channel = Channel {
let details = ChannelDetails {
unique_id,
ed25519_id,
rsa_id,
closed: Arc::clone(&closed),
closed,
};
let details = Arc::new(details);
let channel = Channel {
control: control_tx,
cell_tx,
details: Arc::clone(&details),
};
let reactor = Reactor {
@ -242,10 +266,9 @@ impl Channel {
input: futures::StreamExt::fuse(stream),
output: sink,
circs: circmap,
unique_id,
closed,
circ_unique_id_ctx: CircUniqIdContext::new(),
link_protocol,
details,
};
(channel, reactor)
@ -253,17 +276,17 @@ impl Channel {
/// Return a process-unique identifier for this channel.
pub fn unique_id(&self) -> UniqId {
self.unique_id
self.details.unique_id
}
/// Return the Ed25519 identity for the peer of this channel.
pub fn peer_ed25519_id(&self) -> &Ed25519Identity {
&self.ed25519_id
&self.details.ed25519_id
}
/// Return the (legacy) RSA identity for the peer of this channel.
pub fn peer_rsa_id(&self) -> &RsaIdentity {
&self.rsa_id
&self.details.rsa_id
}
/// Return an error if this channel is somehow mismatched with the
@ -290,7 +313,7 @@ impl Channel {
/// Return true if this channel is closed and therefore unusable.
pub fn is_closing(&self) -> bool {
self.closed.load(Ordering::SeqCst)
self.details.closed.load(Ordering::SeqCst)
}
/// Check whether a cell type is permissible to be _sent_ on an
@ -402,13 +425,18 @@ pub(crate) mod test {
/// Make a new fake reactor-less channel. For testing only, obviously.
pub(crate) fn fake_channel() -> Channel {
let unique_id = UniqId::new();
Channel {
let details = Arc::new(ChannelDetails {
unique_id,
ed25519_id: [6_u8; 32].into(),
rsa_id: [10_u8; 20].into(),
closed: Arc::new(AtomicBool::new(false)),
closed: AtomicBool::new(false),
});
Channel {
control: mpsc::unbounded().0,
cell_tx: mpsc::channel(CHANNEL_BUFFER_SIZE).0,
details,
}
}

View File

@ -7,7 +7,6 @@
//! or in the error handling behavior.
use super::circmap::{CircEnt, CircMap};
use super::UniqId;
use crate::circuit::halfcirc::HalfCirc;
use crate::util::err::ReactorError;
use crate::{Error, Result};
@ -21,12 +20,13 @@ use futures::stream::Stream;
use futures::Sink;
use std::convert::TryInto;
use std::fmt;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::Poll;
use crate::channel::unique_id;
use crate::channel::{unique_id, ChannelDetails};
use crate::circuit::celltypes::{ClientCircChanMsg, CreateResponse};
use tracing::{debug, trace};
@ -80,10 +80,8 @@ pub struct Reactor {
pub(super) output: BoxedChannelSink,
/// A map from circuit ID to Sinks on which we can deliver cells.
pub(super) circs: CircMap,
/// Logging identifier for this channel
pub(super) unique_id: UniqId,
/// If true, this channel is closing.
pub(super) closed: Arc<AtomicBool>,
/// Information shared with the frontend
pub(super) details: Arc<ChannelDetails>,
/// Context for allocating unique circuit log identifiers.
pub(super) circ_unique_id_ctx: unique_id::CircUniqIdContext,
/// What link protocol is the channel using?
@ -91,6 +89,16 @@ pub struct Reactor {
pub(super) link_protocol: u16,
}
/// Allows us to just say debug!("{}: Reactor did a thing", &self, ...)
///
/// There is no risk of confusion because no-one would try to print a
/// Reactor for some other reason.
impl fmt::Display for Reactor {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.details.unique_id, f)
}
}
impl Reactor {
/// Launch the reactor, and run until the channel closes or we
/// encounter an error.
@ -98,10 +106,10 @@ impl Reactor {
/// Once this function returns, the channel is dead, and can't be
/// used again.
pub async fn run(mut self) -> Result<()> {
if self.closed.load(Ordering::SeqCst) {
if self.details.closed.load(Ordering::SeqCst) {
return Err(Error::ChannelClosed);
}
debug!("{}: Running reactor", self.unique_id);
debug!("{}: Running reactor", &self);
let result: Result<()> = loop {
match self.run_once().await {
Ok(()) => (),
@ -109,8 +117,8 @@ impl Reactor {
Err(ReactorError::Err(e)) => break Err(e),
}
};
debug!("{}: Reactor stopped: {:?}", self.unique_id, result);
self.closed.store(true, Ordering::SeqCst);
debug!("{}: Reactor stopped: {:?}", &self, result);
self.details.closed.store(true, Ordering::SeqCst);
result
}
@ -204,7 +212,7 @@ impl Reactor {
/// Handle a CtrlMsg other than Shutdown.
async fn handle_control(&mut self, msg: CtrlMsg) -> Result<()> {
trace!("{}: reactor received {:?}", self.unique_id, msg);
trace!("{}: reactor received {:?}", &self, msg);
match msg {
CtrlMsg::Shutdown => panic!(), // was handled in reactor loop.
CtrlMsg::CloseCircuit(id) => self.outbound_destroy_circ(id).await?,
@ -214,7 +222,7 @@ impl Reactor {
tx,
} => {
let mut rng = rand::thread_rng();
let my_unique_id = self.unique_id;
let my_unique_id = self.details.unique_id;
let circ_unique_id = self.circ_unique_id_ctx.next(my_unique_id);
let ret: Result<_> = self
.circs
@ -234,7 +242,7 @@ impl Reactor {
match msg {
Relay(_) | Padding(_) | VPadding(_) => {} // too frequent to log.
_ => trace!("{}: received {} for {}", self.unique_id, msg.cmd(), circid),
_ => trace!("{}: received {} for {}", &self, msg.cmd(), circid),
}
match msg {
@ -316,11 +324,7 @@ impl Reactor {
// If the circuit is waiting for CREATED, tell it that it
// won't get one.
Some(CircEnt::Opening(oneshot, _)) => {
trace!(
"{}: Passing destroy to pending circuit {}",
self.unique_id,
circid
);
trace!("{}: Passing destroy to pending circuit {}", &self, circid);
oneshot
.send(msg.try_into()?)
// TODO(nickm) I think that this one actually means the other side
@ -333,11 +337,7 @@ impl Reactor {
}
// It's an open circuit: tell it that it got a DESTROY cell.
Some(CircEnt::Open(mut sink)) => {
trace!(
"{}: Passing destroy to open circuit {}",
self.unique_id,
circid
);
trace!("{}: Passing destroy to open circuit {}", &self, circid);
sink.send(msg.try_into()?)
.await
// TODO(nickm) I think that this one actually means the other side
@ -350,11 +350,7 @@ impl Reactor {
Some(CircEnt::DestroySent(_)) => Ok(()),
// Got a DESTROY cell for a circuit we don't have.
None => {
trace!(
"{}: Destroy for nonexistent circuit {}",
self.unique_id,
circid
);
trace!("{}: Destroy for nonexistent circuit {}", &self, circid);
Err(Error::ChanProto("Destroy for nonexistent circuit".into()))
}
}
@ -369,11 +365,7 @@ impl Reactor {
/// Called when a circuit goes away: sends a DESTROY cell and removes
/// the circuit.
async fn outbound_destroy_circ(&mut self, id: CircId) -> Result<()> {
trace!(
"{}: Circuit {} is gone; sending DESTROY",
self.unique_id,
id
);
trace!("{}: Circuit {} is gone; sending DESTROY", &self, id);
// Remove the circuit's entry from the map: nothing more
// can be done with it.
// TODO: It would be great to have a tighter upper bound for
@ -391,6 +383,7 @@ impl Reactor {
pub(crate) mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::channel::UniqId;
use crate::circuit::CircParameters;
use futures::sink::SinkExt;
use futures::stream::StreamExt;