Merge branch 'sleep' into 'main'

Plumb a SleepProvider (now Clone + ....) into Channel

See merge request tpo/core/arti!569
This commit is contained in:
Ian Jackson 2022-06-08 10:46:37 +00:00
commit d202c3e9ca
13 changed files with 143 additions and 64 deletions

View File

@ -61,6 +61,7 @@ async fn main() -> Result<()> {
/// A custom TCP provider that relies on an existing TCP provider (`inner`), but modifies its
/// behavior.
#[derive(Clone)]
struct CustomTcpProvider<T> {
/// The underlying TCP provider.
inner: T,

View File

@ -182,7 +182,10 @@ impl<R: Runtime> ChanBuilder<R> {
let mut builder = ChannelBuilder::new();
builder.set_declared_addr(addr);
let chan = builder
.launch(tls)
.launch(
tls,
self.runtime.clone(), /* TODO provide ZST SleepProvider instead */
)
.connect(|| self.runtime.wallclock())
.await
.map_err(Error::from_proto_no_skew)?;

View File

@ -44,6 +44,7 @@ tor-error = { path = "../tor-error", version = "0.3.1" }
tor-linkspec = { path = "../tor-linkspec", version = "0.3.0" }
tor-llcrypto = { path = "../tor-llcrypto", version = "0.3.0" }
tor-protover = { path = "../tor-protover", version = "0.3.0" }
tor-rtcompat = { path = "../tor-rtcompat", version = "0.4.0" }
tracing = "0.1.18"
typenum = "1.12"
zeroize = "1"

View File

@ -0,0 +1 @@
BREAKING: Channels now require a SleepProvider

View File

@ -75,6 +75,7 @@ use tor_error::internal;
use tor_linkspec::{ChanTarget, OwnedChanTarget};
use tor_llcrypto::pk::ed25519::Ed25519Identity;
use tor_llcrypto::pk::rsa::RsaIdentity;
use tor_rtcompat::SleepProvider;
use asynchronous_codec as futures_codec;
use futures::channel::{mpsc, oneshot};
@ -225,11 +226,12 @@ impl ChannelBuilder {
/// authentication info from the relay: call `check()` on the result
/// to check that. Finally, to finish the handshake, call `finish()`
/// on the result of _that_.
pub fn launch<T>(self, tls: T) -> OutboundClientHandshake<T>
pub fn launch<T, S>(self, tls: T, sleep_prov: S) -> OutboundClientHandshake<T, S>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
S: SleepProvider,
{
handshake::OutboundClientHandshake::new(tls, self.target)
handshake::OutboundClientHandshake::new(tls, self.target, sleep_prov)
}
}
@ -239,14 +241,18 @@ impl Channel {
/// Internal method, called to finalize the channel when we've
/// sent our netinfo cell, received the peer's netinfo cell, and
/// we're finally ready to create circuits.
fn new(
fn new<S>(
link_protocol: u16,
sink: BoxedChannelSink,
stream: BoxedChannelStream,
unique_id: UniqId,
peer_id: OwnedChanTarget,
clock_skew: ClockSkew,
) -> (Self, reactor::Reactor) {
sleep_prov: S,
) -> (Self, reactor::Reactor<S>)
where
S: SleepProvider,
{
use circmap::{CircIdRange, CircMap};
let circmap = CircMap::new(CircIdRange::High);
@ -281,6 +287,7 @@ impl Channel {
circ_unique_id_ctx: CircUniqIdContext::new(),
link_protocol,
details,
sleep_prov,
};
(channel, reactor)
@ -461,6 +468,7 @@ pub(crate) mod test {
use crate::channel::codec::test::MsgBuf;
pub(crate) use crate::channel::reactor::test::new_reactor;
use tor_cell::chancell::{msg, ChanCell};
use tor_rtcompat::PreferredRuntime;
/// Make a new fake reactor-less channel. For testing only, obviously.
pub(crate) fn fake_channel(details: Arc<ChannelDetails>) -> Channel {
@ -513,10 +521,11 @@ pub(crate) mod test {
#[test]
fn chanbuilder() {
let rt = PreferredRuntime::create().unwrap();
let mut builder = ChannelBuilder::default();
builder.set_declared_addr("127.0.0.1:9001".parse().unwrap());
let tls = MsgBuf::new(&b""[..]);
let _outbound = builder.launch(tls);
let _outbound = builder.launch(tls, rt);
}
#[test]

View File

@ -12,6 +12,7 @@ use crate::channel::UniqId;
use crate::util::skew::ClockSkew;
use crate::{Error, Result};
use tor_cell::chancell::{msg, ChanCmd};
use tor_rtcompat::SleepProvider;
use std::net::SocketAddr;
use std::sync::Arc;
@ -34,7 +35,13 @@ use tracing::{debug, trace};
static LINK_PROTOCOLS: &[u16] = &[4];
/// A raw client channel on which nothing has been done.
pub struct OutboundClientHandshake<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
pub struct OutboundClientHandshake<
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
S: SleepProvider,
> {
/// Runtime handle (insofar as we need it)
sleep_prov: S,
/// Underlying TLS stream.
///
/// (We don't enforce that this is actually TLS, but if it isn't, the
@ -51,7 +58,9 @@ pub struct OutboundClientHandshake<T: AsyncRead + AsyncWrite + Send + Unpin + 's
/// A client channel on which versions have been negotiated and the
/// relay's handshake has been read, but where the certs have not
/// been checked.
pub struct UnverifiedChannel<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
pub struct UnverifiedChannel<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider> {
/// Runtime handle (insofar as we need it)
sleep_prov: S,
/// The negotiated link protocol. Must be a member of LINK_PROTOCOLS
link_protocol: u16,
/// The Source+Sink on which we're reading and writing cells.
@ -79,7 +88,9 @@ pub struct UnverifiedChannel<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>
/// This type is separate from UnverifiedChannel, since finishing the
/// handshake requires a bunch of CPU, and you might want to do it as
/// a separate task or after a yield.
pub struct VerifiedChannel<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
pub struct VerifiedChannel<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider> {
/// Runtime handle (insofar as we need it)
sleep_prov: S,
/// The negotiated link protocol.
link_protocol: u16,
/// The Source+Sink on which we're reading and writing cells.
@ -105,13 +116,16 @@ fn codec_err_to_handshake(err: CodecError) -> Error {
}
}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> OutboundClientHandshake<T> {
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider>
OutboundClientHandshake<T, S>
{
/// Construct a new OutboundClientHandshake.
pub(crate) fn new(tls: T, target_addr: Option<SocketAddr>) -> Self {
pub(crate) fn new(tls: T, target_addr: Option<SocketAddr>, sleep_prov: S) -> Self {
Self {
tls,
target_addr,
unique_id: UniqId::new(),
sleep_prov,
}
}
@ -120,7 +134,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> OutboundClientHandshake
///
/// Takes a function that reports the current time. In theory, this can just be
/// `SystemTime::now()`.
pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedChannel<T>>
pub async fn connect<F>(mut self, now_fn: F) -> Result<UnverifiedChannel<T, S>>
where
F: FnOnce() -> SystemTime,
{
@ -265,13 +279,14 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> OutboundClientHandshake
clock_skew,
target_addr: self.target_addr,
unique_id: self.unique_id,
sleep_prov: self.sleep_prov.clone(),
})
}
}
}
}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> UnverifiedChannel<T> {
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider> UnverifiedChannel<T, S> {
/// Return the reported clock skew from this handshake.
///
/// Note that the skew reported by this function might not be "true": the
@ -302,7 +317,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> UnverifiedChannel<T> {
peer: &U,
peer_cert: &[u8],
now: Option<std::time::SystemTime>,
) -> Result<VerifiedChannel<T>> {
) -> Result<VerifiedChannel<T, S>> {
let peer_cert_sha256 = ll::d::Sha256::digest(peer_cert);
self.check_internal(peer, &peer_cert_sha256[..], now)
}
@ -314,7 +329,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> UnverifiedChannel<T> {
peer: &U,
peer_cert_sha256: &[u8],
now: Option<SystemTime>,
) -> Result<VerifiedChannel<T>> {
) -> Result<VerifiedChannel<T, S>> {
use tor_cert::CertType;
use tor_checkable::*;
@ -510,18 +525,19 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> UnverifiedChannel<T> {
ed25519_id,
rsa_id,
clock_skew: self.clock_skew,
sleep_prov: self.sleep_prov,
})
}
}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> VerifiedChannel<T> {
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider> VerifiedChannel<T, S> {
/// Send a 'Netinfo' message to the relay to finish the handshake,
/// and create an open channel and reactor.
///
/// The channel is used to send cells, and to create outgoing circuits.
/// The reactor is used to route incoming messages to their appropriate
/// circuit.
pub async fn finish(mut self) -> Result<(super::Channel, super::reactor::Reactor)> {
pub async fn finish(mut self) -> Result<(super::Channel, super::reactor::Reactor<S>)> {
// We treat a completed channel -- that is to say, one where the
// authentication is finished -- as incoming traffic.
//
@ -556,6 +572,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> VerifiedChannel<T> {
self.unique_id,
peer_id,
self.clock_skew,
self.sleep_prov,
))
}
}
@ -570,6 +587,7 @@ pub(super) mod test {
use crate::channel::codec::test::MsgBuf;
use crate::Result;
use tor_cell::chancell::msg;
use tor_rtcompat::{PreferredRuntime, Runtime};
const VERSIONS: &[u8] = &hex!("0000 07 0006 0003 0004 0005");
// no certificates in this cell, but connect() doesn't care.
@ -606,7 +624,7 @@ pub(super) mod test {
#[test]
fn connect_ok() -> Result<()> {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1217696400);
let mut buf = Vec::new();
// versions cell
@ -616,7 +634,7 @@ pub(super) mod test {
// netinfo cell -- quite minimal.
add_padded(&mut buf, NETINFO_PREFIX);
let mb = MsgBuf::new(&buf[..]);
let handshake = OutboundClientHandshake::new(mb, None);
let handshake = OutboundClientHandshake::new(mb, None, rt.clone());
let unverified = handshake.connect(|| now).await?;
assert_eq!(unverified.link_protocol, 4);
@ -632,7 +650,7 @@ pub(super) mod test {
buf.extend_from_slice(VPADDING);
add_padded(&mut buf, NETINFO_PREFIX_WITH_TIME);
let mb = MsgBuf::new(&buf[..]);
let handshake = OutboundClientHandshake::new(mb, None);
let handshake = OutboundClientHandshake::new(mb, None, rt.clone());
let unverified = handshake.connect(|| now).await?;
// Correct timestamp in the NETINFO, so no skew.
assert_eq!(unverified.clock_skew(), ClockSkew::None);
@ -640,7 +658,7 @@ pub(super) mod test {
// Now pretend our clock is fast.
let now2 = now + Duration::from_secs(3600);
let mb = MsgBuf::new(&buf[..]);
let handshake = OutboundClientHandshake::new(mb, None);
let handshake = OutboundClientHandshake::new(mb, None, rt.clone());
let unverified = handshake.connect(|| now2).await?;
assert_eq!(
unverified.clock_skew(),
@ -651,23 +669,26 @@ pub(super) mod test {
})
}
async fn connect_err<T: Into<Vec<u8>>>(input: T) -> Error {
async fn connect_err<T: Into<Vec<u8>>, S>(input: T, sleep_prov: S) -> Error
where
S: SleepProvider,
{
let mb = MsgBuf::new(input);
let handshake = OutboundClientHandshake::new(mb, None);
let handshake = OutboundClientHandshake::new(mb, None, sleep_prov);
handshake.connect(SystemTime::now).await.err().unwrap()
}
#[test]
fn connect_badver() {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
let err = connect_err(&b"HTTP://"[..]).await;
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let err = connect_err(&b"HTTP://"[..], rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
"handshake protocol violation: Doesn't seem to be a tor relay"
);
let err = connect_err(&hex!("0000 07 0004 1234 ffff")[..]).await;
let err = connect_err(&hex!("0000 07 0004 1234 ffff")[..], rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
@ -678,25 +699,25 @@ pub(super) mod test {
#[test]
fn connect_cellparse() {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let mut buf = Vec::new();
buf.extend_from_slice(VERSIONS);
// Here's a certs cell that will fail.
buf.extend_from_slice(&hex!("00000000 81 0001 01")[..]);
let err = connect_err(buf).await;
let err = connect_err(buf, rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
});
}
#[test]
fn connect_duplicates() {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let mut buf = Vec::new();
buf.extend_from_slice(VERSIONS);
buf.extend_from_slice(NOCERTS);
buf.extend_from_slice(NOCERTS);
add_netinfo(&mut buf);
let err = connect_err(buf).await;
let err = connect_err(buf, rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
@ -709,7 +730,7 @@ pub(super) mod test {
buf.extend_from_slice(AUTHCHALLENGE);
buf.extend_from_slice(AUTHCHALLENGE);
add_netinfo(&mut buf);
let err = connect_err(buf).await;
let err = connect_err(buf, rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
@ -720,11 +741,11 @@ pub(super) mod test {
#[test]
fn connect_missing_certs() {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let mut buf = Vec::new();
buf.extend_from_slice(VERSIONS);
add_netinfo(&mut buf);
let err = connect_err(buf).await;
let err = connect_err(buf, rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
@ -735,12 +756,12 @@ pub(super) mod test {
#[test]
fn connect_misplaced_cell() {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let mut buf = Vec::new();
buf.extend_from_slice(VERSIONS);
// here's a create cell.
add_padded(&mut buf, &hex!("00000001 01")[..]);
let err = connect_err(buf).await;
let err = connect_err(buf, rt.clone()).await;
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
@ -749,7 +770,10 @@ pub(super) mod test {
});
}
fn make_unverified(certs: msg::Certs) -> UnverifiedChannel<MsgBuf> {
fn make_unverified<R>(certs: msg::Certs, runtime: R) -> UnverifiedChannel<MsgBuf, R>
where
R: Runtime,
{
let localhost = std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST);
let netinfo_cell = msg::Netinfo::for_client(Some(localhost));
let clock_skew = ClockSkew::None;
@ -761,6 +785,7 @@ pub(super) mod test {
clock_skew,
target_addr: None,
unique_id: UniqId::new(),
sleep_prov: runtime,
}
}
@ -785,14 +810,18 @@ pub(super) mod test {
SystemTime::UNIX_EPOCH + Duration::new(1601143280, 0)
}
fn certs_test(
fn certs_test<R>(
certs: msg::Certs,
when: Option<SystemTime>,
peer_ed: &[u8],
peer_rsa: &[u8],
peer_cert_sha256: &[u8],
) -> Result<VerifiedChannel<MsgBuf>> {
let unver = make_unverified(certs);
runtime: &R,
) -> Result<VerifiedChannel<MsgBuf, R>>
where
R: Runtime,
{
let unver = make_unverified(certs, runtime.clone());
let ed = Ed25519Identity::from_bytes(peer_ed).unwrap();
let rsa = RsaIdentity::from_bytes(peer_rsa).unwrap();
let chan = DummyChanTarget { ed, rsa };
@ -802,12 +831,14 @@ pub(super) mod test {
// no certs at all!
#[test]
fn certs_none() {
let rt = PreferredRuntime::create().unwrap();
let err = certs_test(
msg::Certs::new_empty(),
None,
&[0_u8; 32],
&[0_u8; 20],
&[0_u8; 128],
&rt,
)
.err()
.unwrap();
@ -819,6 +850,7 @@ pub(super) mod test {
#[test]
fn certs_good() {
let rt = PreferredRuntime::create().unwrap();
let mut certs = msg::Certs::new_empty();
certs.push_cert_body(2.into(), certs::CERT_T2);
@ -831,12 +863,14 @@ pub(super) mod test {
certs::PEER_ED,
certs::PEER_RSA,
certs::PEER_CERT_DIGEST,
&rt,
);
let _ = res.unwrap();
}
#[test]
fn certs_missing() {
let rt = PreferredRuntime::create().unwrap();
let all_certs = [
(2, certs::CERT_T2, "Couldn't find RSA identity key"),
(7, certs::CERT_T7, "No RSA->Ed crosscert"),
@ -862,6 +896,7 @@ pub(super) mod test {
certs::PEER_ED,
certs::PEER_RSA,
certs::PEER_CERT_DIGEST,
&rt,
)
.err()
.unwrap();
@ -875,6 +910,7 @@ pub(super) mod test {
#[test]
fn certs_wrongtarget() {
let rt = PreferredRuntime::create().unwrap();
let mut certs = msg::Certs::new_empty();
certs.push_cert_body(2.into(), certs::CERT_T2);
certs.push_cert_body(5.into(), certs::CERT_T5);
@ -886,6 +922,7 @@ pub(super) mod test {
&[0x10; 32],
certs::PEER_RSA,
certs::PEER_CERT_DIGEST,
&rt,
)
.err()
.unwrap();
@ -901,6 +938,7 @@ pub(super) mod test {
certs::PEER_ED,
&[0x99; 20],
certs::PEER_CERT_DIGEST,
&rt,
)
.err()
.unwrap();
@ -916,6 +954,7 @@ pub(super) mod test {
certs::PEER_ED,
certs::PEER_RSA,
&[0; 32],
&rt,
)
.err()
.unwrap();
@ -928,6 +967,7 @@ pub(super) mod test {
#[test]
fn certs_badsig() {
let rt = PreferredRuntime::create().unwrap();
fn munge(inp: &[u8]) -> Vec<u8> {
let mut v: Vec<u8> = inp.into();
v[inp.len() - 1] ^= 0x10;
@ -944,6 +984,7 @@ pub(super) mod test {
certs::PEER_ED,
certs::PEER_RSA,
certs::PEER_CERT_DIGEST,
&rt,
)
.err()
.unwrap();
@ -964,6 +1005,7 @@ pub(super) mod test {
certs::PEER_ED,
certs::PEER_RSA,
certs::PEER_CERT_DIGEST,
&rt,
)
.err()
.unwrap();
@ -999,7 +1041,7 @@ pub(super) mod test {
#[test]
fn test_finish() {
tor_rtcompat::test_with_one_runtime!(|_rt| async move {
tor_rtcompat::test_with_one_runtime!(|rt| async move {
let ed25519_id = [3_u8; 32].into();
let rsa_id = [4_u8; 20].into();
let peer_addr = "127.1.1.2:443".parse().unwrap();
@ -1011,6 +1053,7 @@ pub(super) mod test {
ed25519_id,
rsa_id,
clock_skew: ClockSkew::None,
sleep_prov: rt,
};
let (_chan, _reactor) = ver.finish().await.unwrap();

View File

@ -13,6 +13,7 @@ use crate::{Error, Result};
use tor_basic_utils::futures::SinkExt as _;
use tor_cell::chancell::msg::{Destroy, DestroyReason};
use tor_cell::chancell::{msg::ChanMsg, ChanCell, CircId};
use tor_rtcompat::SleepProvider;
use futures::channel::{mpsc, oneshot};
@ -74,7 +75,7 @@ pub(super) enum CtrlMsg {
/// This type is returned when you finish a channel; you need to spawn a
/// new task that calls `run()` on it.
#[must_use = "If you don't call run() on a reactor, the channel won't work."]
pub struct Reactor {
pub struct Reactor<S: SleepProvider> {
/// A receiver for control messages from `Channel` objects.
pub(super) control: mpsc::UnboundedReceiver<CtrlMsg>,
/// A receiver for cells to be sent on this reactor's sink.
@ -98,19 +99,22 @@ pub struct Reactor {
/// What link protocol is the channel using?
#[allow(dead_code)] // We don't support protocols where this would matter
pub(super) link_protocol: u16,
/// Sleep Provider (dummy for now, this is going to be in the padding timer)
#[allow(dead_code)]
pub(super) sleep_prov: S,
}
/// Allows us to just say debug!("{}: Reactor did a thing", &self, ...)
///
/// There is no risk of confusion because no-one would try to print a
/// Reactor for some other reason.
impl fmt::Display for Reactor {
impl<S: SleepProvider> fmt::Display for Reactor<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.details.unique_id, f)
}
}
impl Reactor {
impl<S: SleepProvider> Reactor<S> {
/// Launch the reactor, and run until the channel closes or we
/// encounter an error.
///
@ -367,12 +371,15 @@ pub(crate) mod test {
use futures::stream::StreamExt;
use futures::task::SpawnExt;
use tor_linkspec::OwnedChanTarget;
use tor_rtcompat::Runtime;
type CodecResult = std::result::Result<ChanCell, CodecError>;
pub(crate) fn new_reactor() -> (
pub(crate) fn new_reactor<R: Runtime>(
runtime: R,
) -> (
crate::channel::Channel,
Reactor,
Reactor<R>,
mpsc::Receiver<ChanCell>,
mpsc::Sender<CodecResult>,
) {
@ -392,6 +399,7 @@ pub(crate) mod test {
unique_id,
dummy_target,
crate::ClockSkew::None,
runtime,
);
(chan, reactor, recv1, send2)
}
@ -399,8 +407,8 @@ pub(crate) mod test {
// Try shutdown from inside run_once..
#[test]
fn shutdown() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
let (chan, mut reactor, _output, _input) = new_reactor();
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
let (chan, mut reactor, _output, _input) = new_reactor(rt);
chan.terminate();
let r = reactor.run_once().await;
@ -411,13 +419,13 @@ pub(crate) mod test {
// Try shutdown while reactor is running.
#[test]
fn shutdown2() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
// TODO: Ask a rust person if this is how to do this.
use futures::future::FutureExt;
use futures::join;
let (chan, reactor, _output, _input) = new_reactor();
let (chan, reactor, _output, _input) = new_reactor(rt);
// Let's get the reactor running...
let run_reactor = reactor.run().map(|x| x.is_ok()).shared();
@ -439,7 +447,7 @@ pub(crate) mod test {
#[test]
fn new_circ_closed() {
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
let (chan, mut reactor, mut output, _input) = new_reactor();
let (chan, mut reactor, mut output, _input) = new_reactor(rt.clone());
assert!(chan.duration_unused().is_some()); // unused yet
let (ret, reac) = futures::join!(chan.new_circ(), reactor.run_once());
@ -479,7 +487,7 @@ pub(crate) mod test {
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
use tor_cell::chancell::msg;
let (chan, mut reactor, mut output, mut input) = new_reactor();
let (chan, mut reactor, mut output, mut input) = new_reactor(rt.clone());
let (ret, reac) = futures::join!(chan.new_circ(), reactor.run_once());
let (pending, circr) = ret.unwrap();
@ -527,9 +535,9 @@ pub(crate) mod test {
// Try incoming cells that shouldn't arrive on channels.
#[test]
fn bad_cells() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
use tor_cell::chancell::msg;
let (_chan, mut reactor, _output, mut input) = new_reactor();
let (_chan, mut reactor, _output, mut input) = new_reactor(rt);
// We shouldn't get create cells, ever.
let create_cell = msg::Create2::new(4, *b"hihi").into();
@ -597,12 +605,12 @@ pub(crate) mod test {
#[test]
fn deliver_relay() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
use crate::circuit::celltypes::ClientCircChanMsg;
use futures::channel::oneshot;
use tor_cell::chancell::msg;
let (_chan, mut reactor, _output, mut input) = new_reactor();
let (_chan, mut reactor, _output, mut input) = new_reactor(rt);
let (_circ_stream_7, mut circ_stream_13) = {
let (snd1, _rcv1) = oneshot::channel();
@ -680,12 +688,12 @@ pub(crate) mod test {
#[test]
fn deliver_destroy() {
tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
tor_rtcompat::test_with_all_runtimes!(|rt| async move {
use crate::circuit::celltypes::*;
use futures::channel::oneshot;
use tor_cell::chancell::msg;
let (_chan, mut reactor, _output, mut input) = new_reactor();
let (_chan, mut reactor, _output, mut input) = new_reactor(rt);
let (circ_oneshot_7, mut circ_stream_13) = {
let (snd1, rcv1) = oneshot::channel();

View File

@ -747,7 +747,7 @@ mod test {
Receiver<ChanCell>,
Sender<std::result::Result<ChanCell, CodecError>>,
) {
let (channel, chan_reactor, rx, tx) = new_reactor();
let (channel, chan_reactor, rx, tx) = new_reactor(rt.clone());
rt.spawn(async {
let _ignore = chan_reactor.run().await;
})

View File

@ -0,0 +1 @@
BREAKING: Runtime subtraits (eg SleepProvider) are now all Clone + Send + Sync + 'static

View File

@ -72,6 +72,10 @@ where
impl<SpawnR, SleepR, TcpR, TlsR, UdpR> BlockOn for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR, UdpR>
where
SpawnR: BlockOn,
SleepR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
#[inline]
fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
@ -83,6 +87,10 @@ impl<SpawnR, SleepR, TcpR, TlsR, UdpR> SleepProvider
for CompoundRuntime<SpawnR, SleepR, TcpR, TlsR, UdpR>
where
SleepR: SleepProvider,
SpawnR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
type SleepFuture = SleepR::SleepFuture;
@ -133,6 +141,9 @@ impl<SpawnR, SleepR, TcpR, TlsR, UdpR, S> TlsProvider<S>
where
TcpR: TcpProvider,
TlsR: TlsProvider<S>,
SleepR: Clone + Send + Sync + 'static,
SpawnR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
type Connector = TlsR::Connector;
type TlsStream = TlsR::TlsStream;

View File

@ -13,7 +13,7 @@ use std::{
/// A [`TlsProvider`] that uses `native_tls`.
///
/// It supports wrapping any reasonable stream type that implements `AsyncRead` + `AsyncWrite`.
#[derive(Default)]
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct NativeTlsProvider {}

View File

@ -15,6 +15,7 @@ use std::{
/// A [`TlsProvider`] that uses `rustls`.
///
/// It supports wrapping any reasonable stream type that implements `AsyncRead` + `AsyncWrite`.
#[derive(Clone)]
#[non_exhaustive]
pub struct RustlsProvider {
/// Inner `ClientConfig` logic used to create connectors.

View File

@ -79,7 +79,7 @@ impl<T> Runtime for T where
/// Every `SleepProvider` also implements
/// [`SleepProviderExt`](crate::SleepProviderExt); see that trait
/// for other useful functions.
pub trait SleepProvider {
pub trait SleepProvider: Clone + Send + Sync + 'static {
/// A future returned by [`SleepProvider::sleep()`]
type SleepFuture: Future<Output = ()> + Send + 'static;
/// Return a future that will be ready after `duration` has
@ -128,7 +128,7 @@ pub trait SleepProvider {
}
/// Trait for a runtime that can block on a future.
pub trait BlockOn {
pub trait BlockOn: Clone + Send + Sync + 'static {
/// Run `future` until it is ready, and return its output.
fn block_on<F: Future>(&self, future: F) -> F::Output;
}
@ -142,7 +142,7 @@ pub trait BlockOn {
// TODO: Use of async_trait is not ideal, since we have to box with every
// call. Still, async_io basically makes that necessary :/
#[async_trait]
pub trait TcpProvider {
pub trait TcpProvider: Clone + Send + Sync + 'static {
/// The type for the TCP connections returned by [`Self::connect()`].
type TcpStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static;
/// The type for the TCP listeners returned by [`Self::listen()`].
@ -187,7 +187,7 @@ pub trait TcpListener {
/// Trait for a runtime that can send and receive UDP datagrams.
#[async_trait]
pub trait UdpProvider {
pub trait UdpProvider: Clone + Send + Sync + 'static {
/// The type of Udp Socket returned by [`Self::bind()`]
type UdpSocket: UdpSocket + Send + Sync + Unpin + 'static;
@ -270,7 +270,7 @@ pub trait TlsConnector<S> {
/// See the [`TlsConnector`] documentation for a discussion of the Tor-specific
/// limitations of this trait: If you are implementing something other than Tor,
/// this is **not** the functionality you want.
pub trait TlsProvider<S> {
pub trait TlsProvider<S>: Clone + Send + Sync + 'static {
/// The Connector object that this provider can return.
type Connector: TlsConnector<S, Conn = Self::TlsStream> + Send + Sync + Unpin;