Merge branch 'connect-hsdesc-bounds' into 'main'

hsclient: Build cached descriptor TimerangeBounds from descriptor lifetime.

See merge request tpo/core/arti!1154
This commit is contained in:
gabi-250 2023-05-13 12:14:10 +00:00
commit 3490ece8cf
6 changed files with 398 additions and 47 deletions

View File

@ -44,6 +44,7 @@ use std::mem;
pub mod iter;
pub mod n_key_set;
pub mod rangebounds;
pub mod retry;
pub mod test_rng;

View File

@ -0,0 +1,302 @@
//! This module exposes helpers for working with types that implement
//! [`RangeBounds`](std::ops::RangeBounds).
use std::cmp::{self, Ord};
use std::ops::{Bound, RangeBounds};
/// An extension trait for [`RangeBounds`](std::ops::RangeBounds).
pub trait RangeBoundsExt<T>: RangeBounds<T> {
/// Compute the intersection of two `RangeBound`s.
///
/// In essence, this computes the intersection of the intervals described by bounds of the
/// two objects.
///
/// Returns `None` if the intersection of the two ranges is the empty set.
fn intersect<'a, U: RangeBounds<T>>(
&'a self,
other: &'a U,
) -> Option<(Bound<&'a T>, Bound<&'a T>)>;
}
impl<T, R> RangeBoundsExt<T> for R
where
R: RangeBounds<T>,
T: Ord,
{
fn intersect<'a, U: RangeBounds<T>>(
&'a self,
other: &'a U,
) -> Option<(Bound<&'a T>, Bound<&'a T>)> {
use Bound::*;
let this_start = self.start_bound();
let other_start = other.start_bound();
let this_end = self.end_bound();
let other_end = other.end_bound();
let start = bounds_max(this_start, other_start);
let end = bounds_min(this_end, other_end);
match (start, end) {
(Excluded(start), Excluded(end)) | (Included(start), Excluded(end)) if start == end => {
// The interval (n, n) = [n, n) = {} (empty set).
None
}
(Included(start), Included(end))
| (Included(start), Excluded(end))
| (Excluded(start), Included(end))
| (Excluded(start), Excluded(end))
if start > end =>
{
// For any a > b, the intervals [a, b], [a, b), (a, b], (a, b) are empty.
None
}
_ => Some((start, end)),
}
}
}
/// Return the largest of `b1` and `b2`.
///
/// If one of the bounds is [Unbounded](Bound::Unbounded), the other will be returned.
fn bounds_max<'a, T: Ord>(b1: Bound<&'a T>, b2: Bound<&'a T>) -> Bound<&'a T> {
use Bound::*;
match (b1, b2) {
(Included(b1), Included(b2)) => Included(cmp::max(b1, b2)),
(Excluded(b1), Excluded(b2)) => Excluded(cmp::max(b1, b2)),
(Excluded(b1), Included(b2)) if b1 >= b2 => Excluded(b1),
(Excluded(_), Included(b2)) => Included(b2),
(Included(b1), Excluded(b2)) if b2 >= b1 => Excluded(b2),
(Included(b1), Excluded(_)) => Included(b1),
(b, Unbounded) | (Unbounded, b) => b,
}
}
/// Return the smallest of `b1` and `b2`.
///
/// If one of the bounds is [Unbounded](Bound::Unbounded), the other will be returned.
fn bounds_min<'a, T: Ord>(b1: Bound<&'a T>, b2: Bound<&'a T>) -> Bound<&'a T> {
use Bound::*;
match (b1, b2) {
(Included(b1), Included(b2)) => Included(cmp::min(b1, b2)),
(Excluded(b1), Excluded(b2)) => Excluded(cmp::min(b1, b2)),
(Excluded(b1), Included(b2)) if b1 <= b2 => Excluded(b1),
(Excluded(_), Included(b2)) => Included(b2),
(Included(b1), Excluded(b2)) if b2 <= b1 => Excluded(b2),
(Included(b1), Excluded(_)) => Included(b1),
(b, Unbounded) | (Unbounded, b) => b,
}
}
#[cfg(test)]
mod test {
// @@ begin test lint list maintained by maint/add_warning @@
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
//! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
use super::*;
use std::fmt::Debug;
use std::time::{Duration, SystemTime};
use Bound::{Excluded as Excl, Included as Incl, Unbounded};
/// A helper that computes the intersection of `range1` and `range2`.
///
/// This function also asserts that the intersection operation is commutative.
fn intersect<'a, T, R: RangeBounds<T>>(
range1: &'a R,
range2: &'a R,
) -> Option<(Bound<&'a T>, Bound<&'a T>)>
where
T: PartialEq + Ord + Debug,
{
let intersection1 = range1.intersect(range2);
let intersection2 = range2.intersect(range1);
assert_eq!(intersection1, intersection2);
intersection1
}
/// A helper for randomly generating either an inclusive or an exclusive bound with a
/// particular value.
fn random_bound<T>(value: T) -> Bound<T> {
if rand::random() {
Bound::Included(value)
} else {
Bound::Excluded(value)
}
}
#[test]
fn no_overlap() {
#[allow(clippy::type_complexity)]
const NON_OVERLAPPING_RANGES: &[(
(Bound<usize>, Bound<usize>),
(Bound<usize>, Bound<usize>),
)] = &[
// (1, 2) and (3, 4)
((Excl(1), Excl(2)), (Excl(3), Excl(4))),
// (1, 2) and (2, 3)
((Excl(1), Excl(2)), (Excl(2), Excl(3))),
// (1, 2) and [2, 3)
((Excl(1), Excl(2)), (Incl(2), Excl(3))),
// (1, 2) and [2, 3]
((Excl(1), Excl(2)), (Incl(3), Incl(4))),
// (-inf, 2) and [2, 3]
((Unbounded, Excl(2)), (Incl(2), Incl(3))),
// (-inf, 2) and (2, inf)
((Unbounded, Excl(2)), (Excl(2), Unbounded)),
// (-inf, 2) and [2, inf)
((Unbounded, Excl(2)), (Incl(2), Unbounded)),
];
for (range1, range2) in NON_OVERLAPPING_RANGES {
let intersection = intersect(range1, range2);
assert!(
intersection.is_none(),
"{:?} and {:?} => {:?}",
range1,
range2,
intersection
);
}
}
#[test]
fn intersect_unbounded_start() {
// (-inf, 3)
let range1 = (Unbounded, Excl(3));
// [2, 5]
let range2 = (Incl(2), Incl(5));
let intersection = intersect(&range1, &range2).unwrap();
// intersection = [2 3]
assert_eq!(intersection.start_bound(), Bound::Included(&2));
assert_eq!(intersection.end_bound(), Bound::Excluded(&3));
}
#[test]
fn intersect_unbounded_end() {
// (8, inf)
let range1 = (Excl(8), Unbounded);
// [8, 20]
let range2 = (Incl(8), Incl(20));
let intersection = intersect(&range1, &range2).unwrap();
// intersection = (8, 20]
assert_eq!(intersection.start_bound(), Bound::Excluded(&8));
assert_eq!(intersection.end_bound(), Bound::Included(&20));
}
#[test]
fn intersect_unbounded_range() {
#[allow(clippy::type_complexity)]
const RANGES: &[(Bound<usize>, Bound<usize>)] = &[
// (1, 2)
(Excl(1), Excl(2)),
// (1, 2]
(Excl(1), Incl(2)),
// [1, 2]
(Incl(1), Incl(2)),
// [1, 2)
(Incl(1), Excl(2)),
// (1, inf)
(Excl(1), Unbounded),
// [1, inf)
(Incl(1), Unbounded),
// (-inf, 2)
(Unbounded, Excl(2)),
// (-inf, 2]
(Unbounded, Incl(2)),
];
// The intersection of any interval I with (Unbounded, Unbounded) will be I.
let range1 = (Unbounded, Unbounded);
for range2 in RANGES {
let range2 = (range2.0.as_ref(), range2.1.as_ref());
assert_eq!(intersect(&range1, &range2).unwrap(), range2);
}
}
#[test]
fn intersect_time_bounds() {
const MIN: Duration = Duration::from_secs(60);
// time (relative to now): 0 1 2 3
// | | | |
// [t1, t2]: [.......]
// [t3, t4]: [.......]
// intersection: [...]
let now = SystemTime::now();
let t1 = now;
let t2 = now + 2 * MIN;
let t3 = now + 1 * MIN;
let t4 = now + 3 * MIN;
let b1 = (Bound::Included(t1), Bound::Included(t2));
let b2 = (Bound::Included(t3), Bound::Included(t4));
let expected = (Bound::Included(&t3), Bound::Included(&t2));
assert_eq!(intersect(&b1, &b2).unwrap(), expected);
// t1 - - t2 - -
// t3 - - t4
//
// time (relative to now): 0 1 2 3 4 5 6 7
// | | | | | | | |
// [t1, t2]: [.......]
// [t3, t4]: [............]
let t3 = now + 4 * MIN;
let t4 = now + 7 * MIN;
let b2 = (Bound::Included(t3), Bound::Included(t4));
assert!(intersect(&b1, &b2).is_none());
}
#[test]
fn combinatorial() {
for i in 0..10 {
for j in 0..10 {
for k in 0..10 {
for l in 0..10 {
let range1 = (random_bound(i), random_bound(j));
let range2 = (random_bound(k), random_bound(l));
let intersection = intersect(&range1, &range2);
for witness in 0..10 {
let c1 = range1.contains(&witness);
let c2 = range2.contains(&witness);
let both_contain_witness = c1 && c2;
if both_contain_witness {
// If both ranges contain `witness` they definitely intersect.
assert!(intersection.unwrap().contains(&witness));
} else if let Some(intersection) = intersection {
// If one of them doesn't contain `witness`, `witness` is
// definitely not part of the intersection.
assert!(!intersection.contains(&witness));
}
}
}
}
}
}
}
}

View File

@ -105,13 +105,9 @@ impl<T> TimerangeBound<T> {
/// The caller takes responsibility for making sure that the bounds are
/// actually checked.
pub fn dangerously_into_parts(self) -> (T, (Bound<time::SystemTime>, Bound<time::SystemTime>)) {
(
self.obj,
(
self.start.map(Bound::Included).unwrap_or(Bound::Unbounded),
self.end.map(Bound::Included).unwrap_or(Bound::Unbounded),
),
)
let bounds = self.bounds();
(self.obj, bounds)
}
/// Return a reference to the inner object of this TimeRangeBound, without
@ -143,6 +139,27 @@ impl<T> TimerangeBound<T> {
{
self.as_ref().dangerously_map(|t| &**t)
}
/// Return the underlying time bounds of this object.
pub fn bounds(&self) -> (Bound<time::SystemTime>, Bound<time::SystemTime>) {
(self.start_bound().cloned(), self.end_bound().cloned())
}
}
impl<T> RangeBounds<time::SystemTime> for TimerangeBound<T> {
fn start_bound(&self) -> Bound<&time::SystemTime> {
self.start
.as_ref()
.map(Bound::Included)
.unwrap_or(Bound::Unbounded)
}
fn end_bound(&self) -> Bound<&time::SystemTime> {
self.end
.as_ref()
.map(Bound::Included)
.unwrap_or(Bound::Unbounded)
}
}
impl<T> crate::Timebound<T> for TimerangeBound<T> {

View File

@ -3,9 +3,7 @@
use std::time::Duration;
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use std::time::SystemTime;
use async_trait::async_trait;
use educe::Educe;
@ -227,7 +225,7 @@ impl<'c, 'd, R: Runtime, M: MocksForConnect<R>> Context<'c, 'd, R, M> {
// https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/1118#note_2894463
let mut attempts = hs_dirs.iter().cycle().take(MAX_TOTAL_ATTEMPTS);
let mut errors = RetryError::in_attempt_to("retrieve hidden service descriptor");
let (desc, bounds) = loop {
let desc = loop {
let relay = match attempts.next() {
Some(relay) => relay,
None => {
@ -264,10 +262,13 @@ impl<'c, 'd, R: Runtime, M: MocksForConnect<R>> Context<'c, 'd, R, M> {
// Store the bounded value in the cache for reuse,
// but return a reference to the unwrapped `HsDesc`.
//
// Because the `HsDesc` must be owned by `data.desc`,
// we must first wrap it in the TimerangeBound,
// The `HsDesc` must be owned by `data.desc`,
// so first add it to `data.desc`,
// and then dangerously_assume_timely to get a reference out again.
let ret = self.data.desc.insert(TimerangeBound::new(desc, bounds));
//
// It is safe to dangerously_assume_timely,
// as descriptor_fetch_attempt has already checked the timeliness of the descriptor.
let ret = self.data.desc.insert(desc);
Ok(ret.as_ref().dangerously_assume_timely())
}
@ -277,12 +278,12 @@ impl<'c, 'd, R: Runtime, M: MocksForConnect<R>> Context<'c, 'd, R, M> {
///
/// On success, returns the descriptor.
///
/// Also returns a `RangeBounds<SystemTime>` which represents the descriptor's validity.
/// (This is separate, because the descriptor's validity at the current time *has* been checked,)
/// While the returned descriptor is `TimerangeBound`, its validity at the current time *has*
/// been checked.
async fn descriptor_fetch_attempt(
&self,
hsdir: &Relay<'_>,
) -> Result<(HsDesc, impl RangeBounds<SystemTime>), DescriptorErrorDetail> {
) -> Result<TimerangeBound<HsDesc>, DescriptorErrorDetail> {
let request = tor_dirclient::request::HsDescDownloadRequest::new(self.hs_blind_id);
trace!(
"hsdir for {}, trying {}/{}, request {:?} (http request {:?}",
@ -330,19 +331,14 @@ impl<'c, 'd, R: Runtime, M: MocksForConnect<R>> Context<'c, 'd, R, M> {
let now = self.runtime.wallclock();
let hsdesc = HsDesc::parse_decrypt_validate(
HsDesc::parse_decrypt_validate(
&desc_text,
&self.hs_blind_id,
now,
&self.subcredential,
hsc_desc_enc.as_ref().map(|(kp, ks)| (kp, *ks)),
)
.map_err(DescriptorErrorDetail::from)?;
let unbounded_todo = Bound::Unbounded::<SystemTime>; // TODO HS remove
let bound = (unbounded_todo, unbounded_todo);
Ok((hsdesc, bound))
.map_err(DescriptorErrorDetail::from)
}
}
@ -454,6 +450,7 @@ mod test {
use super::*;
use crate::*;
use futures::FutureExt as _;
use std::ops::{Bound, RangeBounds};
use std::{iter, panic::AssertUnwindSafe};
use tokio_crate as tokio;
use tor_async_utils::JoinReadWrite;
@ -571,21 +568,20 @@ mod test {
secret_keys_builder.ks_hsc_desc_enc(sk);
let secret_keys = secret_keys_builder.build().unwrap();
let _got = AssertUnwindSafe(
Context::new(
&runtime,
&mocks,
netdir,
hsid,
&mut data,
secret_keys,
mocks.clone(),
)
.unwrap()
.connect(),
let mut ctx = Context::new(
&runtime,
&mocks,
netdir,
hsid,
&mut data,
secret_keys,
mocks.clone(),
)
.catch_unwind() // TODO HS remove this and the AssertUnwindSafe
.await;
.unwrap();
let _got = AssertUnwindSafe(ctx.connect())
.catch_unwind() // TODO HS remove this and the AssertUnwindSafe
.await;
let (hs_blind_id_key, subcredential) = HsIdKey::try_from(hsid)
.unwrap()
@ -602,7 +598,8 @@ mod test {
&subcredential,
Some((&pk, &sk)),
)
.unwrap();
.unwrap()
.dangerously_assume_timely();
let mglobal = mocks.mglobal.lock().unwrap();
assert_eq!(mglobal.hsdirs_asked.len(), 1);
@ -613,6 +610,16 @@ mod test {
format!("{:?}", Some(hsdesc))
);
// Check how long the descriptor is valid for
let bounds = ctx.data.desc.as_ref().unwrap().bounds();
assert_eq!(bounds.start_bound(), Bound::Unbounded);
let desc_valid_until = humantime::parse_rfc3339("2023-02-11T20:00:00Z").unwrap();
assert_eq!(
bounds.end_bound(),
Bound::Included(desc_valid_until).as_ref()
);
// TODO hs check the circuit in got is the one we gave out
}

View File

@ -85,6 +85,7 @@ smallvec = "1.10"
thiserror = "1"
time = { version = "0.3", features = ["std", "parsing", "macros"] }
tinystr = "0.7.0"
tor-basic-utils = { path = "../tor-basic-utils", version = "0.7.0" }
tor-bytes = { path = "../tor-bytes", version = "0.7.0" }
tor-cert = { path = "../tor-cert", version = "0.7.0" }
tor-checkable = { path = "../tor-checkable", version = "0.5.0" }

View File

@ -18,6 +18,8 @@ mod middle;
mod outer;
pub use desc_enc::DecryptionError;
use tor_basic_utils::rangebounds::RangeBoundsExt;
use tor_error::internal;
use crate::{NetdocErrorKind as EK, Result};
@ -262,14 +264,35 @@ impl HsDesc {
valid_at: SystemTime,
subcredential: &Subcredential,
hsc_desc_enc: Option<(&HsClientDescEncKey, &HsClientDescEncSecretKey)>,
) -> Result<Self> {
Self::parse(input, blinded_onion_id)?
.check_signature()?
.check_valid_at(&valid_at)?
.decrypt(subcredential, hsc_desc_enc)?
.check_valid_at(&valid_at)?
.check_signature()
.map_err(|e| e.into())
) -> Result<TimerangeBound<Self>> {
let unchecked_desc = Self::parse(input, blinded_onion_id)?.check_signature()?;
let (inner_desc, new_bounds) = {
// We use is_valid_at and dangerously_into_parts instead of check_valid_at because we
// need the time bounds of the outer layer (for computing the intersection with the
// time bounds of the inner layer).
unchecked_desc.is_valid_at(&valid_at)?;
// It's safe to use dangerously_into_parts() as we've just checked if unchecked_desc is
// valid at the current time
let (unchecked_desc, bounds) = unchecked_desc.dangerously_into_parts();
let inner_timerangebound = unchecked_desc.decrypt(subcredential, hsc_desc_enc)?;
let new_bounds = bounds
.intersect(&inner_timerangebound)
.map(|(b1, b2)| (b1.cloned(), b2.cloned()));
(inner_timerangebound, new_bounds)
};
let hsdesc = inner_desc.check_valid_at(&valid_at)?.check_signature()?;
// If we've reached this point, it means the descriptor is valid at specified time. This
// means the time bounds of the two layers definitely intersect, so new_bounds **must** be
// Some. It is a bug if new_bounds is None.
let new_bounds = new_bounds
.ok_or_else(|| internal!("failed to compute TimerangeBounds for a valid descriptor"))?;
Ok(TimerangeBound::new(hsdesc, new_bounds))
}
}
@ -288,7 +311,7 @@ impl EncryptedHsDesc {
// TODO hs: I'm not sure that taking `hsc_desc_enc` as an argument is correct. Instead, maybe
// we should take a set of keys?
pub fn decrypt(
self,
&self,
subcredential: &Subcredential,
hsc_desc_enc: Option<(&HsClientDescEncKey, &HsClientDescEncSecretKey)>,
) -> Result<TimerangeBound<SignatureGated<HsDesc>>> {