diff --git a/Cargo.lock b/Cargo.lock index 2c933485a..7c595bd65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2637,6 +2637,12 @@ dependencies = [ "webpki 0.21.4", ] +[[package]] +name = "rustversion" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2cc38e8fa666e2de3c4aba7edeb5ffc5246c1c2ed0e3d17e560aeeba736b23f" + [[package]] name = "ryu" version = "1.0.9" @@ -2918,6 +2924,28 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strum" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e96acfc1b70604b8b2f1ffa4c57e59176c7dbb05d556c71ecd2f5498a1dee7f8" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6878079b17446e4d3eba6192bb0a2950d5b14f0ed8424b852310e5a94345d0ef" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "subtle" version = "2.4.1" @@ -3410,6 +3438,7 @@ dependencies = [ "derive_more", "futures", "once_cell", + "strum", "thiserror", ] diff --git a/crates/tor-error/Cargo.toml b/crates/tor-error/Cargo.toml index e66b6d31d..1b05b7b14 100644 --- a/crates/tor-error/Cargo.toml +++ b/crates/tor-error/Cargo.toml @@ -18,6 +18,7 @@ backtrace = { version = "0.3.39", optional = true } derive_more = "0.99" futures = "0.3" once_cell = "1" +strum = { version = "0.24", features = ["derive"] } thiserror = "1" [dev-dependencies] diff --git a/crates/tor-error/src/retriable.rs b/crates/tor-error/src/retriable.rs index 9f68ca486..568252eb8 100644 --- a/crates/tor-error/src/retriable.rs +++ b/crates/tor-error/src/retriable.rs @@ -4,6 +4,7 @@ use std::{ cmp::Ordering, time::{Duration, Instant}, }; +use strum::EnumDiscriminants; /// A description of when an operation may be retried. /// @@ -35,8 +36,12 @@ use std::{ /// traffic to port 23 (telnet), we say that building a request for such a relay /// is not retriable, even though technically such a relay might appear in the /// next consensus. -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, EnumDiscriminants)] #[non_exhaustive] +// We define a discriminant type so we can simplify loose_cmp. +#[strum_discriminants(derive(Ord, PartialOrd))] +// We don't want to expose RetryTimeDiscriminants. +#[strum_discriminants(vis())] pub enum RetryTime { /// The operation can be retried immediately, and no delay is needed. /// @@ -186,32 +191,70 @@ impl RetryTime { /// /// If you need an absolute comparison operator, convert to [`AbsRetryTime`] first. fn loose_cmp(&self, other: &Self) -> Ordering { - use Ordering::*; + use RetryTime as RT; + match (self, other) { - // Immediate precedes everything. - (RetryTime::Immediate, RetryTime::Immediate) => Equal, - (RetryTime::Immediate, _) => Less, - (_, RetryTime::Immediate) => Greater, + // When we have the same type with an internal embedded duration or time, + // we compare based on the duration or time. + (RT::After(d1), RetryTime::After(d2)) => d1.cmp(d2), + (RT::At(t1), RetryTime::At(t2)) => t1.cmp(t2), - // When we have the same type, then we can compare based on actual - // times. - (RetryTime::AfterWaiting, RetryTime::AfterWaiting) => Equal, - (RetryTime::After(d1), RetryTime::After(d2)) => d1.cmp(d2), - (RetryTime::At(t1), RetryTime::At(t2)) => t1.cmp(t2), - - // Otherwise: pretend AfterWaiting is shorter than After, is shorter - // than At. - (RetryTime::AfterWaiting, RetryTime::After(_)) => Less, - (RetryTime::AfterWaiting, RetryTime::At(_)) => Less, - (RetryTime::After(_), RetryTime::AfterWaiting) => Greater, - (RetryTime::After(_), RetryTime::At(_)) => Less, - (RetryTime::At(_), RetryTime::AfterWaiting) => Greater, - (RetryTime::At(_), RetryTime::After(_)) => Greater, - - // Everything precedes Never. - (RetryTime::Never, RetryTime::Never) => Equal, - (RetryTime::Never, _) => Greater, - (_, RetryTime::Never) => Less, + // Otherwise, we compare based on discriminant type. + // + // This can't do a perfect "apples-to-apples" comparison for + // `AfterWaiting` vs `At` vs `After`, but at least it imposes a + // total order. + (a, b) => RetryTimeDiscriminants::from(a).cmp(&RetryTimeDiscriminants::from(b)), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn comparison() { + use RetryTime as RT; + let sec = Duration::from_secs(1); + let now = Instant::now(); + + let sorted = vec![ + RT::Immediate, + RT::AfterWaiting, + RT::After(sec * 10), + RT::After(sec * 20), + RT::At(now), + RT::At(now + sec * 30), + RT::Never, + ]; + + // Verify that these objects are actually in loose-cmp sorted order. + for (i, a) in sorted.iter().enumerate() { + for (j, b) in sorted.iter().enumerate() { + assert_eq!(a.loose_cmp(b), i.cmp(&j)); + } + } + } + + #[test] + fn abs_comparison() { + use AbsRetryTime as ART; + let sec = Duration::from_secs(1); + let now = Instant::now(); + + let sorted = vec![ + ART::Immediate, + ART::At(now), + ART::At(now + sec * 30), + ART::Never, + ]; + + // Verify that these objects are actually in loose-cmp sorted order. + for (i, a) in sorted.iter().enumerate() { + for (j, b) in sorted.iter().enumerate() { + assert_eq!(a.cmp(b), i.cmp(&j)); + } } } }