diff --git a/Cargo.lock b/Cargo.lock index 8aed041db..a1a25d2f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -582,6 +582,12 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +[[package]] +name = "by_address" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e245704f60eb4eb45810d65cf14eb54d2eb50a6f3715fe2d7cd01ee905c2944f" + [[package]] name = "bytemuck" version = "1.12.1" @@ -3715,6 +3721,7 @@ name = "tor-linkspec" version = "0.5.1" dependencies = [ "base64ct", + "by_address", "cfg-if", "derive_builder_fork_arti", "derive_more", diff --git a/crates/tor-linkspec/Cargo.toml b/crates/tor-linkspec/Cargo.toml index 39828bcd8..fc2653bd5 100644 --- a/crates/tor-linkspec/Cargo.toml +++ b/crates/tor-linkspec/Cargo.toml @@ -18,6 +18,7 @@ pt-client = [] [dependencies] base64ct = "1.5.1" +by_address = "1" cfg-if = "1.0.0" derive_builder = { version = "0.11.2", package = "derive_builder_fork_arti" } derive_more = "0.99" diff --git a/crates/tor-linkspec/src/ids/by_id.rs b/crates/tor-linkspec/src/ids/by_id.rs index 899e86397..be95a70c3 100644 --- a/crates/tor-linkspec/src/ids/by_id.rs +++ b/crates/tor-linkspec/src/ids/by_id.rs @@ -52,9 +52,121 @@ impl ByRelayIds { self.by_id(any_id) .filter(|val| val.has_all_relay_ids_from(key)) } + + /// Return a reference to every element in this set that shares _any_ ID + /// with `key`. + /// + /// No element is returned more than once. + pub fn all_overlapping(&self, key: &T) -> Vec<&H> + where + T: HasRelayIds, + { + use by_address::ByAddress; + use std::collections::HashSet; + + let mut items: HashSet> = HashSet::new(); + + for ident in key.identities() { + if let Some(found) = self.by_id(ident) { + items.insert(ByAddress(found)); + } + } + + items.into_iter().map(|by_addr| by_addr.0).collect() + } } // TODO MSRV: Remove this `allow` once we no longer get a false positive // for it on our MSRV. 1.56 is affected; 1.60 is not. #[allow(unreachable_pub)] pub use tor_basic_utils::n_key_set::Error as ByRelayIdsError; + +#[cfg(test)] +mod test { + #![allow(clippy::unwrap_used)] + + use super::*; + use crate::{RelayIds, RelayIdsBuilder}; + + #[test] + fn lookup() { + let rsa1: RsaIdentity = (*b"12345678901234567890").into(); + let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into(); + let rsa3: RsaIdentity = (*b"abcefghijklmnopQRSTU").into(); + let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into(); + let ed2: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into(); + let ed3: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyz1234567").into(); + + let keys1 = RelayIdsBuilder::default() + .rsa_identity(rsa1) + .ed_identity(ed1) + .build() + .unwrap(); + + let keys2 = RelayIdsBuilder::default() + .rsa_identity(rsa2) + .ed_identity(ed2) + .build() + .unwrap(); + + let mut set = ByRelayIds::new(); + set.insert(keys1.clone()); + set.insert(keys2.clone()); + + // Try by_id + assert_eq!(set.by_id(&rsa1), Some(&keys1)); + assert_eq!(set.by_id(&ed1), Some(&keys1)); + assert_eq!(set.by_id(&rsa2), Some(&keys2)); + assert_eq!(set.by_id(&ed2), Some(&keys2)); + assert_eq!(set.by_id(&rsa3), None); + assert_eq!(set.by_id(&ed3), None); + + // Try exact lookup + assert_eq!(set.by_all_ids(&keys1), Some(&keys1)); + assert_eq!(set.by_all_ids(&keys2), Some(&keys2)); + { + let search = RelayIdsBuilder::default() + .rsa_identity(rsa1) + .build() + .unwrap(); + assert_eq!(set.by_all_ids(&search), Some(&keys1)); + } + { + let search = RelayIdsBuilder::default() + .rsa_identity(rsa1) + .ed_identity(ed2) + .build() + .unwrap(); + assert_eq!(set.by_all_ids(&search), None); + } + + // Try looking for overlap + assert_eq!(set.all_overlapping(&keys1), vec![&keys1]); + assert_eq!(set.all_overlapping(&keys2), vec![&keys2]); + { + let search = RelayIdsBuilder::default() + .rsa_identity(rsa1) + .ed_identity(ed2) + .build() + .unwrap(); + let answer = set.all_overlapping(&search); + assert_eq!(answer.len(), 2); + assert!(answer.contains(&&keys1)); + assert!(answer.contains(&&keys2)); + } + { + let search = RelayIdsBuilder::default() + .rsa_identity(rsa2) + .build() + .unwrap(); + assert_eq!(set.all_overlapping(&search), vec![&keys2]); + } + { + let search = RelayIdsBuilder::default() + .rsa_identity(rsa3) + .build() + .unwrap(); + assert_eq!(set.all_overlapping(&search), Vec::<&RelayIds>::new()); + } + } +} diff --git a/crates/tor-linkspec/src/lib.rs b/crates/tor-linkspec/src/lib.rs index 1c121733a..8ae929523 100644 --- a/crates/tor-linkspec/src/lib.rs +++ b/crates/tor-linkspec/src/lib.rs @@ -48,7 +48,10 @@ pub use ids::{ RelayId, RelayIdError, RelayIdRef, RelayIdType, RelayIdTypeIter, }; pub use ls::LinkSpec; -pub use owned::{OwnedChanTarget, OwnedChanTargetBuilder, OwnedCircTarget, RelayIds}; +pub use owned::{ + OwnedChanTarget, OwnedChanTargetBuilder, OwnedCircTarget, OwnedCircTargetBuilder, RelayIds, + RelayIdsBuilder, +}; pub use traits::{ ChanTarget, CircTarget, DirectChanMethodsHelper, HasAddrs, HasChanMethod, HasRelayIds, HasRelayIdsLegacy,