Make TlsConnector wrap TCP connections, not create its own

`tor-rtcompat`'s `TlsConnector` trait previously included a method to
create a TLS-over-TCP connection, which implied creating a TCP stream
inside that method. This commit changes that, and makes the function
wrap a TCP stream, as returned from the runtime's `TcpProvider` trait
implementation, instead.

This means you can actually override `TcpProvider` and have it apply to
*all* connections Arti makes, which is useful for issues like arti#235
and other cases where you want to have a custom TCP stream
implementation.

This required updating the mock TCP/TLS types in `tor-rtmock` slightly;
due to the change in API, we now store whether a `LocalStream` should
actually be a TLS stream inside the stream itself, and check this
property on reads/writes in order to detect misuse. The fake TLS wrapper
checks this property and removes it in order to "wrap" the stream,
making reads and writes work again.
This commit is contained in:
eta 2021-12-02 11:44:29 +00:00
parent 36f6a61f05
commit b14c5f370e
10 changed files with 108 additions and 76 deletions

View File

@ -18,7 +18,7 @@ pub(crate) struct ChanBuilder<R: Runtime> {
/// Asynchronous runtime for TLS, TCP, spawning, and timeouts.
runtime: R,
/// Object to build TLS connections.
tls_connector: <R as TlsProvider>::Connector,
tls_connector: <R as TlsProvider<R::TcpStream>>::Connector,
}
impl<R: Runtime> ChanBuilder<R> {
@ -72,10 +72,13 @@ impl<R: Runtime> ChanBuilder<R> {
tracing::info!("Negotiating TLS with {}", addr);
// Establish a TCP connection.
let stream = self.runtime.connect(addr).await?;
// TODO: add a random hostname here if it will be used for SNI?
let tls = self
.tls_connector
.connect_unvalidated(addr, "ignored")
.negotiate_unvalidated(stream, "ignored")
.await?;
let peer_cert = tls

View File

@ -913,7 +913,7 @@ mod test {
// Pick a guard and mark it as confirmed.
let id1 = guards.sample[0].clone();
guards.record_success(&id1, &params, t1);
assert_eq!(&guards.confirmed, &[id1.clone()]);
assert_eq!(&guards.confirmed, &[id1]);
let one_day = Duration::from_secs(86400);
guards.expire_old_guards(&params, t1 + one_day * 30);

View File

@ -121,7 +121,6 @@ mod tls {
use std::convert::TryFrom;
use std::io::{Error as IoError, Result as IoResult};
use std::net::SocketAddr;
/// The TLS-over-TCP type returned by this module.
#[allow(unreachable_pub)] // not actually unreachable; depends on features
@ -142,16 +141,14 @@ mod tls {
}
#[async_trait]
impl crate::traits::TlsConnector for TlsConnector {
impl crate::traits::TlsConnector<TcpStream> for TlsConnector {
type Conn = TlsStream;
async fn connect_unvalidated(
async fn negotiate_unvalidated(
&self,
addr: &SocketAddr,
stream: TcpStream,
hostname: &str,
) -> IoResult<Self::Conn> {
let stream = TcpStream::connect(addr).await?;
let conn = self
.connector
.connect(hostname, stream)
@ -183,6 +180,7 @@ mod tls {
// ==============================
use async_std_crate::net::TcpStream;
use futures::{Future, FutureExt};
use std::pin::Pin;
use std::time::Duration;
@ -207,7 +205,7 @@ impl SpawnBlocking for async_executors::AsyncStd {
}
}
impl TlsProvider for async_executors::AsyncStd {
impl TlsProvider<TcpStream> for async_executors::AsyncStd {
type TlsStream = tls::TlsStream;
type Connector = tls::TlsConnector;

View File

@ -28,6 +28,22 @@ mod net {
/// Underlying tokio_util::compat::Compat wrapper.
s: Compat<TokioTcpStream>,
}
impl TcpStream {
/// Get a reference to the underlying tokio `TcpStream`.
pub fn get_ref(&self) -> &TokioTcpStream {
self.s.get_ref()
}
/// Get a mutable reference to the underlying tokio `TcpStream`.
pub fn get_mut(&mut self) -> &mut TokioTcpStream {
self.s.get_mut()
}
/// Convert this type into its underlying tokio `TcpStream`.
pub fn into_inner(self) -> TokioTcpStream {
self.s.into_inner()
}
}
impl From<TokioTcpStream> for TcpStream {
fn from(s: TokioTcpStream) -> TcpStream {
let s = s.compat();
@ -111,9 +127,9 @@ mod tls {
use futures::io::{AsyncRead, AsyncWrite};
use crate::impls::tokio::net::TcpStream;
use std::convert::TryFrom;
use std::io::{Error as IoError, Result as IoResult};
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
@ -139,16 +155,15 @@ mod tls {
}
#[async_trait]
impl crate::traits::TlsConnector for TlsConnector {
impl crate::traits::TlsConnector<TcpStream> for TlsConnector {
type Conn = TlsStream;
async fn connect_unvalidated(
async fn negotiate_unvalidated(
&self,
addr: &SocketAddr,
stream: TcpStream,
hostname: &str,
) -> IoResult<Self::Conn> {
let stream = tokio_crate::net::TcpStream::connect(addr).await?;
let stream = stream.into_inner();
let conn = self
.connector
.connect(hostname, stream)
@ -222,7 +237,7 @@ macro_rules! implement_traits_for {
}
}
impl TlsProvider for $runtime {
impl TlsProvider<net::TcpStream> for $runtime {
type TlsStream = tls::TlsStream;
type Connector = tls::TlsConnector;

View File

@ -202,7 +202,8 @@ fn simple_tls<R: Runtime>(runtime: &R) -> IoResult<()> {
runtime.block_on(async {
let text = b"I Suddenly Dont Understand Anything";
let mut buf = vec![0_u8; text.len()];
let mut conn = connector.connect_unvalidated(&addr, "Kan.Aya").await?;
let conn = runtime.connect(&addr).await?;
let mut conn = connector.negotiate_unvalidated(conn, "Kan.Aya").await?;
assert!(conn.peer_certificate()?.is_some());
conn.write_all(text).await?;
conn.flush().await?;

View File

@ -26,7 +26,15 @@ use std::time::{Duration, Instant, SystemTime};
/// Additionally, every `Runtime` is [`Send`] and [`Sync`], though these
/// requirements may be somewhat relaxed in the future.
pub trait Runtime:
Sync + Send + Spawn + SpawnBlocking + Clone + SleepProvider + TcpProvider + TlsProvider + 'static
Sync
+ Send
+ Spawn
+ SpawnBlocking
+ Clone
+ SleepProvider
+ TcpProvider
+ TlsProvider<Self::TcpStream>
+ 'static
{
}
@ -38,7 +46,7 @@ impl<T> Runtime for T where
+ Clone
+ SleepProvider
+ TcpProvider
+ TlsProvider
+ TlsProvider<Self::TcpStream>
+ 'static
{
}
@ -161,38 +169,35 @@ pub trait CertifiedConn {
fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>>;
}
/// An object that knows how to make a TLS-over-TCP connection we
/// can use in Tor.
/// An object that knows how to wrap a TCP connection (where the type of said TCP
/// connection is `S`) with TLS.
///
/// (Note that because of Tor's peculiarities, this is not a
/// general-purpose TLS type. Unlike typical users, Tor does not want
/// its TLS library to check whether the certificates are signed
/// within the web PKI hierarchy, or what their hostnames are.
#[async_trait]
pub trait TlsConnector {
pub trait TlsConnector<S> {
/// The type of connection returned by this connector
type Conn: AsyncRead + AsyncWrite + CertifiedConn + Unpin + Send + 'static;
/// Launch a TLS-over-TCP connection to a given address.
/// Start a TLS session over the provided TCP stream `stream`.
///
/// Declare `sni_hostname` as the desired hostname, but don't
/// actually check whether the hostname in the certificate matches
/// it.
async fn connect_unvalidated(
&self,
addr: &SocketAddr,
sni_hostname: &str,
) -> IoResult<Self::Conn>;
async fn negotiate_unvalidated(&self, stream: S, sni_hostname: &str) -> IoResult<Self::Conn>;
}
/// Trait for a runtime that knows how to create TLS connections.
/// Trait for a runtime that knows how to create TLS connections over
/// TCP streams of type `S`.
///
/// This is separate from [`TlsConnector`] because eventually we may
/// eventually want to support multiple `TlsConnector` implementations
/// that use a single [`Runtime`].
pub trait TlsProvider {
pub trait TlsProvider<S> {
/// The Connector object that this provider can return.
type Connector: TlsConnector<Conn = Self::TlsStream> + Send + Sync + Unpin;
type Connector: TlsConnector<S, Conn = Self::TlsStream> + Send + Sync + Unpin;
/// The type of the stream returned by that connector.
type TlsStream: AsyncRead + AsyncWrite + CertifiedConn + Unpin + Send + 'static;

View File

@ -37,11 +37,13 @@ pub fn stream_pair() -> (LocalStream, LocalStream) {
w: w1,
r: r1,
pending_bytes: Vec::new(),
tls_cert: None,
};
let s2 = LocalStream {
w: w2,
r: r2,
pending_bytes: Vec::new(),
tls_cert: None,
};
(s1, s2)
}
@ -65,6 +67,15 @@ pub struct LocalStream {
r: mpsc::Receiver<IoResult<Vec<u8>>>,
/// Bytes that we have read from `r` but not yet delivered.
pending_bytes: Vec<u8>,
/// Data about the other side of this stream's fake TLS certificate, if any.
/// If this is present, I/O operations will fail with an error.
///
/// How this is intended to work: things that return `LocalStream`s that could potentially
/// be connected to a fake TLS listener should set this field. Then, a fake TLS wrapper
/// type would clear this field (after checking its contents are as expected).
///
/// FIXME(eta): this is a bit of a layering violation, but it's hard to do otherwise
pub(crate) tls_cert: Option<Vec<u8>>,
}
/// Helper: pull bytes off the front of `pending_bytes` and put them
@ -85,6 +96,12 @@ impl AsyncRead for LocalStream {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if self.tls_cert.is_some() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"attempted to treat a TLS stream as non-TLS!",
)));
}
if !self.pending_bytes.is_empty() {
return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
}
@ -107,6 +124,13 @@ impl AsyncWrite for LocalStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
if self.tls_cert.is_some() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"attempted to treat a TLS stream as non-TLS!",
)));
}
match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
Ok(()) => (),
Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),

View File

@ -162,8 +162,9 @@ impl MockNetwork {
/// Tell the listener at `target_addr` (if any) about an incoming
/// connection from `source_addr` at `peer_stream`.
///
/// If this is a tls listener, only succeed when want_tls is true,
/// and return the certificate.
/// If the listener is a TLS listener, returns its certificate.
/// **Note:** Callers should check whether the presence or absence of a certificate
/// matches their expectations.
///
/// Returns an error if there isn't any such listener.
async fn send_connection(
@ -171,19 +172,12 @@ impl MockNetwork {
source_addr: SocketAddr,
target_addr: SocketAddr,
peer_stream: LocalStream,
want_tls: bool,
) -> IoResult<Option<Vec<u8>>> {
let entry = {
let listener_map = self.listening.lock().expect("Poisoned lock for listener");
listener_map.get(&target_addr).map(Clone::clone)
};
if let Some(mut entry) = entry {
if entry.tls_cert.is_some() != want_tls {
// TODO(nickm): This is not what you'd really see on a
// mismatched connection. Maybe we should change this
// to give garbage, or a warning, or something?
return Err(err(ErrorKind::ConnectionRefused));
}
if entry.send.send((peer_stream, source_addr)).await.is_ok() {
return Ok(entry.tls_cert);
}
@ -364,14 +358,16 @@ impl TcpProvider for MockNetProvider {
async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
let my_addr = self.get_origin_addr_for(addr)?;
let (mine, theirs) = stream_pair();
let (mut mine, theirs) = stream_pair();
let _no_cert = self
let cert = self
.inner
.net
.send_connection(my_addr, *addr, theirs, false)
.send_connection(my_addr, *addr, theirs)
.await?;
mine.tls_cert = cert;
Ok(mine)
}
@ -385,14 +381,12 @@ impl TcpProvider for MockNetProvider {
}
#[async_trait]
impl TlsProvider for MockNetProvider {
impl TlsProvider<LocalStream> for MockNetProvider {
type Connector = MockTlsConnector;
type TlsStream = MockTlsStream;
fn tls_connector(&self) -> MockTlsConnector {
MockTlsConnector {
provider: self.clone(),
}
MockTlsConnector {}
}
}
@ -401,10 +395,8 @@ impl TlsProvider for MockNetProvider {
/// Note that no TLS is actually performed here: connections are simply
/// told that they succeeded with a given certificate.
#[derive(Clone)]
pub struct MockTlsConnector {
/// A handle to the underlying provider.
provider: MockNetProvider,
}
#[non_exhaustive]
pub struct MockTlsConnector;
/// Mock TLS connector for use with MockNetProvider.
///
@ -422,28 +414,24 @@ pub struct MockTlsStream {
}
#[async_trait]
impl TlsConnector for MockTlsConnector {
impl TlsConnector<LocalStream> for MockTlsConnector {
type Conn = MockTlsStream;
async fn connect_unvalidated(
async fn negotiate_unvalidated(
&self,
addr: &SocketAddr,
mut stream: LocalStream,
_sni_hostname: &str,
) -> IoResult<MockTlsStream> {
let my_addr = self.provider.get_origin_addr_for(addr)?;
let (mine, theirs) = stream_pair();
let peer_cert = stream.tls_cert.take();
let peer_cert = self
.provider
.inner
.net
.send_connection(my_addr, *addr, theirs, true)
.await?;
if peer_cert.is_none() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"attempted to wrap non-TLS stream!",
));
}
Ok(MockTlsStream {
peer_cert,
stream: mine,
})
Ok(MockTlsStream { peer_cert, stream })
}
}
@ -629,8 +617,9 @@ mod test {
let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
async {
let connector = client1.tls_connector();
let conn = client1.connect(&address).await?;
let mut conn = connector
.connect_unvalidated(&address, "zombo.example.com")
.negotiate_unvalidated(conn, "zombo.example.com")
.await?;
assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
conn.write_all(b"This is totally encrypted.").await?;
@ -638,10 +627,6 @@ mod test {
conn.read_to_end(&mut v).await?;
conn.close().await?;
assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
// Now try a non-tls connection.
let e = client1.connect(&address).await;
assert!(e.is_err());
Ok(())
},
async {

View File

@ -6,6 +6,7 @@
use crate::net::MockNetProvider;
use tor_rtcompat::{Runtime, SleepProvider, SpawnBlocking, TcpProvider, TlsProvider};
use crate::io::LocalStream;
use async_trait::async_trait;
use futures::task::{FutureObj, Spawn, SpawnError};
use futures::Future;
@ -66,9 +67,9 @@ impl<R: Runtime> TcpProvider for MockNetRuntime<R> {
}
}
impl<R: Runtime> TlsProvider for MockNetRuntime<R> {
type Connector = <MockNetProvider as TlsProvider>::Connector;
type TlsStream = <MockNetProvider as TlsProvider>::TlsStream;
impl<R: Runtime> TlsProvider<LocalStream> for MockNetRuntime<R> {
type Connector = <MockNetProvider as TlsProvider<LocalStream>>::Connector;
type TlsStream = <MockNetProvider as TlsProvider<LocalStream>>::TlsStream;
fn tls_connector(&self) -> Self::Connector {
self.net.tls_connector()
}

View File

@ -105,7 +105,7 @@ impl<R: Runtime> TcpProvider for MockSleepRuntime<R> {
}
}
impl<R: Runtime> TlsProvider for MockSleepRuntime<R> {
impl<R: Runtime> TlsProvider<R::TcpStream> for MockSleepRuntime<R> {
type Connector = R::Connector;
type TlsStream = R::TlsStream;
fn tls_connector(&self) -> Self::Connector {