linkspec: Add an "all_overlapping" accessor to ByRelayIds.

Also, add a few tests for this and the other accessors.

We'll need this accessor to find whether we have any channels to
_any_ of the identities that we're trying to connect to.
This commit is contained in:
Nick Mathewson 2022-10-14 12:31:16 -04:00
parent 1a1a1af5d8
commit 0c8a5a1fa4
4 changed files with 124 additions and 1 deletions

7
Cargo.lock generated
View File

@ -582,6 +582,12 @@ version = "3.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d"
[[package]]
name = "by_address"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e245704f60eb4eb45810d65cf14eb54d2eb50a6f3715fe2d7cd01ee905c2944f"
[[package]] [[package]]
name = "bytemuck" name = "bytemuck"
version = "1.12.1" version = "1.12.1"
@ -3715,6 +3721,7 @@ name = "tor-linkspec"
version = "0.5.1" version = "0.5.1"
dependencies = [ dependencies = [
"base64ct", "base64ct",
"by_address",
"cfg-if", "cfg-if",
"derive_builder_fork_arti", "derive_builder_fork_arti",
"derive_more", "derive_more",

View File

@ -18,6 +18,7 @@ pt-client = []
[dependencies] [dependencies]
base64ct = "1.5.1" base64ct = "1.5.1"
by_address = "1"
cfg-if = "1.0.0" cfg-if = "1.0.0"
derive_builder = { version = "0.11.2", package = "derive_builder_fork_arti" } derive_builder = { version = "0.11.2", package = "derive_builder_fork_arti" }
derive_more = "0.99" derive_more = "0.99"

View File

@ -52,9 +52,121 @@ impl<H: HasRelayIds> ByRelayIds<H> {
self.by_id(any_id) self.by_id(any_id)
.filter(|val| val.has_all_relay_ids_from(key)) .filter(|val| val.has_all_relay_ids_from(key))
} }
/// Return a reference to every element in this set that shares _any_ ID
/// with `key`.
///
/// No element is returned more than once.
pub fn all_overlapping<T>(&self, key: &T) -> Vec<&H>
where
T: HasRelayIds,
{
use by_address::ByAddress;
use std::collections::HashSet;
let mut items: HashSet<ByAddress<&H>> = HashSet::new();
for ident in key.identities() {
if let Some(found) = self.by_id(ident) {
items.insert(ByAddress(found));
}
}
items.into_iter().map(|by_addr| by_addr.0).collect()
}
} }
// TODO MSRV: Remove this `allow` once we no longer get a false positive // TODO MSRV: Remove this `allow` once we no longer get a false positive
// for it on our MSRV. 1.56 is affected; 1.60 is not. // for it on our MSRV. 1.56 is affected; 1.60 is not.
#[allow(unreachable_pub)] #[allow(unreachable_pub)]
pub use tor_basic_utils::n_key_set::Error as ByRelayIdsError; pub use tor_basic_utils::n_key_set::Error as ByRelayIdsError;
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::{RelayIds, RelayIdsBuilder};
#[test]
fn lookup() {
let rsa1: RsaIdentity = (*b"12345678901234567890").into();
let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
let rsa3: RsaIdentity = (*b"abcefghijklmnopQRSTU").into();
let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
let ed2: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
let ed3: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyz1234567").into();
let keys1 = RelayIdsBuilder::default()
.rsa_identity(rsa1)
.ed_identity(ed1)
.build()
.unwrap();
let keys2 = RelayIdsBuilder::default()
.rsa_identity(rsa2)
.ed_identity(ed2)
.build()
.unwrap();
let mut set = ByRelayIds::new();
set.insert(keys1.clone());
set.insert(keys2.clone());
// Try by_id
assert_eq!(set.by_id(&rsa1), Some(&keys1));
assert_eq!(set.by_id(&ed1), Some(&keys1));
assert_eq!(set.by_id(&rsa2), Some(&keys2));
assert_eq!(set.by_id(&ed2), Some(&keys2));
assert_eq!(set.by_id(&rsa3), None);
assert_eq!(set.by_id(&ed3), None);
// Try exact lookup
assert_eq!(set.by_all_ids(&keys1), Some(&keys1));
assert_eq!(set.by_all_ids(&keys2), Some(&keys2));
{
let search = RelayIdsBuilder::default()
.rsa_identity(rsa1)
.build()
.unwrap();
assert_eq!(set.by_all_ids(&search), Some(&keys1));
}
{
let search = RelayIdsBuilder::default()
.rsa_identity(rsa1)
.ed_identity(ed2)
.build()
.unwrap();
assert_eq!(set.by_all_ids(&search), None);
}
// Try looking for overlap
assert_eq!(set.all_overlapping(&keys1), vec![&keys1]);
assert_eq!(set.all_overlapping(&keys2), vec![&keys2]);
{
let search = RelayIdsBuilder::default()
.rsa_identity(rsa1)
.ed_identity(ed2)
.build()
.unwrap();
let answer = set.all_overlapping(&search);
assert_eq!(answer.len(), 2);
assert!(answer.contains(&&keys1));
assert!(answer.contains(&&keys2));
}
{
let search = RelayIdsBuilder::default()
.rsa_identity(rsa2)
.build()
.unwrap();
assert_eq!(set.all_overlapping(&search), vec![&keys2]);
}
{
let search = RelayIdsBuilder::default()
.rsa_identity(rsa3)
.build()
.unwrap();
assert_eq!(set.all_overlapping(&search), Vec::<&RelayIds>::new());
}
}
}

View File

@ -48,7 +48,10 @@ pub use ids::{
RelayId, RelayIdError, RelayIdRef, RelayIdType, RelayIdTypeIter, RelayId, RelayIdError, RelayIdRef, RelayIdType, RelayIdTypeIter,
}; };
pub use ls::LinkSpec; pub use ls::LinkSpec;
pub use owned::{OwnedChanTarget, OwnedChanTargetBuilder, OwnedCircTarget, RelayIds}; pub use owned::{
OwnedChanTarget, OwnedChanTargetBuilder, OwnedCircTarget, OwnedCircTargetBuilder, RelayIds,
RelayIdsBuilder,
};
pub use traits::{ pub use traits::{
ChanTarget, CircTarget, DirectChanMethodsHelper, HasAddrs, HasChanMethod, HasRelayIds, ChanTarget, CircTarget, DirectChanMethodsHelper, HasAddrs, HasChanMethod, HasRelayIds,
HasRelayIdsLegacy, HasRelayIdsLegacy,