diff --git a/Cargo.lock b/Cargo.lock index e563ed91e..f6679f1e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3864,6 +3864,7 @@ dependencies = [ "pin-project", "rand 0.8.5", "rand_core 0.6.3", + "regex", "statrs", "subtle", "thiserror", diff --git a/crates/tor-proto/Cargo.toml b/crates/tor-proto/Cargo.toml index 388dd1520..d889b35f1 100644 --- a/crates/tor-proto/Cargo.toml +++ b/crates/tor-proto/Cargo.toml @@ -59,6 +59,7 @@ zeroize = "1" hex = "0.4" hex-literal = "0.3" itertools = "0.10.1" +regex = { version = "1", default-features = false, features = ["std"] } statrs = "0.15.0" tokio-crate = { package = "tokio", version = "1.7", features = ["full"] } tor-rtcompat = { path = "../tor-rtcompat", version = "0.5.0", features = ["tokio", "native-tls"] } diff --git a/crates/tor-proto/src/channel.rs b/crates/tor-proto/src/channel.rs index be54c24c8..214c5fa00 100644 --- a/crates/tor-proto/src/channel.rs +++ b/crates/tor-proto/src/channel.rs @@ -336,25 +336,7 @@ impl Channel { /// Return an error if this channel is somehow mismatched with the /// given target. pub fn check_match(&self, target: &T) -> Result<()> { - for desired in target.identities() { - let id_type = desired.id_type(); - match self.details.peer_id.identity(id_type) { - Some(actual) if actual == desired => {} - Some(actual) => { - return Err(Error::ChanMismatch(format!( - "Identity {} does not match target {}", - actual, desired - ))); - } - None => { - return Err(Error::ChanMismatch(format!( - "Peer does not have {} identity", - id_type - ))) - } - } - } - Ok(()) + check_id_match_helper(&self.details.peer_id, target) } /// Return true if this channel is closed and therefore unusable. @@ -469,6 +451,37 @@ impl Channel { } } +/// If there is any identity in `wanted_ident` that is not present in +/// `my_ident`, return a ChanMismatch error. +/// +/// This is a helper for [`Channel::check_match`] and +/// [`UnverifiedChannel::check_internal`]. +fn check_id_match_helper(my_ident: &T, wanted_ident: &U) -> Result<()> +where + T: HasRelayIds + ?Sized, + U: HasRelayIds + ?Sized, +{ + for desired in wanted_ident.identities() { + let id_type = desired.id_type(); + match my_ident.identity(id_type) { + Some(actual) if actual == desired => {} + Some(actual) => { + return Err(Error::ChanMismatch(format!( + "Identity {} does not match target {}", + actual, desired + ))); + } + None => { + return Err(Error::ChanMismatch(format!( + "Peer does not have {} identity", + id_type + ))) + } + } + } + Ok(()) +} + #[cfg(test)] pub(crate) mod test { // Most of this module is tested via tests that also check on the diff --git a/crates/tor-proto/src/channel/handshake.rs b/crates/tor-proto/src/channel/handshake.rs index d9e166adb..3c2f8e875 100644 --- a/crates/tor-proto/src/channel/handshake.rs +++ b/crates/tor-proto/src/channel/handshake.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use std::time::SystemTime; use tor_bytes::Reader; -use tor_linkspec::{ChanTarget, HasRelayIds, OwnedChanTarget, RelayIds}; +use tor_linkspec::{ChanTarget, OwnedChanTarget, RelayIds}; use tor_llcrypto as ll; use tor_llcrypto::pk::ed25519::Ed25519Identity; use tor_llcrypto::pk::rsa::RsaIdentity; @@ -511,24 +511,10 @@ impl Unver // We enforce that the relay proved that it has every ID that we wanted: // it may also have additional IDs that we didn't ask for. - for desired_id in peer.identities() { - let id_type = desired_id.id_type(); - match actual_identity.identity(id_type) { - Some(actual) if actual == desired_id => {} - Some(_) => { - return Err(Error::HandshakeProto(format!( - "Peer {} id not as expected", - id_type - ))) - } - None => { - return Err(Error::HandshakeProto(format!( - "Peer did not present a {} id.", - id_type - ))) - } - } - } + match super::check_id_match_helper(&actual_identity, peer) { + Err(Error::ChanMismatch(msg)) => Err(Error::HandshakeProto(msg)), + other => other, + }?; // If we reach this point, the clock skew might be may now be considered // authenticated: The certificates are what we wanted, and everything @@ -612,6 +598,7 @@ impl Verif pub(super) mod test { #![allow(clippy::unwrap_used)] use hex_literal::hex; + use regex::Regex; use std::time::{Duration, SystemTime}; use super::*; @@ -942,10 +929,10 @@ pub(super) mod test { .err() .unwrap(); - assert_eq!( - format!("{}", err), - "Handshake protocol violation: Peer Ed25519 id not as expected" - ); + let re = Regex::new( + r"Identity .* does not match target ed25519:EBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA", + ).unwrap(); + assert!(re.is_match(&format!("{}", err))); let err = certs_test( certs.clone(), @@ -958,10 +945,11 @@ pub(super) mod test { .err() .unwrap(); - assert_eq!( - format!("{}", err), - "Handshake protocol violation: Peer RSA (legacy) id not as expected" - ); + let re = Regex::new( + r"Identity .* does not match target \$9999999999999999999999999999999999999999", + ) + .unwrap(); + assert!(re.is_match(&format!("{}", err))); let err = certs_test( certs,