cell: Parameterize ChannelCodec::decode and encode.

This change lets us use ChannelCodec to encode and decode any
restricted channel message type we want.  (Later on, we'll turn the
related Codec class in tor-proto into a more type-restricted version
of this.)
This commit is contained in:
Nick Mathewson 2023-02-07 16:37:16 -05:00
parent 9046ef90d0
commit b6f6fa4d4b
3 changed files with 21 additions and 12 deletions

View File

@ -6,3 +6,4 @@ BREAKING: Moved ChanMsg methods into a trait.
BREAKING: Moved RelayMsg methods into a trait.
BREAKING: Renamed ChanCell->AnyChanCell, ChanMsg->AnyChanMsg.
BREAKING: Renamed RelayCell->AnyRelayCell, RelayMsg->AnyRelayMsg.
BREAKING: Make ChannelCodec::decode() parameterized.

View File

@ -1,7 +1,7 @@
//! Implementation for encoding and decoding of ChanCells.
use super::CELL_DATA_LEN;
use crate::chancell::{msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
use super::{ChanCell, CELL_DATA_LEN};
use crate::chancell::{ChanCmd, ChanMsg, CircId};
use crate::Error;
use arrayref::{array_mut_ref, array_ref};
use tor_bytes::{self, Reader, Writer};
@ -49,8 +49,12 @@ impl ChannelCodec {
}
/// Write the given cell into the provided BytesMut object.
pub fn write_cell(&mut self, item: AnyChanCell, dst: &mut BytesMut) -> crate::Result<()> {
let AnyChanCell { circid, msg } = item;
pub fn write_cell<M: ChanMsg>(
&mut self,
item: ChanCell<M>,
dst: &mut BytesMut,
) -> crate::Result<()> {
let ChanCell { circid, msg } = item;
let cmd = msg.cmd();
dst.write_u32(circid.into());
dst.write_u8(cmd.into());
@ -83,7 +87,10 @@ impl ChannelCodec {
///
/// On a definite decoding error, return Err(_). On a cell that might
/// just be truncated, return Ok(None).
pub fn decode_cell(&mut self, src: &mut BytesMut) -> crate::Result<Option<AnyChanCell>> {
pub fn decode_cell<M: ChanMsg>(
&mut self,
src: &mut BytesMut,
) -> crate::Result<Option<ChanCell<M>>> {
/// Wrap `be` as an appropriate type.
fn wrap_err(be: tor_bytes::Error) -> crate::Error {
crate::Error::BytesErr {
@ -113,7 +120,7 @@ impl ChannelCodec {
let mut r = Reader::from_bytes(&cell);
let circid: CircId = r.take_u32().map_err(wrap_err)?.into();
r.advance(if varcell { 3 } else { 1 }).map_err(wrap_err)?;
let msg = msg::AnyChanMsg::decode_from_reader(cmd, &mut r).map_err(wrap_err)?;
let msg = M::decode_from_reader(cmd, &mut r).map_err(wrap_err)?;
if !cmd.accepts_circid_val(circid) {
return Err(Error::ChanProto(format!(
@ -121,6 +128,6 @@ impl ChannelCodec {
circid, cmd
)));
}
Ok(Some(AnyChanCell { circid, msg }))
Ok(Some(ChanCell { circid, msg }))
}
}

View File

@ -3,6 +3,7 @@
// Reminder: you can think of a cell as an message plus a circuitid.
#![allow(clippy::uninlined_format_args)]
use tor_cell::chancell::msg::AnyChanMsg;
use tor_cell::chancell::{codec, msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
use tor_cell::Error;
@ -31,7 +32,7 @@ fn cell(body: &str, msg: msg::AnyChanMsg, id: CircId, pad_body: bool) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
bm.extend_from_slice(&b"next thing"[..]);
let decoded = codec.decode_cell(&mut bm).unwrap();
let decoded = codec.decode_cell::<AnyChanMsg>(&mut bm).unwrap();
assert_eq!(bm.len(), 10);
decoded.unwrap()
};
@ -40,7 +41,7 @@ fn cell(body: &str, msg: msg::AnyChanMsg, id: CircId, pad_body: bool) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
// no extra bytes this time.
let decoded = codec.decode_cell(&mut bm).unwrap();
let decoded = codec.decode_cell::<AnyChanMsg>(&mut bm).unwrap();
assert_eq!(bm.len(), 0);
decoded.unwrap()
};
@ -90,7 +91,7 @@ fn test_simple_cells() {
let mut bm = BytesMut::new();
bm.extend_from_slice(&m);
codec::ChannelCodec::new(4)
.decode_cell(&mut bm)
.decode_cell::<AnyChanMsg>(&mut bm)
.unwrap()
.unwrap()
};
@ -109,7 +110,7 @@ fn short_cell(body: &str) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
let len_orig = bm.len();
let d = codec.decode_cell(&mut bm);
let d = codec.decode_cell::<AnyChanMsg>(&mut bm);
assert!(d.unwrap().is_none()); // "Ok(None)" means truncated.
assert_eq!(bm.len(), len_orig);
}
@ -141,7 +142,7 @@ fn bad_cell(body: &str, err: Error, pad_body: bool) {
let mut bm = BytesMut::new();
bm.extend_from_slice(&body[..]);
bm.extend_from_slice(&b"next thing"[..]);
codec.decode_cell(&mut bm).err().unwrap()
codec.decode_cell::<AnyChanMsg>(&mut bm).err().unwrap()
};
assert_eq!(format!("{:?}", decoded), format!("{:?}", err));