Merge branch 'ticket_525_part2' into 'main'

Apply restricted_msg to ChanMsg parts of tor-proto

See merge request tpo/core/arti!1013
This commit is contained in:
Nick Mathewson 2023-02-09 18:06:39 +00:00
commit 696b9bd2d5
30 changed files with 376 additions and 343 deletions

1
Cargo.lock generated
View File

@ -3604,6 +3604,7 @@ dependencies = [
"bitflags",
"bytes",
"caret",
"derive_more",
"educe",
"hex",
"hex-literal",

View File

@ -0,0 +1 @@
MODIFIED: Error::BadMessage is deprecated, Error::InvalidMessage is new.

View File

@ -1,5 +1,7 @@
//! Internal: Declare an Error type for tor-bytes
use std::borrow::Cow;
use thiserror::Error;
use tor_error::{into_internal, Bug};
@ -26,8 +28,13 @@ pub enum Error {
BadLengthValue,
/// An attempt to parse an object failed for some reason related to its
/// contents.
#[deprecated(since = "0.6.2", note = "Use InvalidMessage instead.")]
#[error("Bad object: {0}")]
BadMessage(&'static str),
/// An attempt to parse an object failed for some reason related to its
/// contents.
#[error("Bad object: {0}")]
InvalidMessage(Cow<'static, str>),
/// A parsing error that should never happen.
///
/// We use this one in lieu of calling assert() and expect() and
@ -42,7 +49,9 @@ impl PartialEq for Error {
match (self, other) {
(Truncated, Truncated) => true,
(ExtraneousBytes, ExtraneousBytes) => true,
#[allow(deprecated)]
(BadMessage(a), BadMessage(b)) => a == b,
(InvalidMessage(a), InvalidMessage(b)) => a == b,
(BadLengthValue, BadLengthValue) => true,
// notably, this means that an internal error is equal to nothing, not even itself.
(_, _) => false,

View File

@ -182,7 +182,7 @@ mod ed25519_impls {
fn take_from(b: &mut Reader<'_>) -> Result<Self> {
let bytes = b.take(32)?;
Self::from_bytes(array_ref![bytes, 0, 32])
.map_err(|_| Error::BadMessage("Couldn't decode Ed25519 public key"))
.map_err(|_| Error::InvalidMessage("Couldn't decode Ed25519 public key".into()))
}
}
@ -208,7 +208,7 @@ mod ed25519_impls {
fn take_from(b: &mut Reader<'_>) -> Result<Self> {
let bytes = b.take(64)?;
Self::from_bytes(array_ref![bytes, 0, 64])
.map_err(|_| Error::BadMessage("Couldn't decode Ed25519 signature."))
.map_err(|_| Error::InvalidMessage("Couldn't decode Ed25519 signature.".into()))
}
}
}

View File

@ -27,6 +27,7 @@ arrayref = "0.3"
bitflags = "1"
bytes = "1"
caret = { path = "../caret", version = "0.3.0" }
derive_more = "0.99.3"
educe = "0.4.6"
paste = "1"
rand = "0.8"

View File

@ -6,3 +6,5 @@ BREAKING: Moved ChanMsg methods into a trait.
BREAKING: Moved RelayMsg methods into a trait.
BREAKING: Renamed ChanCell->AnyChanCell, ChanMsg->AnyChanMsg.
BREAKING: Renamed RelayCell->AnyRelayCell, RelayMsg->AnyRelayMsg.
BREAKING: Make ChannelCodec::decode() parameterized.
BREAKING: RelayEarly is now a real type.

View File

@ -1,7 +1,7 @@
//! Implementation for encoding and decoding of ChanCells.
use super::CELL_DATA_LEN;
use crate::chancell::{msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
use super::{ChanCell, CELL_DATA_LEN};
use crate::chancell::{ChanCmd, ChanMsg, CircId};
use crate::Error;
use arrayref::{array_mut_ref, array_ref};
use tor_bytes::{self, Reader, Writer};
@ -49,8 +49,12 @@ impl ChannelCodec {
}
/// Write the given cell into the provided BytesMut object.
pub fn write_cell(&mut self, item: AnyChanCell, dst: &mut BytesMut) -> crate::Result<()> {
let AnyChanCell { circid, msg } = item;
pub fn write_cell<M: ChanMsg>(
&mut self,
item: ChanCell<M>,
dst: &mut BytesMut,
) -> crate::Result<()> {
let ChanCell { circid, msg } = item;
let cmd = msg.cmd();
dst.write_u32(circid.into());
dst.write_u8(cmd.into());
@ -83,7 +87,10 @@ impl ChannelCodec {
///
/// On a definite decoding error, return Err(_). On a cell that might
/// just be truncated, return Ok(None).
pub fn decode_cell(&mut self, src: &mut BytesMut) -> crate::Result<Option<AnyChanCell>> {
pub fn decode_cell<M: ChanMsg>(
&mut self,
src: &mut BytesMut,
) -> crate::Result<Option<ChanCell<M>>> {
/// Wrap `be` as an appropriate type.
fn wrap_err(be: tor_bytes::Error) -> crate::Error {
crate::Error::BytesErr {
@ -113,7 +120,7 @@ impl ChannelCodec {
let mut r = Reader::from_bytes(&cell);
let circid: CircId = r.take_u32().map_err(wrap_err)?.into();
r.advance(if varcell { 3 } else { 1 }).map_err(wrap_err)?;
let msg = msg::AnyChanMsg::decode_from_reader(cmd, &mut r).map_err(wrap_err)?;
let msg = M::decode_from_reader(cmd, &mut r).map_err(wrap_err)?;
if !cmd.accepts_circid_val(circid) {
return Err(Error::ChanProto(format!(
@ -121,6 +128,6 @@ impl ChannelCodec {
circid, cmd
)));
}
Ok(Some(AnyChanCell { circid, msg }))
Ok(Some(ChanCell { circid, msg }))
}
}

View File

@ -11,8 +11,6 @@ use educe::Educe;
/// Trait for the 'bodies' of channel messages.
pub trait Body: Readable {
/// Convert this type into a ChanMsg, wrapped as appropriate.
fn into_message(self) -> AnyChanMsg;
/// Decode a channel cell body from a provided reader.
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
r.extract()
@ -32,6 +30,7 @@ crate::restrict::restricted_msg! {
/// a TLS connection.
#[derive(Clone, Debug)]
#[non_exhaustive]
@omit_from "avoid_conflict_with_a_blanket_implementation"
pub enum AnyChanMsg : ChanMsg {
/// A Padding message
Padding,
@ -96,9 +95,6 @@ impl Padding {
}
}
impl Body for Padding {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Padding(self)
}
fn encode_onto<W: Writer + ?Sized>(self, _w: &mut W) -> EncodeResult<()> {
Ok(())
}
@ -124,9 +120,6 @@ impl Vpadding {
}
}
impl Body for Vpadding {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Vpadding(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_zeros(self.len as usize);
Ok(())
@ -135,7 +128,9 @@ impl Body for Vpadding {
impl Readable for Vpadding {
fn take_from(r: &mut Reader<'_>) -> Result<Self> {
if r.remaining() > std::u16::MAX as usize {
return Err(Error::BadMessage("Too many bytes in VPADDING cell"));
return Err(Error::InvalidMessage(
"Too many bytes in VPADDING cell".into(),
));
}
Ok(Vpadding {
len: r.remaining() as u16,
@ -165,9 +160,6 @@ macro_rules! fixed_len_handshake {
}
}
impl Body for $name {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::$name(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_all(&self.handshake[..]);
Ok(())
@ -261,9 +253,6 @@ pub struct Create2 {
handshake: Vec<u8>,
}
impl Body for Create2 {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Create2(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_u16(self.handshake_type);
let handshake_len = self
@ -335,9 +324,6 @@ impl Created2 {
}
}
impl Body for Created2 {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Created2(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
let handshake_len = self
.handshake
@ -407,13 +393,10 @@ impl Relay {
}
/// Wrap this Relay message into a RelayMsg as a RELAY_EARLY cell.
pub fn into_early(self) -> AnyChanMsg {
AnyChanMsg::RelayEarly(self)
AnyChanMsg::RelayEarly(RelayEarly(self))
}
}
impl Body for Relay {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Relay(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_all(&self.body[..]);
Ok(())
@ -427,8 +410,31 @@ impl Readable for Relay {
}
}
/// Alias for Relay: these two cell types have the same body.
pub type RelayEarly = Relay;
/// A Relay cell that is allowed to contain a CREATE message.
///
/// Only a limited number of these may be sent on each circuit.
#[derive(Clone, Debug, derive_more::Deref, derive_more::From, derive_more::Into)]
pub struct RelayEarly(Relay);
impl Readable for RelayEarly {
fn take_from(r: &mut Reader<'_>) -> Result<Self> {
Ok(RelayEarly(Relay::take_from(r)?))
}
}
impl Body for RelayEarly {
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
self.0.encode_onto(w)
}
}
impl RelayEarly {
/// Consume this RelayEarly message and return a RelayCellBody for
/// encryption/decryption.
//
// (Since this method takes `self` by value, we can't take advantage of
// Deref.)
pub fn into_relay_body(self) -> RawCellBody {
*self.0.body
}
}
/// The Destroy message tears down a circuit.
///
@ -451,9 +457,6 @@ impl Destroy {
}
}
impl Body for Destroy {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Destroy(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_u8(self.reason.into());
Ok(())
@ -608,9 +611,6 @@ impl Netinfo {
}
}
impl Body for Netinfo {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Netinfo(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_u32(self.timestamp);
let their_addr = self
@ -708,9 +708,6 @@ impl Versions {
}
}
impl Body for Versions {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Versions(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
for v in &self.versions {
w.write_u16(*v);
@ -814,9 +811,6 @@ impl Default for PaddingNegotiate {
}
impl Body for PaddingNegotiate {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::PaddingNegotiate(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_u8(0); // version
w.write_u8(self.command.get());
@ -829,8 +823,8 @@ impl Readable for PaddingNegotiate {
fn take_from(r: &mut Reader<'_>) -> Result<Self> {
let v = r.take_u8()?;
if v != 0 {
return Err(Error::BadMessage(
"Unrecognized padding negotiation version",
return Err(Error::InvalidMessage(
"Unrecognized padding negotiation version".into(),
));
}
let command = r.take_u8()?.into();
@ -940,9 +934,6 @@ impl Certs {
}
impl Body for Certs {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Certs(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
let n_certs: u8 = self
.certs
@ -1001,9 +992,6 @@ impl AuthChallenge {
}
impl Body for AuthChallenge {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::AuthChallenge(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_all(&self.challenge[..]);
let n_methods = self
@ -1057,9 +1045,6 @@ impl Authenticate {
}
}
impl Body for Authenticate {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Authenticate(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_u16(self.authtype);
let authlen = self
@ -1098,9 +1083,6 @@ impl Authorize {
}
}
impl Body for Authorize {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Authorize(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_all(&self.content[..]);
Ok(())
@ -1150,9 +1132,6 @@ impl Unrecognized {
}
}
impl Body for Unrecognized {
fn into_message(self) -> AnyChanMsg {
AnyChanMsg::Unrecognized(self)
}
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()> {
w.write_all(&self.content[..]);
Ok(())
@ -1167,12 +1146,6 @@ impl Readable for Unrecognized {
}
}
impl<B: Body> From<B> for AnyChanMsg {
fn from(body: B) -> Self {
body.into_message()
}
}
/// Helper: declare a From<> implementation from message types for
/// cells that don't take a circid.
macro_rules! msg_into_cell {
@ -1181,7 +1154,7 @@ macro_rules! msg_into_cell {
fn from(body: $body) -> super::AnyChanCell {
super::AnyChanCell {
circid: 0.into(),
msg: body.into_message(),
msg: body.into(),
}
}
}

View File

@ -293,7 +293,9 @@ impl<M: RelayMsg> RelayCell<M> {
r.advance(4)?; // digest
let len = r.take_u16()? as usize;
if r.remaining() < len {
return Err(Error::BadMessage("Insufficient data in relay cell"));
return Err(Error::InvalidMessage(
"Insufficient data in relay cell".into(),
));
}
r.truncate(len);
let msg = M::decode_from_reader(cmd, r)?;

View File

@ -91,16 +91,16 @@ impl Readable for NtorV3Extension {
Ok(match tag {
NtorV3ExtensionType::CC_REQUEST => {
if len != 0 {
return Err(tor_bytes::Error::BadMessage(
"invalid length for RequestCongestionControl",
return Err(tor_bytes::Error::InvalidMessage(
"invalid length for RequestCongestionControl".into(),
));
}
NtorV3Extension::RequestCongestionControl
}
NtorV3ExtensionType::CC_RESPONSE => {
if len != 1 {
return Err(tor_bytes::Error::BadMessage(
"invalid length for AckCongestionControl",
return Err(tor_bytes::Error::InvalidMessage(
"invalid length for AckCongestionControl".into(),
));
}
let sendme_inc = reader.take_u8()?;

View File

@ -31,6 +31,7 @@ crate::restrict::restricted_msg! {
/// A single parsed relay message, sent or received along a circuit
#[derive(Debug, Clone)]
#[non_exhaustive]
@omit_from "avoid_conflict_with_a_blanket_implementation"
pub enum AnyRelayMsg : RelayMsg {
/// Create a stream
Begin,
@ -107,20 +108,12 @@ pub enum AnyRelayMsg : RelayMsg {
/// Internal: traits in common different cell bodies.
pub trait Body: Sized {
/// Convert this type into a RelayMsg, wrapped appropriate.
fn into_message(self) -> AnyRelayMsg;
/// Decode a relay cell body from a provided reader.
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self>;
/// Encode the body of this cell into the end of a writer.
fn encode_onto<W: Writer + ?Sized>(self, w: &mut W) -> EncodeResult<()>;
}
impl<B: Body> From<B> for AnyRelayMsg {
fn from(b: B) -> AnyRelayMsg {
b.into_message()
}
}
bitflags! {
/// A set of recognized flags that can be attached to a begin cell.
///
@ -209,9 +202,6 @@ impl Begin {
}
impl Body for Begin {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Begin(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let addr = {
if r.peek(1)? == b"[" {
@ -220,7 +210,7 @@ impl Body for Begin {
let a = r.take_until(b']')?;
let colon = r.take_u8()?;
if colon != b':' {
return Err(Error::BadMessage("missing port in begin cell"));
return Err(Error::InvalidMessage("missing port in begin cell".into()));
}
a
} else {
@ -232,15 +222,17 @@ impl Body for Begin {
let flags = if r.remaining() >= 4 { r.take_u32()? } else { 0 };
if !addr.is_ascii() {
return Err(Error::BadMessage("target address in begin cell not ascii"));
return Err(Error::InvalidMessage(
"target address in begin cell not ascii".into(),
));
}
let port = std::str::from_utf8(port)
.map_err(|_| Error::BadMessage("port in begin cell not utf8"))?;
.map_err(|_| Error::InvalidMessage("port in begin cell not utf8".into()))?;
let port = port
.parse()
.map_err(|_| Error::BadMessage("port in begin cell not a valid port"))?;
.map_err(|_| Error::InvalidMessage("port in begin cell not a valid port".into()))?;
Ok(Begin {
addr: addr.into(),
@ -324,9 +316,6 @@ impl AsRef<[u8]> for Data {
}
impl Body for Data {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Data(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
Ok(Data {
body: r.take(r.remaining())?.into(),
@ -437,9 +426,6 @@ impl End {
}
}
impl Body for End {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::End(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
if r.remaining() == 0 {
return Ok(End {
@ -528,9 +514,6 @@ impl Connected {
}
}
impl Body for Connected {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Connected(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
if r.remaining() == 0 {
return Ok(Connected { addr: None });
@ -538,7 +521,9 @@ impl Body for Connected {
let ipv4 = r.take_u32()?;
let addr = if ipv4 == 0 {
if r.take_u8()? != 6 {
return Err(Error::BadMessage("Invalid address type in CONNECTED cell"));
return Err(Error::InvalidMessage(
"Invalid address type in CONNECTED cell".into(),
));
}
IpAddr::V6(r.extract()?)
} else {
@ -607,9 +592,6 @@ impl Sendme {
}
}
impl Body for Sendme {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Sendme(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let digest = if r.remaining() == 0 {
None
@ -622,7 +604,7 @@ impl Body for Sendme {
Some(r.take(dlen as usize)?.into())
}
_ => {
return Err(Error::BadMessage("Unrecognized SENDME version."));
return Err(Error::InvalidMessage("Unrecognized SENDME version.".into()));
}
}
};
@ -672,9 +654,6 @@ impl Extend {
}
}
impl Body for Extend {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Extend(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let addr = r.extract()?;
let port = r.take_u16()?;
@ -713,9 +692,6 @@ impl Extended {
}
}
impl Body for Extended {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Extended(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let handshake = r.take(TAP_S_HANDSHAKE_LEN)?.into();
Ok(Extended { handshake })
@ -777,9 +753,6 @@ impl Extend2 {
}
impl Body for Extend2 {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Extend2(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let n = r.take_u8()?;
let linkspec = r.extract_n(n as usize)?;
@ -836,9 +809,6 @@ impl Extended2 {
}
}
impl Body for Extended2 {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Extended2(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let hlen = r.take_u16()?;
let handshake = r.take(hlen as usize)?;
@ -881,9 +851,6 @@ impl Truncated {
}
}
impl Body for Truncated {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Truncated(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
Ok(Truncated {
reason: r.take_u8()?.into(),
@ -937,9 +904,6 @@ impl Resolve {
}
}
impl Body for Resolve {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Resolve(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let query = r.take_until(0)?;
Ok(Resolve {
@ -995,7 +959,9 @@ impl Readable for ResolvedVal {
let len = r.take_u8()? as usize;
if let Some(expected_len) = res_len(tp) {
if len != expected_len {
return Err(Error::BadMessage("Wrong length for RESOLVED answer"));
return Err(Error::InvalidMessage(
"Wrong length for RESOLVED answer".into(),
));
}
}
Ok(match tp {
@ -1106,9 +1072,6 @@ impl Resolved {
}
}
impl Body for Resolved {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Resolved(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let mut answers = Vec::new();
while r.remaining() > 0 {
@ -1161,9 +1124,6 @@ impl Unrecognized {
}
impl Body for Unrecognized {
fn into_message(self) -> AnyRelayMsg {
AnyRelayMsg::Unrecognized(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
Ok(Unrecognized {
cmd: 0.into(),
@ -1187,9 +1147,6 @@ macro_rules! empty_body {
#[non_exhaustive]
pub struct $name {}
impl $crate::relaycell::msg::Body for $name {
fn into_message(self) -> $crate::relaycell::msg::AnyRelayMsg {
$crate::relaycell::msg::AnyRelayMsg::$name(self)
}
fn decode_from_reader(_r: &mut Reader<'_>) -> Result<Self> {
Ok(Self::default())
}

View File

@ -157,9 +157,6 @@ pub struct EstablishIntro {
}
impl msg::Body for EstablishIntro {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::EstablishIntro(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let auth_key_type = r.take_u8()?.into();
let auth_key_len = r.take_u16()?;
@ -259,9 +256,6 @@ impl EstablishRendezvous {
}
}
impl msg::Body for EstablishRendezvous {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::EstablishRendezvous(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let cookie = r.extract()?;
r.take_rest();
@ -277,9 +271,6 @@ impl msg::Body for EstablishRendezvous {
pub struct Introduce1(Introduce);
impl msg::Body for Introduce1 {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::Introduce1(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
Ok(Self(Introduce::decode_from_reader(r)?))
}
@ -300,9 +291,6 @@ impl Introduce1 {
pub struct Introduce2(Introduce);
impl msg::Body for Introduce2 {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::Introduce2(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
Ok(Self(Introduce::decode_from_reader(r)?))
}
@ -343,7 +331,9 @@ impl Introduce {
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let legacy_key_id: RsaIdentity = r.extract()?;
if !legacy_key_id.is_zero() {
return Err(BytesError::BadMessage("legacy key id in Introduce1."));
return Err(BytesError::InvalidMessage(
"legacy key id in Introduce1.".into(),
));
}
let auth_key_type = r.take_u8()?.into();
let auth_key_len = r.take_u16()?;
@ -387,10 +377,6 @@ pub struct Rendezvous1 {
}
impl Body for Rendezvous1 {
fn into_message(self) -> msg::AnyRelayMsg {
todo!()
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
todo!()
}
@ -409,10 +395,6 @@ pub struct Rendezvous2 {
}
impl Body for Rendezvous2 {
fn into_message(self) -> msg::AnyRelayMsg {
todo!()
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
todo!()
}
@ -434,10 +416,6 @@ pub struct IntroEstablished {
}
impl Body for IntroEstablished {
fn into_message(self) -> msg::AnyRelayMsg {
todo!()
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
todo!()
}
@ -466,10 +444,6 @@ pub struct IntroduceAck {
}
impl Body for IntroduceAck {
fn into_message(self) -> msg::AnyRelayMsg {
todo!()
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
todo!()
}

View File

@ -98,7 +98,7 @@ impl Readable for Address {
}
T_IPV4 => Self::Ipv4(r.extract()?),
T_IPV6 => Self::Ipv6(r.extract()?),
_ => return Err(Error::BadMessage("Invalid address type")),
_ => return Err(Error::InvalidMessage("Invalid address type".into())),
})
})
}
@ -133,10 +133,10 @@ impl FromStr for Address {
Ok(Self::Ipv6(ipv6))
} else {
if s.len() > MAX_HOSTNAME_LEN {
return Err(Error::BadMessage("Hostname too long"));
return Err(Error::InvalidMessage("Hostname too long".into()));
}
if s.contains('\0') {
return Err(Error::BadMessage("Nul byte not permitted"));
return Err(Error::InvalidMessage("Nul byte not permitted".into()));
}
let mut addr = s.to_string();
@ -188,10 +188,6 @@ impl ConnectUdp {
}
impl msg::Body for ConnectUdp {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::ConnectUdp(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let flags = r.take_u32()?;
let addr = r.extract()?;
@ -230,18 +226,14 @@ impl ConnectedUdp {
}
impl msg::Body for ConnectedUdp {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::ConnectedUdp(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
let our_address: AddressPort = r.extract()?;
if our_address.addr.is_hostname() {
return Err(Error::BadMessage("Our address is a Hostname"));
return Err(Error::InvalidMessage("Our address is a Hostname".into()));
}
let their_address: AddressPort = r.extract()?;
if their_address.addr.is_hostname() {
return Err(Error::BadMessage("Their address is a Hostname"));
return Err(Error::InvalidMessage("Their address is a Hostname".into()));
}
Ok(Self {
@ -305,10 +297,6 @@ impl AsRef<[u8]> for Datagram {
}
impl msg::Body for Datagram {
fn into_message(self) -> msg::AnyRelayMsg {
msg::AnyRelayMsg::Datagram(self)
}
fn decode_from_reader(r: &mut Reader<'_>) -> Result<Self> {
Ok(Datagram {
body: r.take(r.remaining())?.into(),

View File

@ -1,7 +1,7 @@
//! Declare a restricted variant of our message types.
/// Re-export tor_bytes here, so that the macro can use it.
pub use tor_bytes;
/// Re-export tor_bytes and paste here, so that the macro can use it.
pub use {paste, tor_bytes};
/// Declare a restricted version of
/// [`AnyRelayMsg`](crate::relaycell::msg::AnyRelayMsg) or
@ -40,18 +40,20 @@ pub use tor_bytes;
macro_rules! restricted_msg {
{
$(#[$meta:meta])*
$(@omit_from $omit_from:literal)?
$v:vis enum $name:ident : RelayMsg {
$($tt:tt)*
}
} => {
$crate::restrict::restricted_msg!{
[
base_type: $crate::relaycell::msg::AnyRelayMsg,
any_type: $crate::relaycell::msg::AnyRelayMsg,
msg_mod: $crate::relaycell::msg,
cmd_type: $crate::relaycell::RelayCmd,
unrecognized: $crate::relaycell::msg::Unrecognized,
body_trait: $crate::relaycell::msg::Body,
msg_trait: $crate::relaycell::RelayMsg
msg_trait: $crate::relaycell::RelayMsg,
omit_from: $($omit_from)?
]
$(#[$meta])*
$v enum $name { $($tt)*}
@ -59,18 +61,20 @@ macro_rules! restricted_msg {
};
{
$(#[$meta:meta])*
$(@omit_from $omit_from:literal)?
$v:vis enum $name:ident : ChanMsg {
$($tt:tt)*
}
} => {
$crate::restrict::restricted_msg!{
[
base_type: $crate::chancell::msg::AnyChanMsg,
any_type: $crate::chancell::msg::AnyChanMsg,
msg_mod: $crate::chancell::msg,
cmd_type: $crate::chancell::ChanCmd,
unrecognized: $crate::chancell::msg::Unrecognized,
body_trait: $crate::chancell::msg::Body,
msg_trait: $crate::chancell::ChanMsg
msg_trait: $crate::chancell::ChanMsg,
omit_from: $($omit_from)?
]
$(#[$meta])*
$v enum $name { $($tt)*}
@ -78,12 +82,13 @@ macro_rules! restricted_msg {
};
{
[
base_type: $base:ty,
any_type: $any_msg:ty,
msg_mod: $msg_mod:path,
cmd_type: $cmd_type:ty,
unrecognized: $unrec_type:ty,
body_trait: $body_type:ty,
msg_trait: $msg_trait:ty
msg_trait: $msg_trait:ty,
omit_from: $($omit_from:literal)?
]
$(#[$meta:meta])*
$v:vis enum $name:ident {
@ -98,7 +103,7 @@ macro_rules! restricted_msg {
$(,)?
}
} => {
paste::paste!{
$crate::restrict::paste::paste!{
$(#[$meta])*
$v enum $name {
$(
@ -151,12 +156,65 @@ macro_rules! restricted_msg {
$(
_ => Self::$unrecognized($unrec_type::decode_with_cmd(cmd, r)?),
)?
// TODO: This message is too terse! This message type should maybe take a Cow?
#[allow(unreachable_patterns)] // This is unreachable if we had an Unrecognized variant above.
_ => return Err($crate::restrict::tor_bytes::Error::BadMessage("Unexpected command")),
_ => return Err($crate::restrict::tor_bytes::Error::InvalidMessage(
format!("Unexpected command {} in {}", cmd, stringify!($name)).into()
)),
})
}
}
$(
#[cfg(feature = $omit_from)]
)?
impl From<$name> for $any_msg {
fn from(msg: $name) -> $any_msg {
match msg {
$(
$( #[cfg(feature=$feat)] )?
$name::$case(b) => Self::$case(b),
)*
$(
$name::$unrecognized(u) => $any_msg::Unrecognized(u),
)?
}
}
}
$(
#[cfg(feature = $omit_from)]
)?
impl TryFrom<$any_msg> for $name {
type Error = $any_msg;
fn try_from(msg: $any_msg) -> std::result::Result<$name, $any_msg> {
Ok(match msg {
$(
$( #[cfg(feature=$feat)] )?
$any_msg::$case(b) => $name::$case(b),
)*
$(
$any_msg::Unrecognized(u) => Self::$unrecognized(u),
)?
#[allow(unreachable_patterns)]
other => return Err(other),
})
}
}
$(
$( #[cfg(feature=$feat)] )?
impl From<$msg_mod :: $case> for $name {
fn from(m: $msg_mod::$case) -> $name {
$name :: $case(m)
}
}
)*
$(
impl From<$unrec_type> for $name {
fn from (u: $unrec_type) -> $name {
$name::$unrecognized(u)
}
}
)?
}
}
}

View File

@ -3,6 +3,7 @@
// Reminder: you can think of a cell as an message plus a circuitid.
#![allow(clippy::uninlined_format_args)]
use tor_cell::chancell::msg::AnyChanMsg;
use tor_cell::chancell::{codec, msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
use tor_cell::Error;
@ -31,7 +32,7 @@ fn cell(body: &str, msg: msg::AnyChanMsg, id: CircId, pad_body: bool) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
bm.extend_from_slice(&b"next thing"[..]);
let decoded = codec.decode_cell(&mut bm).unwrap();
let decoded = codec.decode_cell::<AnyChanMsg>(&mut bm).unwrap();
assert_eq!(bm.len(), 10);
decoded.unwrap()
};
@ -40,7 +41,7 @@ fn cell(body: &str, msg: msg::AnyChanMsg, id: CircId, pad_body: bool) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
// no extra bytes this time.
let decoded = codec.decode_cell(&mut bm).unwrap();
let decoded = codec.decode_cell::<AnyChanMsg>(&mut bm).unwrap();
assert_eq!(bm.len(), 0);
decoded.unwrap()
};
@ -90,7 +91,7 @@ fn test_simple_cells() {
let mut bm = BytesMut::new();
bm.extend_from_slice(&m);
codec::ChannelCodec::new(4)
.decode_cell(&mut bm)
.decode_cell::<AnyChanMsg>(&mut bm)
.unwrap()
.unwrap()
};
@ -109,7 +110,7 @@ fn short_cell(body: &str) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
let len_orig = bm.len();
let d = codec.decode_cell(&mut bm);
let d = codec.decode_cell::<AnyChanMsg>(&mut bm);
assert!(d.unwrap().is_none()); // "Ok(None)" means truncated.
assert_eq!(bm.len(), len_orig);
}
@ -141,7 +142,7 @@ fn bad_cell(body: &str, err: Error, pad_body: bool) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
bm.extend_from_slice(&b"next thing"[..]);
codec.decode_cell(&mut bm).err().unwrap()
codec.decode_cell::<AnyChanMsg>(&mut bm).err().unwrap()
};
assert_eq!(format!("{:?}", decoded), format!("{:?}", err));

View File

@ -91,7 +91,9 @@ fn test_cells() {
let m = decode("02 0000 9999 12345678 01f3 6e6565642d746f2d6b6e6f77 00000000");
assert_eq!(
AnyRelayCell::decode(m).err(),
Some(Error::BadMessage("Insufficient data in relay cell"))
Some(Error::InvalidMessage(
"Insufficient data in relay cell".into()
))
);
// check accessors.
@ -173,7 +175,10 @@ fn test_address() {
let hostname = "a".repeat(256);
let addr = Address::from_str(hostname.as_str());
assert!(addr.is_err());
assert_eq!(addr.err(), Some(Error::BadMessage("Hostname too long")));
assert_eq!(
addr.err(),
Some(Error::InvalidMessage("Hostname too long".into()))
);
// Some Unicode emojis (go Gen-Z!).
let hostname = "👍️👍️👍️";
@ -187,6 +192,6 @@ fn test_address() {
assert!(addr.is_err());
assert_eq!(
addr.err(),
Some(Error::BadMessage("Nul byte not permitted"))
Some(Error::InvalidMessage("Nul byte not permitted".into()))
);
}

View File

@ -406,6 +406,6 @@ fn test_padding_negotiate() {
assert_eq!(
decode_err(cmd, "90 0303", true),
BytesError::BadMessage("Unrecognized padding negotiation version")
BytesError::InvalidMessage("Unrecognized padding negotiation version".into())
);
}

View File

@ -101,14 +101,14 @@ fn test_begin() {
msg_error(
cmd,
"5b3a3a5d21", // [::]!
BytesError::BadMessage("missing port in begin cell"),
BytesError::InvalidMessage("missing port in begin cell".into()),
);
// hand-generated failure case: not ascii.
msg_error(
cmd,
"746f7270726f6a656374e284a22e6f72673a34343300", // torproject™.org:443
BytesError::BadMessage("target address in begin cell not ascii"),
BytesError::InvalidMessage("target address in begin cell not ascii".into()),
);
// failure on construction: bad address.
@ -151,7 +151,7 @@ fn test_connected() {
msg_error(
cmd,
"00000000 07 20010db8 00000000 00000000 00001122 00000E10",
BytesError::BadMessage("Invalid address type in CONNECTED cell"),
BytesError::InvalidMessage("Invalid address type in CONNECTED cell".into()),
);
}
@ -418,7 +418,7 @@ fn test_resolved() {
msg_error(
cmd,
"04 03 010203 00000001",
BytesError::BadMessage("Wrong length for RESOLVED answer"),
BytesError::InvalidMessage("Wrong length for RESOLVED answer".into()),
);
}
@ -564,7 +564,7 @@ fn test_connect_udp() {
msg_error(
cmd,
"00000000 07 04 01020304",
BytesError::BadMessage("Invalid address type"),
BytesError::InvalidMessage("Invalid address type".into()),
);
// A zero length address with and without hostname payload.
@ -608,14 +608,14 @@ fn test_connected_udp() {
cmd,
"01 04 01020304 0050
04 04 05060708 0050",
BytesError::BadMessage("Our address is a Hostname"),
BytesError::InvalidMessage("Our address is a Hostname".into()),
);
// Invalid their_address
msg_error(
cmd,
"04 04 01020304 0050
01 04 05060708 0050",
BytesError::BadMessage("Their address is a Hostname"),
BytesError::InvalidMessage("Their address is a Hostname".into()),
);
}
@ -773,7 +773,7 @@ fn test_introduce() {
02 0004 00010203
00
01090804",
BytesError::BadMessage("legacy key id in Introduce1."),
BytesError::InvalidMessage("legacy key id in Introduce1.".into()),
);
}
// TODO: need to add tests for:

View File

@ -301,13 +301,15 @@ impl Readable for CertExt {
Ok(match ext_type {
ExtType::SIGNED_WITH_ED25519_KEY => CertExt::SignedWithEd25519(SignedWithEd25519Ext {
pk: ed25519::Ed25519Identity::from_bytes(body)
.ok_or(BytesError::BadMessage("wrong length on Ed25519 key"))?,
pk: ed25519::Ed25519Identity::from_bytes(body).ok_or(
BytesError::InvalidMessage("wrong length on Ed25519 key".into()),
)?,
}),
_ => {
if (flags & 1) != 0 {
return Err(BytesError::BadMessage(
"unrecognized certificate extension, with 'affects_validation' flag set.",
return Err(BytesError::InvalidMessage(
"unrecognized certificate extension, with 'affects_validation' flag set."
.into(),
));
}
CertExt::Unrecognized(UnrecognizedExt {
@ -335,7 +337,9 @@ impl Ed25519Cert {
if v != 1 {
// This would be something other than a "v1" certificate. We don't
// understand those.
return Err(BytesError::BadMessage("Unrecognized certificate version"));
return Err(BytesError::InvalidMessage(
"Unrecognized certificate version".into(),
));
}
let cert_type = r.take_u8()?.into();
let exp_hours = r.take_u32()?;
@ -591,8 +595,8 @@ mod test {
assert!(e.is_err());
assert_eq!(
e.err().unwrap(),
BytesError::BadMessage(
"unrecognized certificate extension, with 'affects_validation' flag set."
BytesError::InvalidMessage(
"unrecognized certificate extension, with 'affects_validation' flag set.".into()
)
);

View File

@ -82,7 +82,9 @@ impl ExternallySigned<TimerangeBound<RsaCrosscert>> for UncheckedRsaCrosscert {
fn is_well_signed(&self, k: &Self::Key) -> Result<(), Self::Error> {
k.verify(&self.0.digest[..], &self.0.signature[..])
.map_err(|_| {
tor_bytes::Error::BadMessage("Invalid signature on RSA->Ed identity crosscert")
tor_bytes::Error::InvalidMessage(
"Invalid signature on RSA->Ed identity crosscert".into(),
)
})?;
Ok(())
}

View File

@ -17,7 +17,7 @@ fn cant_parse() {
assert_eq!(
decode_err(&hex!("03")),
Error::BadMessage("Unrecognized certificate version")
Error::InvalidMessage("Unrecognized certificate version".into())
);
assert_eq!(
@ -30,7 +30,7 @@ fn cant_parse() {
FF1A5203FA27F86EF7528D89A0845D2520166E340754FFEA2AAE0F612B7CE5DA
094A0236CDAC45034B0B6842C18E7F6B51B93A3CF7E60663B8AD061C30A62602"
)),
Error::BadMessage("wrong length on Ed25519 key")
Error::InvalidMessage("wrong length on Ed25519 key".into())
);
assert_eq!(
@ -43,8 +43,8 @@ fn cant_parse() {
FF1A5203FA27F86EF7528D89A0845D2520166E340754FFEA2AAE0F612B7CE5DA
094A0236CDAC45034B0B6842C18E7F6B51B93A3CF7E60663B8AD061C30A62602"
)),
Error::BadMessage(
"unrecognized certificate extension, with 'affects_validation' flag set."
Error::InvalidMessage(
"unrecognized certificate extension, with 'affects_validation' flag set.".into()
)
);
}

View File

@ -75,8 +75,10 @@ use safelog::sensitive as sv;
use std::pin::Pin;
use std::sync::{Mutex, MutexGuard};
use std::time::Duration;
use tor_cell::chancell::ChanMsg;
use tor_cell::chancell::msg::AnyChanMsg;
use tor_cell::chancell::{msg, msg::PaddingNegotiate, AnyChanCell, CircId};
use tor_cell::chancell::{ChanCell, ChanMsg};
use tor_cell::restricted_msg;
use tor_error::internal;
use tor_linkspec::{HasRelayIds, OwnedChanTarget};
use tor_rtcompat::SleepProvider;
@ -114,9 +116,39 @@ use crate::channel::unique_id::CircUniqIdContext;
pub(crate) use codec::CodecError;
pub use handshake::{OutboundClientHandshake, UnverifiedChannel, VerifiedChannel};
restricted_msg! {
/// A channel message that we allow to be sent from a server to a client on
/// an open channel.
///
/// (An Open channel here is one on which we have received a NETINFO cell.)
///
/// Note that an unexpected message type will _not_ be ignored: instead, it
/// will cause the channel to shut down.
#[derive(Clone, Debug)]
pub(crate) enum OpenChanMsgS2C : ChanMsg {
Padding,
Vpadding,
// Not Create*, since we are not a relay.
// Not Created, since we never send CREATE.
CreatedFast,
Created2,
Relay,
// Not RelayEarly, since we are a client.
Destroy,
// Not PaddingNegotiate, since we are not a relay.
// Not Versions, Certs, AuthChallenge, Authenticate: they are for handshakes.
// Not Authorize: it is reserved, but unused.
}
}
/// A channel cell that we allot to be sent on an open channel from
/// a server to a client.
pub(crate) type OpenChanCellS2C = ChanCell<OpenChanMsgS2C>;
/// Type alias: A Sink and Stream that transforms a TLS connection into
/// a cell-based communication mechanism.
type CellFrame<T> = futures_codec::Framed<T, crate::channel::codec::ChannelCodec>;
type CellFrame<T> =
futures_codec::Framed<T, crate::channel::codec::ChannelCodec<OpenChanMsgS2C, AnyChanMsg>>;
/// An open client channel, ready to send and receive Tor cells.
///

View File

@ -1,8 +1,9 @@
//! Wrap tor_cell::...:::ChannelCodec for use with the futures_codec
//! crate.
use std::io::Error as IoError;
use std::{io::Error as IoError, marker::PhantomData};
use tor_cell::chancell::{codec, AnyChanCell};
use futures::{AsyncRead, AsyncWrite};
use tor_cell::chancell::{codec, ChanCell, ChanMsg};
use asynchronous_codec as futures_codec;
use bytes::BytesMut;
@ -31,36 +32,89 @@ pub(crate) enum CodecError {
/// for use with futures_codec.
///
/// This type lets us wrap a TLS channel (or some other secure
/// AsyncRead+AsyncWrite type) as a Sink and a Stream of ChanCell, so we
/// can forget about byte-oriented communication.
pub(crate) struct ChannelCodec(codec::ChannelCodec);
/// AsyncRead+AsyncWrite type) as a Sink and a Stream of ChanCell, so we can
/// forget about byte-oriented communication.
///
/// It's parameterized on two message types: one that we're allowed to receive
/// (`IN`), and one that we're allowed to send (`OUT`).
pub(crate) struct ChannelCodec<IN, OUT> {
/// The cell codec that we'll use to encode and decode our cells.
inner: codec::ChannelCodec,
/// Tells the compiler that we're using IN, and we might
/// consume values of type IN.
_phantom_in: PhantomData<fn(IN)>,
/// Tells the compiler that we're using OUT, and we might
/// produce values of type OUT.
_phantom_out: PhantomData<fn() -> OUT>,
}
impl ChannelCodec {
impl<IN, OUT> ChannelCodec<IN, OUT> {
/// Create a new ChannelCodec with a given link protocol.
pub(crate) fn new(link_proto: u16) -> Self {
ChannelCodec(codec::ChannelCodec::new(link_proto))
ChannelCodec {
inner: codec::ChannelCodec::new(link_proto),
_phantom_in: PhantomData,
_phantom_out: PhantomData,
}
}
/// Consume this codec, and return a new one that sends and receives
/// different message types.
pub(crate) fn change_message_types<IN2, OUT2>(self) -> ChannelCodec<IN2, OUT2> {
ChannelCodec {
inner: self.inner,
_phantom_in: PhantomData,
_phantom_out: PhantomData,
}
}
}
impl futures_codec::Encoder for ChannelCodec {
type Item = AnyChanCell;
impl<IN, OUT> futures_codec::Encoder for ChannelCodec<IN, OUT>
where
OUT: ChanMsg,
{
type Item = ChanCell<OUT>;
type Error = CodecError;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.0.write_cell(item, dst).map_err(CodecError::EncCell)?;
self.inner
.write_cell(item, dst)
.map_err(CodecError::EncCell)?;
Ok(())
}
}
impl futures_codec::Decoder for ChannelCodec {
type Item = AnyChanCell;
impl<IN, OUT> futures_codec::Decoder for ChannelCodec<IN, OUT>
where
IN: ChanMsg,
{
type Item = ChanCell<IN>;
type Error = CodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.0.decode_cell(src).map_err(CodecError::DecCell)
self.inner.decode_cell(src).map_err(CodecError::DecCell)
}
}
/// Consume a [`Framed`](futures_codec::Framed) codec user, and produce one that
/// sends and receives different message types.
pub(crate) fn change_message_types<T, IN, OUT, IN2, OUT2>(
framed: futures_codec::Framed<T, ChannelCodec<IN, OUT>>,
) -> futures_codec::Framed<T, ChannelCodec<IN2, OUT2>>
where
T: AsyncRead + AsyncWrite,
IN: ChanMsg,
OUT: ChanMsg,
IN2: ChanMsg,
OUT2: ChanMsg,
{
futures_codec::Framed::from_parts(
framed
.into_parts()
.map_codec(ChannelCodec::change_message_types),
)
}
#[cfg(test)]
pub(crate) mod test {
#![allow(clippy::unwrap_used)]
@ -70,6 +124,7 @@ pub(crate) mod test {
use futures::task::{Context, Poll};
use hex_literal::hex;
use std::pin::Pin;
use tor_cell::chancell::msg::AnyChanMsg;
use super::{futures_codec, ChannelCodec};
use tor_cell::chancell::{msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
@ -128,7 +183,9 @@ pub(crate) mod test {
}
}
fn frame_buf(mbuf: MsgBuf) -> futures_codec::Framed<MsgBuf, ChannelCodec> {
fn frame_buf(
mbuf: MsgBuf,
) -> futures_codec::Framed<MsgBuf, ChannelCodec<AnyChanMsg, AnyChanMsg>> {
futures_codec::Framed::new(mbuf, ChannelCodec::new(4))
}

View File

@ -5,9 +5,10 @@ use asynchronous_codec as futures_codec;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use tor_cell::restricted_msg;
use tor_error::internal;
use crate::channel::codec::{ChannelCodec, CodecError};
use crate::channel::codec::{self, ChannelCodec, CodecError};
use crate::channel::UniqId;
use crate::util::skew::ClockSkew;
use crate::{Error, Result};
@ -106,6 +107,26 @@ pub struct VerifiedChannel<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S
clock_skew: ClockSkew,
}
restricted_msg! {
/// A restricted subset of ChanMsg that can arrive during a handshake.
///
/// (These are messages that come after the VERSIONS cell, up to and
/// including the NETINFO.)
///
/// Note that unrecognized message types (ones not yet implemented in Arti)
/// cause an error, rather than getting ignored. That's intentional: if we
/// start to allow them in the future, we should negotiate a new Channel
/// protocol for the VERSIONS cell.
#[derive(Clone,Debug)]
enum HandshakeMsg : ChanMsg {
Padding,
Vpadding,
AuthChallenge,
Certs,
Netinfo
}
}
/// Convert a CodecError to an Error, under the context that it occurs while
/// doing a channel handshake.
fn codec_err_to_handshake(err: CodecError) -> Error {
@ -210,7 +231,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider>
// Now we can switch to using a "Framed". We can ignore the
// AsyncRead/AsyncWrite aspects of the tls, and just treat it
// as a stream and a sink for cells.
let codec = ChannelCodec::new(link_protocol);
let codec = ChannelCodec::<HandshakeMsg, HandshakeMsg>::new(link_protocol);
let mut tls = futures_codec::Framed::new(self.tls, codec);
// Read until we have the netinfo cells.
@ -221,14 +242,12 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider>
// Loop: reject duplicate and unexpected cells
trace!("{}: waiting for rest of handshake.", self.unique_id);
while let Some(m) = tls.next().await {
use msg::AnyChanMsg::*;
use HandshakeMsg::*;
let (_, m) = m.map_err(codec_err_to_handshake)?.into_circid_and_msg();
trace!("{}: received a {} cell.", self.unique_id, m.cmd());
match m {
// Are these technically allowed?
Padding(_) | Vpadding(_) => (),
// Unrecognized cells get ignored.
Unrecognized(_) => (),
// Clients don't care about AuthChallenge
AuthChallenge(_) => {
if seen_authchallenge {
@ -253,13 +272,6 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider>
netinfo = Some((n, coarsetime::Instant::now()));
break;
}
// No other cell types are allowed.
m => {
return Err(Error::HandshakeProto(format!(
"Unexpected cell type {}",
m.cmd()
)))
}
}
}
@ -285,7 +297,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static, S: SleepProvider>
};
Ok(UnverifiedChannel {
link_protocol,
tls,
tls: codec::change_message_types(tls),
certs_cell,
netinfo_cell,
clock_skew,
@ -806,7 +818,7 @@ pub(super) mod test {
assert!(matches!(err, Error::HandshakeProto(_)));
assert_eq!(
format!("{}", err),
"Handshake protocol violation: Unexpected cell type CREATE"
"Handshake protocol violation: Invalid cell on handshake: Error while parsing channel cell"
);
});
}

View File

@ -7,6 +7,8 @@
//! or in the error handling behavior.
use super::circmap::{CircEnt, CircMap};
use super::OpenChanCellS2C;
use crate::channel::OpenChanMsgS2C;
use crate::circuit::halfcirc::HalfCirc;
use crate::util::err::{ChannelClosed, ReactorError};
use crate::{Error, Result};
@ -35,8 +37,9 @@ use crate::circuit::celltypes::{ClientCircChanMsg, CreateResponse};
use tracing::{debug, trace};
/// A boxed trait object that can provide `ChanCell`s.
pub(super) type BoxedChannelStream =
Box<dyn Stream<Item = std::result::Result<AnyChanCell, CodecError>> + Send + Unpin + 'static>;
pub(super) type BoxedChannelStream = Box<
dyn Stream<Item = std::result::Result<OpenChanCellS2C, CodecError>> + Send + Unpin + 'static,
>;
/// A boxed trait object that can sink `ChanCell`s.
pub(super) type BoxedChannelSink =
Box<dyn Sink<AnyChanCell, Error = CodecError> + Send + Unpin + 'static>;
@ -303,9 +306,9 @@ impl<S: SleepProvider> Reactor<S> {
/// Helper: process a cell on a channel. Most cell types get ignored
/// or rejected; a few get delivered to circuits.
async fn handle_cell(&mut self, cell: AnyChanCell) -> Result<()> {
async fn handle_cell(&mut self, cell: OpenChanCellS2C) -> Result<()> {
let (circid, msg) = cell.into_circid_and_msg();
use AnyChanMsg::*;
use OpenChanMsgS2C::*;
match msg {
Relay(_) | Padding(_) | Vpadding(_) => {} // too frequent to log.
@ -313,41 +316,15 @@ impl<S: SleepProvider> Reactor<S> {
}
match msg {
// These aren't allowed on clients.
Create(_) | CreateFast(_) | Create2(_) | RelayEarly(_) | PaddingNegotiate(_) => Err(
Error::ChanProto(format!("{} cell on client channel", msg.cmd())),
),
// In theory this is allowed in clients, but we should never get
// one, since we don't use TAP.
Created(_) => Err(Error::ChanProto(format!(
"{} cell received, but we never send CREATEs",
msg.cmd()
))),
// These aren't allowed after handshaking is done.
Versions(_) | Certs(_) | Authorize(_) | Authenticate(_) | AuthChallenge(_)
| Netinfo(_) => Err(Error::ChanProto(format!(
"{} cell after handshake is done",
msg.cmd()
))),
// These are allowed, and need to be handled.
Relay(_) => self.deliver_relay(circid, msg).await,
Relay(_) => self.deliver_relay(circid, msg.into()).await,
Destroy(_) => self.deliver_destroy(circid, msg).await,
Destroy(_) => self.deliver_destroy(circid, msg.into()).await,
CreatedFast(_) | Created2(_) => self.deliver_created(circid, msg).await,
CreatedFast(_) | Created2(_) => self.deliver_created(circid, msg.into()).await,
// These are always ignored.
Padding(_) | Vpadding(_) => Ok(()),
// Unrecognized cell types should be safe to allow _on channels_,
// since they can't propagate.
Unrecognized(_) => Ok(()),
// tor_cells knows about this type, but we don't.
_ => Ok(()),
}
}
@ -476,7 +453,7 @@ pub(crate) mod test {
use tor_linkspec::OwnedChanTarget;
use tor_rtcompat::Runtime;
type CodecResult = std::result::Result<AnyChanCell, CodecError>;
type CodecResult = std::result::Result<OpenChanCellS2C, CodecError>;
pub(crate) fn new_reactor<R: Runtime>(
runtime: R,
@ -617,7 +594,7 @@ pub(crate) mod test {
rtc.sleep(Duration::from_millis(100)).await;
trace!("sending createdfast");
// We'll get a bad handshake result from this createdfast cell.
let created_cell = AnyChanCell::new(id, msg::CreatedFast::new(*b"x").into());
let created_cell = OpenChanCellS2C::new(id, msg::CreatedFast::new(*b"x").into());
input.send(Ok(created_cell)).await.unwrap();
reactor.run_once().await.unwrap();
};
@ -646,26 +623,13 @@ pub(crate) mod test {
use tor_cell::chancell::msg;
let (_chan, mut reactor, _output, mut input) = new_reactor(rt);
// We shouldn't get create cells, ever.
let create_cell = msg::Create2::new(4, *b"hihi").into();
input
.send(Ok(AnyChanCell::new(9.into(), create_cell)))
.await
.unwrap();
// shouldn't get created2 cells for nonexistent circuits
let created2_cell = msg::Created2::new(*b"hihi").into();
input
.send(Ok(AnyChanCell::new(7.into(), created2_cell)))
.send(Ok(OpenChanCellS2C::new(7.into(), created2_cell)))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
assert_eq!(
format!("{}", e),
"Channel protocol violation: CREATE2 cell on client channel"
);
let e = reactor.run_once().await.unwrap_err().unwrap_err();
assert_eq!(
format!("{}", e),
@ -675,7 +639,7 @@ pub(crate) mod test {
// Can't get a relay cell on a circuit we've never heard of.
let relay_cell = msg::Relay::new(b"abc").into();
input
.send(Ok(AnyChanCell::new(4.into(), relay_cell)))
.send(Ok(OpenChanCellS2C::new(4.into(), relay_cell)))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
@ -684,29 +648,9 @@ pub(crate) mod test {
"Channel protocol violation: Relay cell on nonexistent circuit"
);
// Can't get handshaking cells while channel is open.
let versions_cell = msg::Versions::new([3]).unwrap().into();
input
.send(Ok(AnyChanCell::new(0.into(), versions_cell)))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
assert_eq!(
format!("{}", e),
"Channel protocol violation: VERSIONS cell after handshake is done"
);
// We don't accept CREATED.
let created_cell = msg::Created::new(&b"xyzzy"[..]).into();
input
.send(Ok(AnyChanCell::new(25.into(), created_cell)))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
assert_eq!(
format!("{}", e),
"Channel protocol violation: CREATED cell received, but we never send CREATEs"
);
// There used to be tests here for other types, but now that we only
// accept OpenClientChanCell, we know that the codec can't even try
// to give us e.g. VERSIONS or CREATE.
});
}
@ -737,9 +681,9 @@ pub(crate) mod test {
// If a relay cell is sent on an open channel, the correct circuit
// should get it.
let relaycell: AnyChanMsg = msg::Relay::new(b"do you suppose").into();
let relaycell: OpenChanMsgS2C = msg::Relay::new(b"do you suppose").into();
input
.send(Ok(AnyChanCell::new(13.into(), relaycell.clone())))
.send(Ok(OpenChanCellS2C::new(13.into(), relaycell.clone())))
.await
.unwrap();
reactor.run_once().await.unwrap();
@ -748,7 +692,7 @@ pub(crate) mod test {
// If a relay cell is sent on an opening channel, that's an error.
input
.send(Ok(AnyChanCell::new(7.into(), relaycell.clone())))
.send(Ok(OpenChanCellS2C::new(7.into(), relaycell.clone())))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
@ -759,7 +703,7 @@ pub(crate) mod test {
// If a relay cell is sent on a non-existent channel, that's an error.
input
.send(Ok(AnyChanCell::new(101.into(), relaycell.clone())))
.send(Ok(OpenChanCellS2C::new(101.into(), relaycell.clone())))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
@ -774,7 +718,7 @@ pub(crate) mod test {
// We can do this 25 more times according to our setup:
for _ in 0..25 {
input
.send(Ok(AnyChanCell::new(23.into(), relaycell.clone())))
.send(Ok(OpenChanCellS2C::new(23.into(), relaycell.clone())))
.await
.unwrap();
reactor.run_once().await.unwrap(); // should be fine.
@ -782,7 +726,7 @@ pub(crate) mod test {
// This one will fail.
input
.send(Ok(AnyChanCell::new(23.into(), relaycell.clone())))
.send(Ok(OpenChanCellS2C::new(23.into(), relaycell.clone())))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();
@ -819,9 +763,9 @@ pub(crate) mod test {
};
// Destroying an opening circuit is fine.
let destroycell: AnyChanMsg = msg::Destroy::new(0.into()).into();
let destroycell: OpenChanMsgS2C = msg::Destroy::new(0.into()).into();
input
.send(Ok(AnyChanCell::new(7.into(), destroycell.clone())))
.send(Ok(OpenChanCellS2C::new(7.into(), destroycell.clone())))
.await
.unwrap();
reactor.run_once().await.unwrap();
@ -830,7 +774,7 @@ pub(crate) mod test {
// Destroying an open circuit is fine.
input
.send(Ok(AnyChanCell::new(13.into(), destroycell.clone())))
.send(Ok(OpenChanCellS2C::new(13.into(), destroycell.clone())))
.await
.unwrap();
reactor.run_once().await.unwrap();
@ -839,14 +783,14 @@ pub(crate) mod test {
// Destroying a DestroySent circuit is fine.
input
.send(Ok(AnyChanCell::new(23.into(), destroycell.clone())))
.send(Ok(OpenChanCellS2C::new(23.into(), destroycell.clone())))
.await
.unwrap();
reactor.run_once().await.unwrap();
// Destroying a nonexistent circuit is an error.
input
.send(Ok(AnyChanCell::new(101.into(), destroycell.clone())))
.send(Ok(OpenChanCellS2C::new(101.into(), destroycell.clone())))
.await
.unwrap();
let e = reactor.run_once().await.unwrap_err().unwrap_err();

View File

@ -803,6 +803,7 @@ mod test {
//! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
use super::*;
use crate::channel::OpenChanCellS2C;
use crate::channel::{test::new_reactor, CodecError};
use crate::crypto::cell::RelayCellBody;
use chanmsg::{AnyChanMsg, Created2, CreatedFast};
@ -860,7 +861,7 @@ mod test {
) -> (
Channel,
Receiver<AnyChanCell>,
Sender<std::result::Result<AnyChanCell, CodecError>>,
Sender<std::result::Result<OpenChanCellS2C, CodecError>>,
) {
let (channel, chan_reactor, rx, tx) = new_reactor(rt.clone());
rt.spawn(async {
@ -1351,7 +1352,7 @@ mod test {
StreamId,
usize,
Receiver<AnyChanCell>,
Sender<std::result::Result<AnyChanCell, CodecError>>,
Sender<std::result::Result<OpenChanCellS2C, CodecError>>,
) {
let (chan, mut rx, sink2) = working_fake_channel(rt);
let (circ, mut sink) = newcirc(rt, chan).await;

View File

@ -126,9 +126,9 @@ mod test {
bad(msg::CreatedFast::new(&b"guaranteed in this world"[..]).into());
bad(msg::Created2::new(&b"and the next"[..]).into());
good(msg::Relay::new(&b"guaranteed guaranteed"[..]).into());
bad(msg::AnyChanMsg::RelayEarly(msg::Relay::new(
&b"for the world and its mother"[..],
)));
bad(msg::AnyChanMsg::RelayEarly(
msg::Relay::new(&b"for the world and its mother"[..]).into(),
));
bad(msg::Versions::new([1, 2, 3]).unwrap().into());
}
}

View File

@ -324,7 +324,7 @@ where
let mut rng = rand::thread_rng();
let unique_id = reactor.unique_id;
use tor_cell::relaycell::msg::{Body, Extend2};
use tor_cell::relaycell::msg::Extend2;
// Perform the first part of the cryptographic handshake
let (state, msg) = H::client1(&mut rng, key)?;
@ -339,7 +339,7 @@ where
);
let extend_msg = Extend2::new(linkspecs, handshake_id, msg);
let cell = AnyRelayCell::new(0.into(), extend_msg.into_message());
let cell = AnyRelayCell::new(0.into(), extend_msg.into());
// Send the message to the last hop...
reactor.send_relay_cell(
@ -1007,7 +1007,7 @@ impl Reactor {
// the whole circuit (e.g. by returning an error).
let msg = chancell::msg::Relay::from_raw(body.into());
let msg = if early {
AnyChanMsg::RelayEarly(msg)
AnyChanMsg::RelayEarly(msg.into())
} else {
AnyChanMsg::Relay(msg)
};

View File

@ -41,18 +41,20 @@ impl Readable for SocksAddr {
let hlen = r.take_u8()?;
let hostname = r.take(hlen as usize)?;
let hostname = std::str::from_utf8(hostname)
.map_err(|_| BytesError::BadMessage("bad utf8 on hostname"))?
.map_err(|_| BytesError::InvalidMessage("bad utf8 on hostname".into()))?
.to_string();
let hostname = hostname
.try_into()
.map_err(|_| BytesError::BadMessage("hostname too long"))?;
.map_err(|_| BytesError::InvalidMessage("hostname too long".into()))?;
Ok(SocksAddr::Hostname(hostname))
}
4 => {
let ip6: std::net::Ipv6Addr = r.extract()?;
Ok(SocksAddr::Ip(ip6.into()))
}
_ => Err(BytesError::BadMessage("unrecognized address type.")),
_ => Err(BytesError::InvalidMessage(
"unrecognized address type.".into(),
)),
}
}
}

View File

@ -124,7 +124,7 @@ impl SocksProxyHandshake {
.to_string();
let hostname = hostname
.try_into()
.map_err(|_| BytesError::BadMessage("hostname too long"))?;
.map_err(|_| BytesError::InvalidMessage("hostname too long".into()))?;
SocksAddr::Hostname(hostname)
} else {
let ip4: std::net::Ipv4Addr = ip.into();