diff --git a/Cargo.lock b/Cargo.lock index a4c4bdb71..9f3457862 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3300,6 +3300,7 @@ dependencies = [ "arrayref", "bytes", "digest 0.10.3", + "educe", "generic-array", "getrandom 0.2.6", "hex-literal", diff --git a/crates/tor-bytes/Cargo.toml b/crates/tor-bytes/Cargo.toml index 2a65985d9..96672e4de 100644 --- a/crates/tor-bytes/Cargo.toml +++ b/crates/tor-bytes/Cargo.toml @@ -15,6 +15,7 @@ repository = "https://gitlab.torproject.org/tpo/core/arti.git/" arrayref = "0.3" bytes = "1" digest = { version = "0.10.0", features = ["subtle", "mac"] } +educe = "0.4.6" generic-array = "0.14.3" signature = "1" thiserror = "1" diff --git a/crates/tor-bytes/src/writer.rs b/crates/tor-bytes/src/writer.rs index 7a6aa928a..b7985888d 100644 --- a/crates/tor-bytes/src/writer.rs +++ b/crates/tor-bytes/src/writer.rs @@ -1,8 +1,12 @@ //! Internal: Declare the Writer type for tor-bytes -use crate::Result; +use std::marker::PhantomData; + +use educe::Educe; + use crate::Writeable; use crate::WriteableOnce; +use crate::{Error, Result}; /// A byte-oriented trait for writing to small arrays. /// @@ -81,6 +85,94 @@ pub trait Writer { fn write_and_consume(&mut self, e: E) { e.write_into(self); } + + /// Arranges to write a u8 length, and some data whose encoding is that length + /// + /// Prefer to use this function, rather than manual length calculations + /// and ad-hoc `write_u8`, + /// Using this facility eliminates the need to separately keep track of the lengths. + /// + /// The returned `NestedWriter` should be used to write the contents, + /// inside the byte-counted section. + /// + /// Then you **must** call `finish` to finalise the buffer. + fn write_nested_u8len(&mut self) -> NestedWriter<'_, Self, u8> { + write_nested_generic(self) + } + /// Arranges to writes a u16 length and some data whose encoding is that length + fn write_nested_u16len(&mut self) -> NestedWriter<'_, Self, u16> { + write_nested_generic(self) + } + /// Arranges to writes a u32 length and some data whose encoding is that length + fn write_nested_u32len(&mut self) -> NestedWriter<'_, Self, u32> { + write_nested_generic(self) + } +} + +/// Work in progress state for writing a nested (length-counted) item +/// +/// You must call `finish` ! +#[derive(Educe)] +#[educe(Deref, DerefMut)] +pub struct NestedWriter<'w, W, L> +where + W: ?Sized, +{ + /// Variance doesn't matter since this is local to the module, but for form's sake: + /// Be invariant in `L`, as maximally conservative. + length_type: PhantomData<*mut L>, + + /// The outer writer + outer: &'w mut W, + + /// Our inner buffer + /// + /// Caller can use us as `Writer` via `DerefMut` + /// + /// (An alternative would be to `impl Writer` but that involves recapitulating + /// the impl for `Vec` and we do not have the `ambassador` crate to help us. + /// Exposing this inner `Vec` is harmless.) + /// + /// We must allocate here because some `Writer`s are streaming + #[educe(Deref, DerefMut)] + inner: Vec, +} + +/// Implementation of `write_nested_*` - generic over the length type +fn write_nested_generic(w: &mut W) -> NestedWriter +where + W: Writer + ?Sized, + L: Default + Copy + Sized + Writeable + TryFrom, +{ + NestedWriter { + length_type: PhantomData, + outer: w, + inner: vec![], + } +} + +impl<'w, W, L> NestedWriter<'w, W, L> +where + W: Writer + ?Sized, + L: Default + Copy + Sized + Writeable + TryFrom + std::ops::Not, +{ + /// Ends writing the nested data, and updates the length appropriately + /// + /// You must check the return value. + /// It will only be `Err` if the amount you wrote doesn't fit into the length field. + /// + /// Sadly, you may well be implementing a `Writeable`, in which case you + /// will have nothing good to do with the error, and must panic. + /// In these cases you should have ensured, somehow, that overflow cannot happen. + /// Ideally, by making your `Writeable` type incapable of holding values + /// whose encoded length doesn't fit in the length field. + pub fn finish(self) -> Result<()> { + let length = self.inner.len(); + let length: L = length.try_into().map_err(|_| Error::BadLengthValue)?; + self.outer.write(&length); + self.outer.write(&self.inner); + Ok(()) + } } #[cfg(test)] @@ -132,4 +224,27 @@ mod tests { v.write_and_consume(Sequence(3)); assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5, 0, 1, 2]); } + + #[test] + fn nested() { + let mut v: Vec = b"abc".to_vec(); + + let mut w = v.write_nested_u8len(); + w.write_u8(b'x'); + w.finish().unwrap(); + + let mut w = v.write_nested_u16len(); + w.write_u8(b'y'); + w.finish().unwrap(); + + let mut w = v.write_nested_u32len(); + w.write_u8(b'z'); + w.finish().unwrap(); + + assert_eq!(&v, b"abc\x01x\0\x01y\0\0\0\x01z"); + + let mut w = v.write_nested_u8len(); + w.write_zeros(256); + assert_eq!(w.finish().err().unwrap(), Error::BadLengthValue); + } }