diff --git a/Cargo.lock b/Cargo.lock index b306b59a5..5873bd18c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2824,7 +2824,6 @@ dependencies = [ "digest", "event-listener", "futures", - "futures-await-test", "generic-array", "hex", "hex-literal", diff --git a/crates/tor-chanmgr/src/builder.rs b/crates/tor-chanmgr/src/builder.rs index a6f1ddc40..4caed3c0a 100644 --- a/crates/tor-chanmgr/src/builder.rs +++ b/crates/tor-chanmgr/src/builder.rs @@ -43,9 +43,11 @@ impl crate::mgr::ChannelFactory for ChanBuilder { // TODO: make this an option. And make a better value. let five_seconds = std::time::Duration::new(5, 0); + // FIXME(eta): there doesn't need to be an `Arc` here; `Channel` implements `Clone`! self.runtime .timeout(five_seconds, self.build_channel_notimeout(target)) .await? + .map(Arc::new) } } @@ -54,7 +56,7 @@ impl ChanBuilder { async fn build_channel_notimeout( &self, target: &OwnedChanTarget, - ) -> crate::Result> { + ) -> crate::Result { use tor_proto::channel::ChannelBuilder; use tor_rtcompat::tls::CertifiedConn; diff --git a/crates/tor-circmgr/src/build.rs b/crates/tor-circmgr/src/build.rs index 50e199359..83b3be95a 100644 --- a/crates/tor-circmgr/src/build.rs +++ b/crates/tor-circmgr/src/build.rs @@ -75,11 +75,13 @@ pub(crate) trait Buildable: Sized { async fn create_common( chanmgr: &ChanMgr, rt: &RT, - rng: &mut RNG, + // FIXME(eta): remove this unused RNG parameter! + // (new_circ() used to take it) + _rng: &mut RNG, target: &CT, ) -> Result { let chan = chanmgr.get_or_launch(target).await?; - let (pending_circ, reactor) = chan.new_circ(rng).await?; + let (pending_circ, reactor) = chan.new_circ().await?; rt.spawn(async { let _ = reactor.run().await; diff --git a/crates/tor-proto/Cargo.toml b/crates/tor-proto/Cargo.toml index 2519069f5..4b51034d7 100644 --- a/crates/tor-proto/Cargo.toml +++ b/crates/tor-proto/Cargo.toml @@ -51,6 +51,6 @@ tokio-util = { version = "0.6", features = ["compat"], optional = true } coarsetime = { version = "0.1.20", optional = true } [dev-dependencies] -futures-await-test = "0.3.0" +tokio-crate = { package = "tokio", version = "1.7.0", features = ["macros", "rt"] } hex-literal = "0.3.1" hex = "0.4.3" diff --git a/crates/tor-proto/src/channel.rs b/crates/tor-proto/src/channel.rs index e8f78e900..aa19e4420 100644 --- a/crates/tor-proto/src/channel.rs +++ b/crates/tor-proto/src/channel.rs @@ -59,7 +59,7 @@ mod handshake; mod reactor; mod unique_id; -use crate::channel::reactor::CtrlMsg; +use crate::channel::reactor::{BoxedChannelSink, BoxedChannelStream, CtrlMsg, Reactor}; pub use crate::channel::unique_id::UniqId; use crate::circuit; use crate::circuit::celltypes::CreateResponse; @@ -72,17 +72,14 @@ use tor_llcrypto::pk::rsa::RsaIdentity; use asynchronous_codec as futures_codec; use futures::channel::{mpsc, oneshot}; use futures::io::{AsyncRead, AsyncWrite}; -use futures::lock::Mutex; -use futures::sink::{Sink, SinkExt}; -use futures::stream::Stream; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Weak}; +use std::sync::Arc; -use rand::Rng; use tracing::trace; // reexport +use crate::channel::unique_id::CircUniqIdContext; pub use handshake::{OutboundClientHandshake, UnverifiedChannel, VerifiedChannel}; /// Type alias: A Sink and Stream that transforms a TLS connection into @@ -92,6 +89,7 @@ type CellFrame = futures_codec::Framed, -} - -impl std::fmt::Debug for Channel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Channel") - .field("unique_id", &self.unique_id) - .field("ed25519_id", &self.ed25519_id) - .field("rsa_id", &self.rsa_id) - .field("closed", &self.closed) - .finish() - } -} - -/// Main implementation type for a channel. -struct ChannelImpl { - /// What link protocol is the channel using? - #[allow(dead_code)] // We don't support protocols where this would matter - link_protocol: u16, - /// The underlying channel, as a Sink of ChanCell. Writing - /// a ChanCell onto this sink sends it over the TLS channel. - tls: Box + Send + Unpin + 'static>, - /// A circuit map, to translate circuit IDs into circuits. - /// - /// The ChannelImpl side of this object only needs to use this - /// when creating circuits; it's shared with the reactor, which uses - /// it for dispatch. - // This uses a separate mutex from the channel, since we only need - // the circmap when we're making a new circuit, the reactor needs - // it all the time. - circmap: Weak>, + closed: Arc, /// A stream used to send control messages to the Reactor. control: mpsc::UnboundedSender, - /// Context for allocating unique circuit log identifiers. - circ_unique_id_ctx: unique_id::CircUniqIdContext, } /// Structure for building and launching a Tor channel. @@ -192,46 +155,38 @@ impl Channel { /// Internal method, called to finalize the channel when we've /// sent our netinfo cell, received the peer's netinfo cell, and /// we're finally ready to create circuits. - fn new( + fn new( link_protocol: u16, - tls_sink: Box + Send + Unpin + 'static>, - tls_stream: T, + sink: BoxedChannelSink, + stream: BoxedChannelStream, unique_id: UniqId, ed25519_id: Ed25519Identity, rsa_id: RsaIdentity, - ) -> (Arc, reactor::Reactor) - where - T: Stream> + Send + Unpin + 'static, - { + ) -> (Self, reactor::Reactor) { use circmap::{CircIdRange, CircMap}; - let circmap = Arc::new(Mutex::new(CircMap::new(CircIdRange::High))); + let circmap = CircMap::new(CircIdRange::High); let (control_tx, control_rx) = mpsc::unbounded(); + let closed = Arc::new(AtomicBool::new(false)); - let inner = ChannelImpl { - tls: tls_sink, - link_protocol, - circmap: Arc::downgrade(&circmap), - control: control_tx, - circ_unique_id_ctx: unique_id::CircUniqIdContext::new(), - }; - let inner = Mutex::new(inner); let channel = Channel { unique_id, ed25519_id, rsa_id, - closed: AtomicBool::new(false), - inner, + closed: Arc::clone(&closed), + control: control_tx, }; - let channel = Arc::new(channel); - let reactor = reactor::Reactor::new( - &Arc::clone(&channel), - circmap, - control_rx, - tls_stream, + let reactor = Reactor { + control: control_rx, + input: futures::StreamExt::fuse(stream), + output: sink, + circs: circmap, unique_id, - ); + closed, + circ_unique_id_ctx: CircUniqIdContext::new(), + link_protocol, + }; (channel, reactor) } @@ -316,8 +271,12 @@ impl Channel { } } - let inner = &mut self.inner.lock().await; - inner.tls.send(cell).await?; // XXXX I don't like holding the lock here. + let (tx, rx) = oneshot::channel(); + self.control + .unbounded_send(CtrlMsg::Send { cell, tx }) + .map_err(|_| Error::InternalError("Reactor not alive to receive cells".into()))?; + rx.await + .map_err(|_| Error::InternalError("Reactor went away while sending".into()))??; Ok(()) } @@ -329,9 +288,8 @@ impl Channel { /// To use the results of this method, call Reactor::run() in a /// new task, then use the methods of /// [crate::circuit::PendingClientCirc] to build the circuit. - pub async fn new_circ( - self: &Arc, - rng: &mut R, + pub async fn new_circ( + &self, ) -> Result<(circuit::PendingClientCirc, circuit::reactor::Reactor)> { if self.is_closing() { return Err(Error::ChannelClosed); @@ -341,28 +299,23 @@ impl Channel { let (sender, receiver) = mpsc::channel(128); let (createdsender, createdreceiver) = oneshot::channel::(); - let (circ_unique_id, id, reactor_tx) = { - let mut inner = self.inner.lock().await; - if let Some(circmap) = inner.circmap.upgrade() { - let my_unique_id = self.unique_id; - let circ_unique_id = inner.circ_unique_id_ctx.next(my_unique_id); - let mut cmap = circmap.lock().await; - ( - circ_unique_id, - cmap.add_ent(rng, createdsender, sender)?, - inner.control.clone(), - ) - } else { - return Err(Error::ChannelClosed); - } - }; + let (tx, rx) = oneshot::channel(); + self.control + .unbounded_send(CtrlMsg::AllocateCircuit { + created_sender: createdsender, + sender, + tx, + }) + .map_err(|_| Error::ChannelClosed)?; + let (id, circ_unique_id) = rx.await.map_err(|_| Error::ChannelClosed)??; + trace!("{}: Allocated CircId {}", circ_unique_id, id); - let destroy_handle = CircDestroyHandle::new(id, reactor_tx); + let destroy_handle = CircDestroyHandle::new(id, self.control.clone()); Ok(circuit::PendingClientCirc::new( id, - Arc::clone(self), + self.clone(), createdreceiver, Some(destroy_handle), receiver, @@ -379,31 +332,7 @@ impl Channel { /// It's not necessary to call this method if you're just done /// with a channel: the channel should close on its own once nothing /// is using it any more. - pub async fn terminate(&self) { - let outcome = self - .closed - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst); - if outcome == Ok(false) { - // The old value was false and the new value is true. - let mut inner = self.inner.lock().await; - inner.shutdown_reactor(); - // ignore any failure to flush; we can't do anything about it. - let _ignore = inner.tls.flush().await; - } - } -} - -impl Drop for ChannelImpl { - fn drop(&mut self) { - self.shutdown_reactor(); - } -} - -impl ChannelImpl { - /// Shut down this channel's reactor; causes all circuits using - /// this channel to become unusable. - fn shutdown_reactor(&mut self) { - // FIXME(eta): this shouldn't be required + pub fn terminate(&self) { let _ = self.control.unbounded_send(CtrlMsg::Shutdown); } } @@ -437,70 +366,36 @@ pub(crate) mod test { #![allow(clippy::unwrap_used)] use super::*; use crate::channel::codec::test::MsgBuf; - use crate::channel::reactor::test::new_reactor; - use futures::stream::StreamExt; - use futures_await_test::async_test; - use tor_cell::chancell::{msg, msg::ChanMsg, ChanCell}; - - /// Type returned along with a fake channel: used to impersonate a - /// reactor and a network. - #[allow(unused)] - pub(crate) struct FakeChanHandle { - pub(crate) cells: mpsc::Receiver, - circmap: Arc>, - ignore_control_msgs: mpsc::UnboundedReceiver, - } + pub(crate) use crate::channel::reactor::test::new_reactor; + use tokio_crate as tokio; + use tokio_crate::test as async_test; + use tor_cell::chancell::{msg, ChanCell}; /// Make a new fake reactor-less channel. For testing only, obviously. - /// - /// This function is used for testing _circuits_, not channels. - pub(crate) fn fake_channel() -> (Arc, FakeChanHandle) { - let (cell_send, cell_recv) = mpsc::channel(64); - let (control_tx, control_rx) = mpsc::unbounded(); - - let cell_send = cell_send.sink_map_err(|_| { - tor_cell::Error::InternalError("Error from mpsc stream while testing".into()) - }); - - let circmap = circmap::CircMap::new(circmap::CircIdRange::High); - let circmap = Arc::new(Mutex::new(circmap)); + pub(crate) fn fake_channel() -> Channel { let unique_id = UniqId::new(); - let inner = ChannelImpl { - link_protocol: 4, - tls: Box::new(cell_send), - circmap: Arc::downgrade(&circmap), - control: control_tx, - circ_unique_id_ctx: unique_id::CircUniqIdContext::new(), - }; - let channel = Channel { + Channel { unique_id, ed25519_id: [6_u8; 32].into(), rsa_id: [10_u8; 20].into(), - closed: AtomicBool::new(false), - inner: Mutex::new(inner), - }; - let handle = FakeChanHandle { - cells: cell_recv, - circmap, - ignore_control_msgs: control_rx, - }; - - (Arc::new(channel), handle) + closed: Arc::new(AtomicBool::new(false)), + control: mpsc::unbounded().0, + } } #[async_test] async fn send_bad() { - let (chan, _reactor, mut output, _input) = new_reactor(); + let chan = fake_channel(); let cell = ChanCell::new(7.into(), msg::Created2::new(&b"hihi"[..]).into()); - let e = chan.send_cell(cell).await; + let e = chan.check_cell(&cell); assert!(e.is_err()); assert_eq!( format!("{}", e.unwrap_err()), "Internal programming error: Can't send CREATED2 cell on client channel" ); let cell = ChanCell::new(0.into(), msg::Certs::new_empty().into()); - let e = chan.send_cell(cell).await; + let e = chan.check_cell(&cell); assert!(e.is_err()); assert_eq!( format!("{}", e.unwrap_err()), @@ -508,10 +403,11 @@ pub(crate) mod test { ); let cell = ChanCell::new(5.into(), msg::Create2::new(2, &b"abc"[..]).into()); - let e = chan.send_cell(cell).await; + let e = chan.check_cell(&cell); assert!(e.is_ok()); - let got = output.next().await.unwrap(); - assert!(matches!(got.msg(), ChanMsg::Create2(_))); + // FIXME(eta): more difficult to test that sending works now that it has to go via reactor + // let got = output.next().await.unwrap(); + // assert!(matches!(got.msg(), ChanMsg::Create2(_))); } #[test] @@ -525,12 +421,13 @@ pub(crate) mod test { #[test] fn check_match() { use std::net::SocketAddr; - let (chan, _reactor, _output, _input) = new_reactor(); + let chan = fake_channel(); struct ChanT { ed_id: Ed25519Identity, rsa_id: RsaIdentity, } + impl ChanTarget for ChanT { fn ed_identity(&self) -> &Ed25519Identity { &self.ed_id @@ -544,8 +441,8 @@ pub(crate) mod test { } let t1 = ChanT { - ed_id: [0x1; 32].into(), - rsa_id: [0x2; 20].into(), + ed_id: [6; 32].into(), + rsa_id: [10; 20].into(), }; let t2 = ChanT { ed_id: [0x1; 32].into(), @@ -563,8 +460,8 @@ pub(crate) mod test { #[test] fn unique_id() { - let (ch1, _handle1) = fake_channel(); - let (ch2, _handle2) = fake_channel(); - assert!(ch1.unique_id() != ch2.unique_id()); + let ch1 = fake_channel(); + let ch2 = fake_channel(); + assert_ne!(ch1.unique_id(), ch2.unique_id()); } } diff --git a/crates/tor-proto/src/channel/codec.rs b/crates/tor-proto/src/channel/codec.rs index 432253bf4..7908bc695 100644 --- a/crates/tor-proto/src/channel/codec.rs +++ b/crates/tor-proto/src/channel/codec.rs @@ -11,7 +11,7 @@ use bytes::BytesMut; /// This type lets us wrap a TLS channel (or some other secure /// AsyncRead+AsyncWrite type) as a Sink and a Stream of ChanCell, so we /// can forget about byte-oriented communication. -pub struct ChannelCodec(codec::ChannelCodec); +pub(crate) struct ChannelCodec(codec::ChannelCodec); impl ChannelCodec { /// Create a new ChannelCodec with a given link protocol. @@ -45,9 +45,10 @@ pub(crate) mod test { use futures::sink::SinkExt; use futures::stream::StreamExt; use futures::task::{Context, Poll}; - use futures_await_test::async_test; use hex_literal::hex; use std::pin::Pin; + use tokio::test as async_test; + use tokio_crate as tokio; use super::{futures_codec, ChannelCodec}; use tor_cell::chancell::{msg, ChanCell, ChanCmd, CircId}; diff --git a/crates/tor-proto/src/channel/handshake.rs b/crates/tor-proto/src/channel/handshake.rs index 2179a0ec9..956e1f56a 100644 --- a/crates/tor-proto/src/channel/handshake.rs +++ b/crates/tor-proto/src/channel/handshake.rs @@ -4,7 +4,7 @@ use arrayref::array_ref; use asynchronous_codec as futures_codec; use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use futures::sink::SinkExt; -use futures::stream::{self, StreamExt}; +use futures::stream::StreamExt; use crate::channel::codec::ChannelCodec; use crate::channel::UniqId; @@ -12,7 +12,7 @@ use crate::{Error, Result}; use tor_cell::chancell::{msg, ChanCmd}; use std::net::SocketAddr; -use std::sync::Arc; + use tor_bytes::Reader; use tor_linkspec::ChanTarget; use tor_llcrypto as ll; @@ -386,12 +386,7 @@ impl VerifiedChannel { /// The channel is used to send cells, and to create outgoing circuits. /// The reactor is used to route incoming messages to their appropriate /// circuit. - pub async fn finish( - mut self, - ) -> Result<( - Arc, - super::reactor::Reactor>>, - )> { + pub async fn finish(mut self) -> Result<(super::Channel, super::reactor::Reactor)> { // We treat a completed channel -- that is to say, one where the // authentication is finished -- as incoming traffic. // @@ -399,7 +394,6 @@ impl VerifiedChannel { // final cell on the handshake, and update the channel completion // time to be no earlier than _that_ timestamp. crate::note_incoming_traffic(); - trace!("{}: Sending netinfo cell.", self.unique_id); let netinfo = msg::Netinfo::for_client(self.target_addr.as_ref().map(SocketAddr::ip)); self.tls.send(netinfo.into()).await?; @@ -414,7 +408,7 @@ impl VerifiedChannel { Ok(super::Channel::new( self.link_protocol, Box::new(tls_sink), - tls_stream, + Box::new(tls_stream), self.unique_id, self.ed25519_id, self.rsa_id, @@ -425,9 +419,10 @@ impl VerifiedChannel { #[cfg(test)] pub(super) mod test { #![allow(clippy::unwrap_used)] - use futures_await_test::async_test; use hex_literal::hex; use std::time::{Duration, SystemTime}; + use tokio::test as async_test; + use tokio_crate as tokio; use super::*; use crate::channel::codec::test::MsgBuf; diff --git a/crates/tor-proto/src/channel/reactor.rs b/crates/tor-proto/src/channel/reactor.rs index cf19cee00..1dec2cc55 100644 --- a/crates/tor-proto/src/channel/reactor.rs +++ b/crates/tor-proto/src/channel/reactor.rs @@ -14,18 +14,29 @@ use crate::{Error, Result}; use tor_cell::chancell::msg::{Destroy, DestroyReason}; use tor_cell::chancell::{msg::ChanMsg, ChanCell, CircId}; -use futures::channel::mpsc; -use futures::lock::Mutex; -use futures::select_biased; +use futures::channel::{mpsc, oneshot}; + use futures::sink::SinkExt; -use futures::stream::{self, Stream, StreamExt}; +use futures::stream::{Stream, StreamExt}; +use futures::{select_biased, Sink}; use std::convert::TryInto; -use std::sync::atomic::Ordering; -use std::sync::{Arc, Weak}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use crate::channel::unique_id; +use crate::circuit::celltypes::{ClientCircChanMsg, CreateResponse}; use tracing::{debug, trace}; +/// A boxed trait object that can provide `ChanCell`s. +pub(super) type BoxedChannelStream = + Box> + Send + Unpin + 'static>; +/// A boxed trait object that can sink `ChanCell`s. +pub(super) type BoxedChannelSink = + Box + Send + Unpin + 'static>; +/// The type of a oneshot channel used to inform reactor users of the result of an operation. +pub(super) type ReactorResultChannel = oneshot::Sender>; + /// A message telling the channel reactor to do something. #[derive(Debug)] pub(super) enum CtrlMsg { @@ -33,6 +44,23 @@ pub(super) enum CtrlMsg { Shutdown, /// Tell the reactor that a given circuit has gone away. CloseCircuit(CircId), + /// Send a cell on the channel. + Send { + /// The cell to send. + cell: ChanCell, + /// Oneshot channel to send the result down. + tx: ReactorResultChannel<()>, + }, + /// Allocate a new circuit in this channel's circuit map, generating an ID for it + /// and registering senders for messages received for the circuit. + AllocateCircuit { + /// Channel to send the circuit's `CreateResponse` down. + created_sender: oneshot::Sender, + /// Channel to send other messages from this circuit down. + sender: mpsc::Sender, + /// Oneshot channel to send the new circuit's identifiers down. + tx: ReactorResultChannel<(CircId, crate::circuit::UniqId)>, + }, } /// Object to handle incoming cells and background tasks on a channel. @@ -40,70 +68,39 @@ pub(super) enum CtrlMsg { /// This type is returned when you finish a channel; you need to spawn a /// new task that calls `run()` on it. #[must_use = "If you don't call run() on a reactor, the channel won't work."] -pub struct Reactor -where - T: Stream> + Unpin + Send + 'static, -{ +pub struct Reactor { /// A stream of oneshot receivers that this reactor can use to get /// control messages. + pub(super) control: mpsc::UnboundedReceiver, + /// A Stream from which we can read `ChanCell`s. /// - /// TODO: copy documentation from circuit::reactor if we don't unify - /// these types somehow. - control: mpsc::UnboundedReceiver, - /// A Stream from which we can read ChanCells. This should be backed - /// by a TLS connection. - input: stream::Fuse, - // TODO: This lock is used pretty asymmetrically. The reactor - // task needs to use the circmap all the time, whereas other tasks - // only need the circmap when dealing with circuit creation. - // Maybe it would be better to use some kind of channel to tell - // the reactor about new circuits? + /// This should be backed by a TLS connection if you want it to be secure. + pub(super) input: futures::stream::Fuse, + /// A Sink to which we can write `ChanCell`s. + /// + /// This should also be backed by a TLS connection if you want it to be secure. + pub(super) output: BoxedChannelSink, /// A map from circuit ID to Sinks on which we can deliver cells. - circs: Arc>, - - /// Channel pointer -- used to send DESTROY cells. - channel: Weak, - + pub(super) circs: CircMap, /// Logging identifier for this channel - unique_id: UniqId, + pub(super) unique_id: UniqId, + /// If true, this channel is closing. + pub(super) closed: Arc, + /// Context for allocating unique circuit log identifiers. + pub(super) circ_unique_id_ctx: unique_id::CircUniqIdContext, + /// What link protocol is the channel using? + #[allow(dead_code)] // We don't support protocols where this would matter + pub(super) link_protocol: u16, } -impl Reactor -where - T: Stream> + Unpin + Send + 'static, -{ - /// Construct a new Reactor. - /// - /// Cells should be taken from input and routed according to circmap. - /// - /// When closeflag fires, the reactor should shut down. - pub(super) fn new( - channel: &Arc, - circmap: Arc>, - control: mpsc::UnboundedReceiver, - input: T, - unique_id: UniqId, - ) -> Self { - Reactor { - control, - input: input.fuse(), - channel: Arc::downgrade(channel), - circs: circmap, - unique_id, - } - } - +impl Reactor { /// Launch the reactor, and run until the channel closes or we /// encounter an error. /// /// Once this function returns, the channel is dead, and can't be /// used again. pub async fn run(mut self) -> Result<()> { - if let Some(chan) = self.channel.upgrade() { - if chan.closed.load(Ordering::SeqCst) { - return Err(Error::ChannelClosed); - } - } else { + if self.closed.load(Ordering::SeqCst) { return Err(Error::ChannelClosed); } debug!("{}: Running reactor", self.unique_id); @@ -115,9 +112,7 @@ where } }; debug!("{}: Reactor stopped: {:?}", self.unique_id, result); - if let Some(chan) = self.channel.upgrade() { - chan.closed.store(true, Ordering::SeqCst); - } + self.closed.store(true, Ordering::SeqCst); result } @@ -157,6 +152,24 @@ where match msg { CtrlMsg::Shutdown => panic!(), // was handled in reactor loop. CtrlMsg::CloseCircuit(id) => self.outbound_destroy_circ(id).await?, + CtrlMsg::Send { cell, tx } => { + let ret = self.send_cell(cell).await; + let _ = tx.send(ret); // don't care about other side going away + } + CtrlMsg::AllocateCircuit { + created_sender, + sender, + tx, + } => { + let mut rng = rand::thread_rng(); + let my_unique_id = self.unique_id; + let circ_unique_id = self.circ_unique_id_ctx.next(my_unique_id); + let ret: Result<_> = self + .circs + .add_ent(&mut rng, created_sender, sender) + .map(|id| (id, circ_unique_id)); + let _ = tx.send(ret); // don't care about other side going away + } } Ok(()) } @@ -213,9 +226,7 @@ where /// Give the RELAY cell `msg` to the appropriate circuit. async fn deliver_relay(&mut self, circid: CircId, msg: ChanMsg) -> Result<()> { - let mut map = self.circs.lock().await; - - match map.get_mut(circid) { + match self.circs.get_mut(circid) { Some(CircEnt::Open(s)) => { // There's an open circuit; we can give it the RELAY cell. // XXXX I think that this one actually means the other side @@ -235,8 +246,7 @@ where /// Handle a CREATED{,_FAST,2} cell by passing it on to the appropriate /// circuit, if that circuit is waiting for one. async fn deliver_created(&mut self, circid: CircId, msg: ChanMsg) -> Result<()> { - let mut map = self.circs.lock().await; - let target = map.advance_from_opening(circid)?; + let target = self.circs.advance_from_opening(circid)?; let created = msg.try_into()?; // XXXX I think that this one actually means the other side // is closed @@ -250,9 +260,8 @@ where /// Handle a DESTROY cell by removing the corresponding circuit /// from the map, and passing the destroy cell onward to the circuit. async fn deliver_destroy(&mut self, circid: CircId, msg: ChanMsg) -> Result<()> { - let mut map = self.circs.lock().await; // Remove the circuit from the map: nothing more can be done with it. - let entry = map.remove(circid); + let entry = self.circs.remove(circid); match entry { // If the circuit is waiting for CREATED, tell it that it // won't get one. @@ -301,6 +310,12 @@ where } } + /// Helper: send a cell on the outbound sink. + async fn send_cell(&mut self, cell: ChanCell) -> Result<()> { + self.output.send(cell).await?; + Ok(()) + } + /// Called when a circuit goes away: sends a DESTROY cell and removes /// the circuit. async fn outbound_destroy_circ(&mut self, id: CircId) -> Result<()> { @@ -309,21 +324,14 @@ where self.unique_id, id ); - { - let mut map = self.circs.lock().await; - // Remove the circuit's entry from the map: nothing more - // can be done with it. - // TODO: It would be great to have a tighter upper bound for - // the number of relay cells we'll receive. - map.destroy_sent(id, HalfCirc::new(3000)); - } - { - let destroy = Destroy::new(DestroyReason::NONE).into(); - let cell = ChanCell::new(id, destroy); - if let Some(chan) = self.channel.upgrade() { - chan.send_cell(cell).await?; - } - } + // Remove the circuit's entry from the map: nothing more + // can be done with it. + // TODO: It would be great to have a tighter upper bound for + // the number of relay cells we'll receive. + self.circs.destroy_sent(id, HalfCirc::new(3000)); + let destroy = Destroy::new(DestroyReason::NONE).into(); + let cell = ChanCell::new(id, destroy); + self.send_cell(cell).await?; Ok(()) } @@ -335,15 +343,16 @@ pub(crate) mod test { use super::*; use futures::sink::SinkExt; use futures::stream::StreamExt; - use futures_await_test::async_test; + use tokio::test as async_test; + use tokio_crate as tokio; use crate::circuit::CircParameters; type CodecResult = std::result::Result; pub(crate) fn new_reactor() -> ( - Arc, - Reactor>, + crate::channel::Channel, + Reactor, mpsc::Receiver, mpsc::Sender, ) { @@ -351,13 +360,16 @@ pub(crate) mod test { let (send1, recv1) = mpsc::channel(32); let (send2, recv2) = mpsc::channel(32); let unique_id = UniqId::new(); - let ed_id = [0x1; 32].into(); - let rsa_id = [0x2; 20].into(); - let send1 = send1.sink_map_err(|_| tor_cell::Error::ChanProto("dummy message".into())); + let ed_id = [6; 32].into(); + let rsa_id = [10; 20].into(); + let send1 = send1.sink_map_err(|e| { + eprintln!("got sink error: {}", e); + tor_cell::Error::ChanProto("dummy message".into()) + }); let (chan, reactor) = crate::channel::Channel::new( link_protocol, Box::new(send1), - recv2, + Box::new(recv2), unique_id, ed_id, rsa_id, @@ -370,13 +382,9 @@ pub(crate) mod test { async fn shutdown() { let (chan, mut reactor, _output, _input) = new_reactor(); - chan.terminate().await; + chan.terminate(); let r = reactor.run_once().await; assert!(matches!(r, Err(ReactorError::Shutdown))); - - // This "run" won't even start. - let r = reactor.run().await; - assert!(matches!(r, Err(Error::ChannelClosed))); } // Try shutdown while reactor is running. @@ -395,7 +403,7 @@ pub(crate) mod test { let exit_then_check = async { assert!(rr.peek().is_none()); // ... and terminate the channel while that's happening. - chan.terminate().await; + chan.terminate(); }; let (rr_s, _) = join!(run_reactor, exit_then_check); @@ -406,28 +414,23 @@ pub(crate) mod test { #[async_test] async fn new_circ_closed() { - let mut rng = rand::thread_rng(); let (chan, mut reactor, mut output, _input) = new_reactor(); - let (pending, _circr) = chan.new_circ(&mut rng).await.unwrap(); + let (ret, reac) = futures::join!(chan.new_circ(), reactor.run_once()); + let (pending, _circr) = ret.unwrap(); + assert!(reac.is_ok()); let id = pending.peek_circid().await; - { - let mut circs = reactor.circs.lock().await; - let ent = circs.get_mut(id); - assert!(matches!(ent, Some(CircEnt::Opening(_, _)))); - } + let ent = reactor.circs.get_mut(id); + assert!(matches!(ent, Some(CircEnt::Opening(_, _)))); // Now drop the circuit; this should tell the reactor to remove // the circuit from the map. drop(pending); reactor.run_once().await.unwrap(); - { - let mut circs = reactor.circs.lock().await; - let ent = circs.get_mut(id); - assert!(matches!(ent, Some(CircEnt::DestroySent(_)))); - } + let ent = reactor.circs.get_mut(id); + assert!(matches!(ent, Some(CircEnt::DestroySent(_)))); let cell = output.next().await.unwrap(); assert_eq!(cell.circid(), id); assert!(matches!(cell.msg(), ChanMsg::Destroy(_))); @@ -440,25 +443,26 @@ pub(crate) mod test { let mut rng = rand::thread_rng(); let (chan, mut reactor, mut output, mut input) = new_reactor(); - let (pending, _circr) = chan.new_circ(&mut rng).await.unwrap(); + let (ret, reac) = futures::join!(chan.new_circ(), reactor.run_once()); + let (pending, _circr) = ret.unwrap(); + assert!(reac.is_ok()); let circparams = CircParameters::default(); let id = pending.peek_circid().await; - { - let mut circs = reactor.circs.lock().await; - let ent = circs.get_mut(id); - assert!(matches!(ent, Some(CircEnt::Opening(_, _)))); - } + let ent = reactor.circs.get_mut(id); + assert!(matches!(ent, Some(CircEnt::Opening(_, _)))); // We'll get a bad handshake result from this createdfast cell. let created_cell = ChanCell::new(id, msg::CreatedFast::new(*b"x").into()); input.send(Ok(created_cell)).await.unwrap(); - let (circ, reac) = futures::join!( - pending.create_firsthop_fast(&mut rng, &circparams), - reactor.run_once() - ); + let (circ, reac) = + futures::join!(pending.create_firsthop_fast(&mut rng, &circparams), async { + reactor.run_once().await?; + reactor.run_once().await?; + Ok::<(), ReactorError>(()) + }); // Make sure statuses are as expected. assert!(matches!(circ.err().unwrap(), Error::BadHandshake)); assert!(reac.is_ok()); @@ -469,19 +473,13 @@ pub(crate) mod test { // The circid now counts as open, since as far as the reactor knows, // it was accepted. (TODO: is this a bug?) - { - let mut circs = reactor.circs.lock().await; - let ent = circs.get_mut(id); - assert!(matches!(ent, Some(CircEnt::Open(_)))); - } + let ent = reactor.circs.get_mut(id); + assert!(matches!(ent, Some(CircEnt::Open(_)))); // But the next run if the reactor will make the circuit get closed. reactor.run_once().await.unwrap(); - { - let mut circs = reactor.circs.lock().await; - let ent = circs.get_mut(id); - assert!(matches!(ent, Some(CircEnt::DestroySent(_)))); - } + let ent = reactor.circs.get_mut(id); + assert!(matches!(ent, Some(CircEnt::DestroySent(_)))); } // Try incoming cells that shouldn't arrive on channels. @@ -562,15 +560,18 @@ pub(crate) mod test { let (_chan, mut reactor, _output, mut input) = new_reactor(); let (_circ_stream_7, mut circ_stream_13) = { - let mut circmap = reactor.circs.lock().await; let (snd1, _rcv1) = oneshot::channel(); let (snd2, rcv2) = mpsc::channel(64); - circmap.put_unchecked(7.into(), CircEnt::Opening(snd1, snd2)); + reactor + .circs + .put_unchecked(7.into(), CircEnt::Opening(snd1, snd2)); let (snd3, rcv3) = mpsc::channel(64); - circmap.put_unchecked(13.into(), CircEnt::Open(snd3)); + reactor.circs.put_unchecked(13.into(), CircEnt::Open(snd3)); - circmap.put_unchecked(23.into(), CircEnt::DestroySent(HalfCirc::new(25))); + reactor + .circs + .put_unchecked(23.into(), CircEnt::DestroySent(HalfCirc::new(25))); (rcv2, rcv3) }; @@ -640,15 +641,18 @@ pub(crate) mod test { let (_chan, mut reactor, _output, mut input) = new_reactor(); let (circ_oneshot_7, mut circ_stream_13) = { - let mut circmap = reactor.circs.lock().await; let (snd1, rcv1) = oneshot::channel(); let (snd2, _rcv2) = mpsc::channel(64); - circmap.put_unchecked(7.into(), CircEnt::Opening(snd1, snd2)); + reactor + .circs + .put_unchecked(7.into(), CircEnt::Opening(snd1, snd2)); let (snd3, rcv3) = mpsc::channel(64); - circmap.put_unchecked(13.into(), CircEnt::Open(snd3)); + reactor.circs.put_unchecked(13.into(), CircEnt::Open(snd3)); - circmap.put_unchecked(23.into(), CircEnt::DestroySent(HalfCirc::new(25))); + reactor + .circs + .put_unchecked(23.into(), CircEnt::DestroySent(HalfCirc::new(25))); (rcv1, rcv3) }; diff --git a/crates/tor-proto/src/circuit.rs b/crates/tor-proto/src/circuit.rs index 7876cd1d6..e6f7a43cf 100644 --- a/crates/tor-proto/src/circuit.rs +++ b/crates/tor-proto/src/circuit.rs @@ -175,7 +175,7 @@ struct ClientCircImpl { id: CircId, /// The channel that this circuit uses to send its cells to the /// next hop. - channel: Arc, + channel: Channel, /// The cryptographic state for this circuit for outbound cells. /// This object is divided into multiple layers, each of which is /// shared with one hop of the circuit @@ -857,7 +857,7 @@ impl PendingClientCirc { /// pub(crate) fn new( id: CircId, - channel: Arc, + channel: Channel, createdreceiver: oneshot::Receiver, circ_closed: Option, input: mpsc::Receiver, @@ -1152,14 +1152,18 @@ fn resolvedval_to_result(val: ResolvedVal) -> Result { #[cfg(test)] mod test { #![allow(clippy::unwrap_used)] + use super::*; - use crate::channel::test::fake_channel; + use crate::channel::test::new_reactor; use chanmsg::{ChanMsg, Created2, CreatedFast}; + use futures::channel::mpsc::{Receiver, Sender}; use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::sink::SinkExt; use futures::stream::StreamExt; - use futures_await_test::async_test; use hex_literal::hex; + use tokio::runtime::Handle; + use tokio_crate as tokio; + use tokio_crate::test as async_test; use tor_cell::chancell::msg as chanmsg; use tor_cell::relaycell::msg as relaymsg; use tor_llcrypto::pk; @@ -1219,6 +1223,16 @@ mod test { ) } + fn working_fake_channel() -> ( + Channel, + Receiver, + Sender>, + ) { + let (channel, chan_reactor, rx, tx) = new_reactor(); + Handle::current().spawn(chan_reactor.run()); + (channel, rx, tx) + } + async fn test_create(fast: bool) { // We want to try progressing from a pending circuit to a circuit // via a crate_fast handshake. @@ -1226,7 +1240,7 @@ mod test { use crate::crypto::handshake::{fast::CreateFastServer, ntor::NtorServer, ServerHandshake}; use futures::future::FutureExt; - let (chan, mut ch) = fake_channel(); + let (chan, mut rx, _sink) = working_fake_channel(); let circid = 128.into(); let (created_send, created_recv) = oneshot::channel(); let (_circmsg_send, circmsg_recv) = mpsc::channel(64); @@ -1244,7 +1258,7 @@ mod test { // Future to pretend to be a relay on the other end of the circuit. let simulate_relay_fut = async move { let mut rng = rand::thread_rng(); - let create_cell = ch.cells.next().await.unwrap(); + let create_cell = rx.next().await.unwrap(); assert_eq!(create_cell.circid(), 128.into()); let reply = if fast { let cf = match create_cell.msg() { @@ -1269,13 +1283,17 @@ mod test { let mut rng = rand::thread_rng(); let target = example_target(); let params = CircParameters::default(); - if fast { + let ret = if fast { + eprintln!("doing fast create"); pending.create_firsthop_fast(&mut rng, ¶ms).await } else { + eprintln!("doing ntor create"); pending .create_firsthop_ntor(&mut rng, &target, ¶ms) .await - } + }; + eprintln!("create done: result {:?}", ret); + ret }; // Future to run the reactor. let reactor_fut = reactor.run_once().map(|_| ()); @@ -1347,7 +1365,7 @@ mod test { // Helper: set up a 3-hop circuit with no encryption, where the // next inbound message seems to come from hop next_msg_from async fn newcirc_ext( - chan: Arc, + chan: Channel, next_msg_from: HopNum, ) -> ( Arc, @@ -1361,7 +1379,7 @@ mod test { let (pending, mut reactor) = PendingClientCirc::new( circid, - Arc::clone(&chan), + chan, created_recv, None, // circ_closed. circmsg_recv, @@ -1394,7 +1412,7 @@ mod test { // Helper: set up a 3-hop circuit with no encryption, where the // next inbound message seems to come from hop next_msg_from async fn newcirc( - chan: Arc, + chan: Channel, ) -> ( Arc, reactor::Reactor, @@ -1406,7 +1424,7 @@ mod test { // Try sending a cell via send_relay_cell #[async_test] async fn send_simple() { - let (chan, mut ch) = fake_channel(); + let (chan, mut rx, _sink) = working_fake_channel(); let (circ, _reactor, _send) = newcirc(chan).await; let begindir = RelayCell::new(0.into(), RelayMsg::BeginDir); circ.send_relay_cell(2.into(), false, begindir) @@ -1415,7 +1433,7 @@ mod test { // Here's what we tried to put on the TLS channel. Note that // we're using dummy relay crypto for testing convenience. - let rcvd = ch.cells.next().await.unwrap(); + let rcvd = rx.next().await.unwrap(); assert_eq!(rcvd.circid(), 128.into()); let m = match rcvd.into_circid_and_msg().1 { ChanMsg::Relay(r) => RelayCell::decode(r.into_relay_body()).unwrap(), @@ -1428,7 +1446,7 @@ mod test { // for a specific circuit. #[async_test] async fn recv_meta() { - let (chan, _ch) = fake_channel(); + let (chan, _, _sink) = working_fake_channel(); let (circ, mut reactor, mut sink) = newcirc(chan).await; // 1: Try doing it via handle_meta_cell directly. @@ -1488,7 +1506,7 @@ mod test { async fn extend() { use crate::crypto::handshake::{ntor::NtorServer, ServerHandshake}; - let (chan, mut ch) = fake_channel(); + let (chan, mut rx, _sink) = working_fake_channel(); let (circ, mut reactor, mut sink) = newcirc(chan).await; let params = CircParameters::default(); @@ -1501,7 +1519,7 @@ mod test { let reply_fut = async move { // We've disabled encryption on this circuit, so we can just // read the extend2 cell. - let (id, chmsg) = ch.cells.next().await.unwrap().into_circid_and_msg(); + let (id, chmsg) = rx.next().await.unwrap().into_circid_and_msg(); assert_eq!(id, 128.into()); let rmsg = match chmsg { ChanMsg::RelayEarly(r) => RelayCell::decode(r.into_relay_body()).unwrap(), @@ -1530,7 +1548,7 @@ mod test { } async fn bad_extend_test_impl(reply_hop: HopNum, bad_reply: ClientCircChanMsg) -> Error { - let (chan, _ch) = fake_channel(); + let (chan, _rx, _sink) = working_fake_channel(); let (circ, mut reactor, mut sink) = newcirc_ext(chan, reply_hop).await; let params = CircParameters::default(); @@ -1571,7 +1589,7 @@ mod test { Error::CircDestroy(s) => { assert_eq!(s, "Circuit closed while waiting for EXTENDED2"); } - _ => panic!(), + x => panic!("got other error: {}", x), } } @@ -1609,7 +1627,7 @@ mod test { #[async_test] async fn begindir() { - let (chan, mut ch) = fake_channel(); + let (chan, mut rx, _sink) = working_fake_channel(); let (circ, mut reactor, mut sink) = newcirc(chan).await; let begin_and_send_fut = async move { @@ -1628,7 +1646,7 @@ mod test { let reply_fut = async move { // We've disabled encryption on this circuit, so we can just // read the begindir cell. - let (id, chmsg) = ch.cells.next().await.unwrap().into_circid_and_msg(); + let (id, chmsg) = rx.next().await.unwrap().into_circid_and_msg(); assert_eq!(id, 128.into()); // hardcoded circid. let rmsg = match chmsg { ChanMsg::Relay(r) => RelayCell::decode(r.into_relay_body()).unwrap(), @@ -1642,7 +1660,7 @@ mod test { sink.send(rmsg_to_ccmsg(streamid, connected)).await.unwrap(); // Now read a DATA cell... - let (id, chmsg) = ch.cells.next().await.unwrap().into_circid_and_msg(); + let (id, chmsg) = rx.next().await.unwrap().into_circid_and_msg(); assert_eq!(id, 128.into()); let rmsg = match chmsg { ChanMsg::Relay(r) => RelayCell::decode(r.into_relay_body()).unwrap(), @@ -1690,7 +1708,7 @@ mod test { crate::circuit::reactor::Reactor, usize, ) { - let (chan, mut ch) = fake_channel(); + let (chan, mut rx, _sink) = working_fake_channel(); let (circ, mut reactor, mut sink) = newcirc(chan).await; let (snd_done, mut rcv_done) = oneshot::channel::<()>(); @@ -1714,7 +1732,7 @@ mod test { let receive_fut = async move { // Read the begindir cell. - let (_id, chmsg) = ch.cells.next().await.unwrap().into_circid_and_msg(); + let (_id, chmsg) = rx.next().await.unwrap().into_circid_and_msg(); let rmsg = match chmsg { ChanMsg::Relay(r) => RelayCell::decode(r.into_relay_body()).unwrap(), _ => panic!(), @@ -1729,7 +1747,7 @@ mod test { let mut cells_received = 0_usize; while bytes_received < n_to_send { // Read a data cell, and remember how much we got. - let (id, chmsg) = ch.cells.next().await.unwrap().into_circid_and_msg(); + let (id, chmsg) = rx.next().await.unwrap().into_circid_and_msg(); assert_eq!(id, 128.into()); let rmsg = match chmsg { diff --git a/crates/tor-proto/src/circuit/halfstream.rs b/crates/tor-proto/src/circuit/halfstream.rs index 0f648955d..d1901dfcc 100644 --- a/crates/tor-proto/src/circuit/halfstream.rs +++ b/crates/tor-proto/src/circuit/halfstream.rs @@ -84,7 +84,8 @@ mod test { #![allow(clippy::unwrap_used)] use super::*; use crate::circuit::sendme::{StreamRecvWindow, StreamSendWindow}; - use futures_await_test::async_test; + use tokio::test as async_test; + use tokio_crate as tokio; use tor_cell::relaycell::msg; #[async_test] diff --git a/crates/tor-proto/src/circuit/sendme.rs b/crates/tor-proto/src/circuit/sendme.rs index 40511f4f1..1110852d8 100644 --- a/crates/tor-proto/src/circuit/sendme.rs +++ b/crates/tor-proto/src/circuit/sendme.rs @@ -287,7 +287,8 @@ mod test { #![allow(clippy::unwrap_used)] use super::*; use futures::FutureExt; - use futures_await_test::async_test; + use tokio::test as async_test; + use tokio_crate as tokio; use tor_cell::relaycell::{msg, RelayCell}; #[test]