diff --git a/Cargo.lock b/Cargo.lock index b4f3dc78a..d9702f99c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4335,7 +4335,6 @@ dependencies = [ "tempfile", "thiserror", "time", - "tokio", "tor-basic-utils", "tor-checkable", "tor-circmgr", @@ -4823,16 +4822,23 @@ dependencies = [ name = "tor-rtmock" version = "0.8.2" dependencies = [ + "amplify", "async-trait", + "educe", "futures", "futures-await-test", "humantime 2.1.0", + "itertools 0.11.0", "pin-project", "rand 0.8.5", + "slotmap", + "strum", "thiserror", "tor-basic-utils", "tor-rtcompat", "tracing", + "tracing-test", + "void", ] [[package]] diff --git a/crates/tor-dirmgr/Cargo.toml b/crates/tor-dirmgr/Cargo.toml index 8ed6b0305..82ef7f768 100644 --- a/crates/tor-dirmgr/Cargo.toml +++ b/crates/tor-dirmgr/Cargo.toml @@ -102,7 +102,6 @@ anyhow = "1.0.23" float_eq = "1.0.0" hex-literal = "0.4" tempfile = "3" -tokio = { version = "1.7", features = ["full"] } tor-linkspec = { path = "../tor-linkspec", version = "0.8.1" } tor-rtcompat = { path = "../tor-rtcompat", version = "0.9.1", features = ["tokio", "native-tls"] } tor-rtmock = { path = "../tor-rtmock", version = "0.8.1" } diff --git a/crates/tor-dirmgr/src/bridgedesc/bdtest.rs b/crates/tor-dirmgr/src/bridgedesc/bdtest.rs index 717e5e84f..8a378d8d9 100644 --- a/crates/tor-dirmgr/src/bridgedesc/bdtest.rs +++ b/crates/tor-dirmgr/src/bridgedesc/bdtest.rs @@ -27,7 +27,7 @@ use tracing_test::traced_test; use tor_linkspec::HasAddrs; use tor_rtcompat::SleepProvider; use tor_rtmock::time::MockSleepProvider; -use tor_rtmock::MockSleepRuntime; +use tor_rtmock::MockRuntime; use super::*; @@ -49,8 +49,7 @@ fn example_wallclock() -> SystemTime { example_validity().0 + Duration::from_secs(10) } -type RealRuntime = tor_rtcompat::tokio::TokioNativeTlsRuntime; -type R = MockSleepRuntime; +type R = MockRuntime; type M = Mock; type Bdm = BridgeDescMgr; type RT = RetryTime; @@ -126,11 +125,8 @@ impl Mock { } } -fn setup() -> (TempDir, Bdm, R, M, BridgeKey, rusqlite::Connection) { - let runtime = RealRuntime::current().unwrap(); - let runtime = MockSleepRuntime::new(runtime); +fn setup(runtime: MockRuntime) -> (TempDir, Bdm, R, M, BridgeKey, rusqlite::Connection) { let sleep = runtime.mock_sleep().clone(); - sleep.jump_to(example_wallclock()); let mut docs = HashMap::new(); @@ -228,380 +224,379 @@ fn bad_bridge(i: usize) -> BridgeKey { bad } -#[tokio::test] #[traced_test] -async fn success() -> Result<(), anyhow::Error> { - let (_db_tmp_dir, bdm, runtime, mock, bridge, ..) = setup(); +#[test] +fn success() -> Result<(), anyhow::Error> { + MockRuntime::try_test_with_various(|runtime| async { + let (_db_tmp_dir, bdm, runtime, mock, bridge, ..) = setup(runtime); - bdm.check_consistency(Some([])); + bdm.check_consistency(Some([])); - let mut events = bdm.events().fuse(); + let mut events = bdm.events().fuse(); - eprintln!("----- test downloading one descriptor -----"); + eprintln!("----- test downloading one descriptor -----"); - stream_drain_ready(&mut events).await; + stream_drain_ready(&mut events).await; - let hold = mock.mstate.lock().await; + let hold = mock.mstate.lock().await; - bdm.set_bridges(&[bridge.clone()]); - bdm.check_consistency(Some([&bridge])); + bdm.set_bridges(&[bridge.clone()]); + bdm.check_consistency(Some([&bridge])); - drop(hold); + drop(hold); - let got = stream_drain_until(3, &mut events, || async { - bdm.bridges().get(&bridge).cloned() - }) - .await; - - dbg!(runtime.wallclock(), example_validity(),); - - eprintln!("got: {:?}", got.unwrap()); - - bdm.check_consistency(Some([&bridge])); - mock.expect_download_calls(1).await; - - eprintln!("----- add a number of failing descriptors -----"); - - const NFAIL: usize = 6; - - let bad = (1..=NFAIL).map(bad_bridge).collect_vec(); - - let mut bridges = chain!(iter::once(bridge.clone()), bad.iter().cloned(),).collect_vec(); - - let hold = mock.mstate.lock().await; - - bdm.set_bridges(&bridges); - bdm.check_consistency(Some(&bridges)); - - drop(hold); - - let () = stream_drain_until(13, &mut events, || async { - bdm.check_consistency(Some(&bridges)); - bridges - .iter() - .all(|b| bdm.bridges().contains_key(b)) - .then_some(()) - }) - .await; - - for b in &bad { - bdm.bridges().get(b).unwrap().as_ref().unwrap_err(); - } - - bdm.check_consistency(Some(&bridges)); - mock.expect_download_calls(NFAIL).await; - - eprintln!("----- move the clock forward to do some retries ----------"); - - mock.sleep.advance(Duration::from_secs(5000)).await; - - bdm.check_consistency(Some(&bridges)); - - let () = stream_drain_until(13, &mut events, || async { - bdm.check_consistency(Some(&bridges)); - (mock.mstate.lock().await.download_calls == NFAIL).then_some(()) - }) - .await; - - stream_drain_ready(&mut events).await; - - bdm.check_consistency(Some(&bridges)); - mock.expect_download_calls(NFAIL).await; - - eprintln!("----- set the bridges to the ones we have already ----------"); - - let hold = mock.mstate.lock().await; - - bdm.set_bridges(&bridges); - bdm.check_consistency(Some(&bridges)); - - drop(hold); - - let events_counted = stream_drain_ready(&mut events).await; - assert_eq!(events_counted, 0); - bdm.check_consistency(Some(&bridges)); - mock.expect_download_calls(0).await; - - eprintln!("----- set the bridges to one fewer than we have already ----------"); - - let _ = bridges.pop().unwrap(); - - let hold = mock.mstate.lock().await; - - bdm.set_bridges(&bridges); - bdm.check_consistency(Some(&bridges)); - - drop(hold); - - let events_counted = stream_drain_ready(&mut events).await; - assert_eq!(events_counted, 1); - bdm.check_consistency(Some(&bridges)); - mock.expect_download_calls(0).await; - - eprintln!("----- remove a bridge while we have some requeued ----------"); - - let hold = mock.mstate.lock().await; - - mock.sleep.advance(Duration::from_secs(8000)).await; - bdm.check_consistency(Some(&bridges)); - - // should yield, but not produce any events yet - let count = stream_drain_ready(&mut events).await; - assert_eq!(count, 0); - bdm.check_consistency(Some(&bridges)); - - let removed = bridges.pop().unwrap(); - bdm.set_bridges(&bridges); - - // should produce a removed bridge event - let () = stream_drain_until(1, &mut events, || async { - bdm.check_consistency(Some(&bridges)); - (!bdm.bridges().contains_key(&removed)).then_some(()) - }) - .await; - - drop(hold); - - // should produce a removed bridge event - let () = stream_drain_until(1, &mut events, || async { - bdm.check_consistency(Some(&bridges)); - queues_are_empty(&bdm) - }) - .await; - - { - // When we cancel the download, we race with the manager. - // Maybe the download for the one we removed was started, or maybe not. - let mut mstate = mock.mstate.lock().await; - assert!( - ((NFAIL - 1)..=NFAIL).contains(&mstate.download_calls), - "{:?}", - mstate.download_calls - ); - mstate.download_calls = 0; - } - - Ok(()) -} - -#[tokio::test] -#[traced_test] -async fn cache() -> Result<(), anyhow::Error> { - let (_db_tmp_path, bdm, runtime, mock, bridge, sql_conn, ..) = setup(); - let mut events = bdm.events().fuse(); - - let in_results = |wanted| in_results(&bdm, &bridge, wanted); - - eprintln!("----- test that a downloaded descriptor goes into the cache -----"); - - bdm.set_bridges(&[bridge.clone()]); - stream_drain_until(3, &mut events, || async { in_results(Some(Ok(()))) }).await; - - mock.expect_download_calls(1).await; - - sql_conn - .query_row("SELECT * FROM BridgeDescs", [], |row| { - let get_time = |f| -> SystemTime { row.get_unwrap::<&str, OffsetDateTime>(f).into() }; - let bline: String = row.get_unwrap("bridge_line"); - let fetched: SystemTime = get_time("fetched"); - let until: SystemTime = get_time("until"); - let contents: String = row.get_unwrap("contents"); - let now = runtime.wallclock(); - assert_eq!(bline, bridge.to_string()); - assert!(fetched <= now); - assert!(now < until); - assert_eq!(contents, EXAMPLE_DESCRIPTOR); - Ok(()) + let got = stream_drain_until(3, &mut events, || async { + bdm.bridges().get(&bridge).cloned() }) - .unwrap(); + .await; - eprintln!("----- forget the descriptor and try to reload it from the cache -----"); + dbg!(runtime.wallclock(), example_validity(),); - clear_and_re_request(&bdm, &mut events, &bridge).await; - stream_drain_until(3, &mut events, || async { in_results(Some(Ok(()))) }).await; + eprintln!("got: {:?}", got.unwrap()); - // Should not have been re-downloaded, since the fetch time is great. - mock.expect_download_calls(0).await; + bdm.check_consistency(Some([&bridge])); + mock.expect_download_calls(1).await; - eprintln!("----- corrupt the cache and check we re-download -----"); + eprintln!("----- add a number of failing descriptors -----"); - sql_conn - .execute_batch("UPDATE BridgeDescs SET contents = 'garbage'") - .unwrap(); + const NFAIL: usize = 6; - clear_and_re_request(&bdm, &mut events, &bridge).await; - stream_drain_until(3, &mut events, || async { in_results(Some(Ok(()))) }).await; + let bad = (1..=NFAIL).map(bad_bridge).collect_vec(); - mock.expect_download_calls(1).await; + let mut bridges = chain!(iter::once(bridge.clone()), bad.iter().cloned(),).collect_vec(); - eprintln!("----- advance the lock and check that we do an if-modified-since -----"); + let hold = mock.mstate.lock().await; - let published = bdm - .bridges() - .get(&bridge) - .unwrap() - .as_ref() - .unwrap() - .as_ref() - .published(); + bdm.set_bridges(&bridges); + bdm.check_consistency(Some(&bridges)); - mock.mstate.lock().await.docs.insert( - EXAMPLE_PORT, - Ok(format!("{}{:?}", MOCK_NOT_MODIFIED, published)), - ); + drop(hold); - // Exceeds default max_refetch - mock.sleep.advance(Duration::from_secs(20000)).await; + let () = stream_drain_until(13, &mut events, || async { + bdm.check_consistency(Some(&bridges)); + bridges + .iter() + .all(|b| bdm.bridges().contains_key(b)) + .then_some(()) + }) + .await; - stream_drain_until(3, &mut events, || async { - (mock.mstate.lock().await.download_calls > 0).then_some(()) + for b in &bad { + bdm.bridges().get(b).unwrap().as_ref().unwrap_err(); + } + + bdm.check_consistency(Some(&bridges)); + mock.expect_download_calls(NFAIL).await; + + eprintln!("----- move the clock forward to do some retries ----------"); + + mock.sleep.advance(Duration::from_secs(5000)).await; + + bdm.check_consistency(Some(&bridges)); + + let () = stream_drain_until(13, &mut events, || async { + bdm.check_consistency(Some(&bridges)); + (mock.mstate.lock().await.download_calls == NFAIL).then_some(()) + }) + .await; + + stream_drain_ready(&mut events).await; + + bdm.check_consistency(Some(&bridges)); + mock.expect_download_calls(NFAIL).await; + + eprintln!("----- set the bridges to the ones we have already ----------"); + + let hold = mock.mstate.lock().await; + + bdm.set_bridges(&bridges); + bdm.check_consistency(Some(&bridges)); + + drop(hold); + + let events_counted = stream_drain_ready(&mut events).await; + assert_eq!(events_counted, 0); + bdm.check_consistency(Some(&bridges)); + mock.expect_download_calls(0).await; + + eprintln!("----- set the bridges to one fewer than we have already ----------"); + + let _ = bridges.pop().unwrap(); + + let hold = mock.mstate.lock().await; + + bdm.set_bridges(&bridges); + bdm.check_consistency(Some(&bridges)); + + drop(hold); + + let events_counted = stream_drain_ready(&mut events).await; + assert_eq!(events_counted, 1); + bdm.check_consistency(Some(&bridges)); + mock.expect_download_calls(0).await; + + eprintln!("----- remove a bridge while we have some requeued ----------"); + + let hold = mock.mstate.lock().await; + + mock.sleep.advance(Duration::from_secs(8000)).await; + bdm.check_consistency(Some(&bridges)); + + // should yield, but not produce any events yet + let count = stream_drain_ready(&mut events).await; + assert_eq!(count, 0); + bdm.check_consistency(Some(&bridges)); + + let removed = bridges.pop().unwrap(); + bdm.set_bridges(&bridges); + + // should produce a removed bridge event + let () = stream_drain_until(1, &mut events, || async { + bdm.check_consistency(Some(&bridges)); + (!bdm.bridges().contains_key(&removed)).then_some(()) + }) + .await; + + drop(hold); + + // Check that queues become empty. + // Depending on scheduling, there may be tasks still live from the work above. + // For example, one of the requeues might be still running after we did the remove. + // So we may get a number of change events. Certainly not more than 10. + let () = stream_drain_until(10, &mut events, || async { + bdm.check_consistency(Some(&bridges)); + queues_are_empty(&bdm) + }) + .await; + + { + // When we cancel the download, we race with the manager. + // Maybe the download for the one we removed was started, or maybe not. + let mut mstate = mock.mstate.lock().await; + assert!( + ((NFAIL - 1)..=NFAIL).contains(&mstate.download_calls), + "{:?}", + mstate.download_calls + ); + mstate.download_calls = 0; + } + + Ok(()) }) - .await; - - mock.expect_download_calls(1).await; - - Ok(()) } -#[tokio::test] #[traced_test] -async fn dormant() -> Result<(), anyhow::Error> { - #[allow(unused_variables)] // avoids churn and makes all of these identical - let (db_tmp_path, bdm, runtime, mock, bridge, sql_conn, ..) = setup(); - let mut events = bdm.events().fuse(); +#[test] +fn cache() -> Result<(), anyhow::Error> { + MockRuntime::try_test_with_various(|runtime| async { + let (_db_tmp_path, bdm, runtime, mock, bridge, sql_conn, ..) = setup(runtime); + let mut events = bdm.events().fuse(); - use Dormancy::*; + let in_results = |wanted| in_results(&bdm, &bridge, wanted); - eprintln!("----- become dormant, but request a bridge -----"); - bdm.set_dormancy(Dormant); - bdm.set_bridges(&[bridge.clone()]); + eprintln!("----- test that a downloaded descriptor goes into the cache -----"); - // TODO async wait for idle: - // - // This is a bodge. What we really want to do is drive all tasks until we are idle. - // But Tokio does not provide this facility, AFAICT. I also checked smol and - // async-std, and did a moderately thorough search using lib.rs. I think the proper - // approach has to be a custom executor. (`tor_rtmock::MockSleepRuntime::wait_for` - // doesn't work because it doesn't track, and therefore doesn't progress, spawned tasks.) - // - // Instead, we do this: this is real time, not mock time. That ought to let - // everything that is going to run, do so. (I have verified that this test fails - // before dormancy is actually implemented.) If the 10ms we have here is too short - // (eg by random chance) then we might miss a situation where the dormancy is not - // properly effective, but we oughtn't to have a flaky test with good code, since "no - // progress was made within 10ms" is the expected behaviour. - tokio::time::sleep(Duration::from_millis(10)).await; - mock.expect_download_calls(0).await; + bdm.set_bridges(&[bridge.clone()]); + stream_drain_until(3, &mut events, || async { in_results(Some(Ok(()))) }).await; - eprintln!("----- become active -----"); - bdm.set_dormancy(Active); - // This should immediately trigger the download: + mock.expect_download_calls(1).await; - stream_drain_until(3, &mut events, || async { - in_results(&bdm, &bridge, Some(Ok(()))) - }) - .await; - mock.expect_download_calls(1).await; + sql_conn + .query_row("SELECT * FROM BridgeDescs", [], |row| { + let get_time = + |f| -> SystemTime { row.get_unwrap::<&str, OffsetDateTime>(f).into() }; + let bline: String = row.get_unwrap("bridge_line"); + let fetched: SystemTime = get_time("fetched"); + let until: SystemTime = get_time("until"); + let contents: String = row.get_unwrap("contents"); + let now = runtime.wallclock(); + assert_eq!(bline, bridge.to_string()); + assert!(fetched <= now); + assert!(now < until); + assert_eq!(contents, EXAMPLE_DESCRIPTOR); + Ok(()) + }) + .unwrap(); - Ok(()) -} + eprintln!("----- forget the descriptor and try to reload it from the cache -----"); -#[tokio::test] -async fn process_doc() -> Result<(), anyhow::Error> { - #[allow(unused_variables)] // avoids churn and makes all of these identical - let (db_tmp_path, bdm, runtime, mock, bridge, sql_conn, ..) = setup(); + clear_and_re_request(&bdm, &mut events, &bridge).await; + stream_drain_until(3, &mut events, || async { in_results(Some(Ok(()))) }).await; - let text = EXAMPLE_DESCRIPTOR; - let config = BridgeDescDownloadConfig::default(); - let valid = example_validity(); + // Should not have been re-downloaded, since the fetch time is great. + mock.expect_download_calls(0).await; - let pr_t = |s: &str, t: SystemTime| { - let now = runtime.wallclock(); - eprintln!( - " {:10} {:?} {:10}", - s, - t, - t.duration_since(UNIX_EPOCH).unwrap().as_secs_f64() - - now.duration_since(UNIX_EPOCH).unwrap().as_secs_f64(), + eprintln!("----- corrupt the cache and check we re-download -----"); + + sql_conn + .execute_batch("UPDATE BridgeDescs SET contents = 'garbage'") + .unwrap(); + + clear_and_re_request(&bdm, &mut events, &bridge).await; + stream_drain_until(3, &mut events, || async { in_results(Some(Ok(()))) }).await; + + mock.expect_download_calls(1).await; + + eprintln!("----- advance the lock and check that we do an if-modified-since -----"); + + let published = bdm + .bridges() + .get(&bridge) + .unwrap() + .as_ref() + .unwrap() + .as_ref() + .published(); + + mock.mstate.lock().await.docs.insert( + EXAMPLE_PORT, + Ok(format!("{}{:?}", MOCK_NOT_MODIFIED, published)), ); - }; - let expecting_of = |text: &str, exp: Result| { - let got = process_document(&runtime, &config, text); - match exp { - Ok(exp_refetch) => { - let refetch = got.unwrap().refetch; - pr_t("refetch", refetch); - assert_eq!(refetch, exp_refetch); - } - Err(exp_msg) => { - let msg = got.as_ref().expect_err(exp_msg).to_string(); - assert!( - msg.contains(exp_msg), - "{:?} {:?} exp={:?}", - msg, - got, - exp_msg - ); - } - } - }; + // Exceeds default max_refetch + mock.sleep.advance(Duration::from_secs(20000)).await; - let expecting_at = |now: SystemTime, exp| { - mock.sleep.jump_to(now); - pr_t("now", now); - pr_t("valid.0", valid.0); - pr_t("valid.1", valid.1); - if let Ok(exp) = exp { - pr_t("expect", exp); - } - expecting_of(text, exp); - }; + stream_drain_until(3, &mut events, || async { + (mock.mstate.lock().await.download_calls > 0).then_some(()) + }) + .await; - let secs = Duration::from_secs; + mock.expect_download_calls(1).await; - eprintln!("----- good -----"); - expecting_of(text, Ok(runtime.wallclock() + config.max_refetch)); - - eprintln!("----- modified under signature -----"); - expecting_of( - &text.replace("\nbandwidth 10485760", "\nbandwidth 10485761"), - Err("Signature check failed"), - ); - - eprintln!("----- doc not yet valid -----"); - expecting_at( - valid.0 - secs(10), - Err("Descriptor is outside its validity time"), - ); - - eprintln!("----- need to refetch due to doc validity expiring soon -----"); - expecting_at(valid.1 - secs(5000), Ok(valid.1 - secs(1000))); - - eprintln!("----- will refetch later than usual, due to min refetch interval -----"); - { - let now = valid.1 - secs(4000); // would want to refetch at valid.1-1000 ie 30000 - expecting_at(now, Ok(now + config.min_refetch)); - } - - eprintln!("----- will refetch after doc validity ends, due to min refetch interval -----"); - { - let now = valid.1 - secs(10); - let exp = now + config.min_refetch; - assert!(exp > valid.1); - expecting_at(now, Ok(exp)); - } - - eprintln!("----- expired -----"); - expecting_at( - valid.1 + secs(10), - Err("Descriptor is outside its validity time"), - ); - - // TODO ideally we would test the `ops::Bound::Unbounded` case in process_download's - // expiry time handling, but that would require making a document with unbounded - // validity time. Even if that is possible, I don't think we have code in-tree to - // make signed test documents. - - Ok(()) + Ok(()) + }) +} + +#[traced_test] +#[test] +fn dormant() -> Result<(), anyhow::Error> { + MockRuntime::try_test_with_various(|runtime| async { + #[allow(unused_variables)] // avoids churn and makes all of these identical + let (db_tmp_path, bdm, runtime, mock, bridge, sql_conn, ..) = setup(runtime); + let mut events = bdm.events().fuse(); + + use Dormancy::*; + + eprintln!("----- become dormant, but request a bridge -----"); + bdm.set_dormancy(Dormant); + bdm.set_bridges(&[bridge.clone()]); + + // Drive all tasks until we are idle + runtime.progress_until_stalled().await; + + eprintln!("----- become active -----"); + bdm.set_dormancy(Active); + // This should immediately trigger the download: + + stream_drain_until(3, &mut events, || async { + in_results(&bdm, &bridge, Some(Ok(()))) + }) + .await; + mock.expect_download_calls(1).await; + + Ok(()) + }) +} + +#[traced_test] +#[test] +fn process_doc() -> Result<(), anyhow::Error> { + MockRuntime::try_test_with_various(|runtime| async { + #[allow(unused_variables)] // avoids churn and makes all of these identical + let (db_tmp_path, bdm, runtime, mock, bridge, sql_conn, ..) = setup(runtime); + + let text = EXAMPLE_DESCRIPTOR; + let config = BridgeDescDownloadConfig::default(); + let valid = example_validity(); + + let pr_t = |s: &str, t: SystemTime| { + let now = runtime.wallclock(); + eprintln!( + " {:10} {:?} {:10}", + s, + t, + t.duration_since(UNIX_EPOCH).unwrap().as_secs_f64() + - now.duration_since(UNIX_EPOCH).unwrap().as_secs_f64(), + ); + }; + + let expecting_of = |text: &str, exp: Result| { + let got = process_document(&runtime, &config, text); + match exp { + Ok(exp_refetch) => { + let refetch = got.unwrap().refetch; + pr_t("refetch", refetch); + assert_eq!(refetch, exp_refetch); + } + Err(exp_msg) => { + let msg = got.as_ref().expect_err(exp_msg).to_string(); + assert!( + msg.contains(exp_msg), + "{:?} {:?} exp={:?}", + msg, + got, + exp_msg + ); + } + } + }; + + let expecting_at = |now: SystemTime, exp| { + mock.sleep.jump_to(now); + pr_t("now", now); + pr_t("valid.0", valid.0); + pr_t("valid.1", valid.1); + if let Ok(exp) = exp { + pr_t("expect", exp); + } + expecting_of(text, exp); + }; + + let secs = Duration::from_secs; + + eprintln!("----- good -----"); + expecting_of(text, Ok(runtime.wallclock() + config.max_refetch)); + + eprintln!("----- modified under signature -----"); + expecting_of( + &text.replace("\nbandwidth 10485760", "\nbandwidth 10485761"), + Err("Signature check failed"), + ); + + eprintln!("----- doc not yet valid -----"); + expecting_at( + valid.0 - secs(10), + Err("Descriptor is outside its validity time"), + ); + + eprintln!("----- need to refetch due to doc validity expiring soon -----"); + expecting_at(valid.1 - secs(5000), Ok(valid.1 - secs(1000))); + + eprintln!("----- will refetch later than usual, due to min refetch interval -----"); + { + let now = valid.1 - secs(4000); // would want to refetch at valid.1-1000 ie 30000 + expecting_at(now, Ok(now + config.min_refetch)); + } + + eprintln!("----- will refetch after doc validity ends, due to min refetch interval -----"); + { + let now = valid.1 - secs(10); + let exp = now + config.min_refetch; + assert!(exp > valid.1); + expecting_at(now, Ok(exp)); + } + + eprintln!("----- expired -----"); + expecting_at( + valid.1 + secs(10), + Err("Descriptor is outside its validity time"), + ); + + // TODO ideally we would test the `ops::Bound::Unbounded` case in process_download's + // expiry time handling, but that would require making a document with unbounded + // validity time. Even if that is possible, I don't think we have code in-tree to + // make signed test documents. + + Ok(()) + }) } diff --git a/crates/tor-hsclient/src/state.rs b/crates/tor-hsclient/src/state.rs index 8a1564688..ea178a7b3 100644 --- a/crates/tor-hsclient/src/state.rs +++ b/crates/tor-hsclient/src/state.rs @@ -733,7 +733,7 @@ pub(crate) mod test { use tokio::pin; use tokio_crate as tokio; use tor_rtcompat::{test_with_one_runtime, SleepProvider}; - use tor_rtmock::MockSleepRuntime; + use tor_rtmock::MockRuntime; use tracing_test::traced_test; use ConnError as E; @@ -931,14 +931,7 @@ pub(crate) mod test { #[test] #[traced_test] fn expiry() { - test_with_one_runtime!(|outer_runtime| async move { - let runtime = MockSleepRuntime::new(outer_runtime.clone()); - - // We sleep this actual amount, with the real runtime, when we want to yield - // for long enough for some other task to do whatever it needs to. - // This represents an actual delay to the real test run. - const BODGE_YIELD: Duration = Duration::from_millis(125); - + MockRuntime::test_with_various(|runtime| async move { // This is the amount by which we adjust clock advances to make sure we // hit more or less than a particular value, to avoid edge cases and // cope with real time advancing too. @@ -950,13 +943,12 @@ pub(crate) mod test { let advance = |duration| { let hsconn = hsconn.clone(); let runtime = &runtime; - let outer_runtime = &outer_runtime; async move { // let expiry task get going and choose its expiry (wakeup) time - outer_runtime.sleep(BODGE_YIELD).await; + runtime.progress_until_stalled().await; runtime.advance(duration).await; // let expiry task run - outer_runtime.sleep(BODGE_YIELD).await; + runtime.progress_until_stalled().await; hsconn.services().unwrap().run_housekeeping(runtime.now()); } }; diff --git a/crates/tor-rtmock/Cargo.toml b/crates/tor-rtmock/Cargo.toml index 99d498f06..da6637146 100644 --- a/crates/tor-rtmock/Cargo.toml +++ b/crates/tor-rtmock/Cargo.toml @@ -12,13 +12,20 @@ categories = ["asynchronous"] repository = "https://gitlab.torproject.org/tpo/core/arti.git/" [dependencies] +amplify = { version = "4", default-features = false, features = ["derive"] } async-trait = "0.1.54" +educe = "0.4.6" futures = "0.3.14" humantime = "2" +itertools = "0.11.0" pin-project = "1" +slotmap = "1.0.6" +strum = { version = "0.25", features = ["derive"] } thiserror = "1" tor-rtcompat = { version = "0.9.1", path = "../tor-rtcompat" } tracing = "0.1.36" +tracing-test = "0.2" +void = "1" [dev-dependencies] futures-await-test = "0.3.0" diff --git a/crates/tor-rtmock/semver.md b/crates/tor-rtmock/semver.md new file mode 100644 index 000000000..e4ff48cc3 --- /dev/null +++ b/crates/tor-rtmock/semver.md @@ -0,0 +1,3 @@ +ADDED: `MockExecutor`, `MockRuntime` +ADDED: `Default` impls for many types including several `Provider`s +BREAKING: `MockNet*` use an always-failing UDP stub, not unmocked system UDP diff --git a/crates/tor-rtmock/src/lib.rs b/crates/tor-rtmock/src/lib.rs index 39345eae5..4d356422c 100644 --- a/crates/tor-rtmock/src/lib.rs +++ b/crates/tor-rtmock/src/lib.rs @@ -42,12 +42,18 @@ extern crate core; +#[macro_use] +mod util; + pub mod io; pub mod net; +pub mod task; pub mod time; mod net_runtime; +mod runtime; mod sleep_runtime; pub use net_runtime::MockNetRuntime; +pub use runtime::MockRuntime; pub use sleep_runtime::MockSleepRuntime; diff --git a/crates/tor-rtmock/src/net.rs b/crates/tor-rtmock/src/net.rs index 19c95d49d..96c4f6fc7 100644 --- a/crates/tor-rtmock/src/net.rs +++ b/crates/tor-rtmock/src/net.rs @@ -10,6 +10,7 @@ use super::MockNetRuntime; use core::fmt; use tor_rtcompat::tls::TlsConnector; use tor_rtcompat::{CertifiedConn, Runtime, TcpListener, TcpProvider, TlsProvider}; +use tor_rtcompat::{UdpProvider, UdpSocket}; use async_trait::async_trait; use futures::channel::mpsc; @@ -20,13 +21,14 @@ use futures::stream::{Stream, StreamExt}; use futures::FutureExt; use std::collections::HashMap; use std::fmt::Formatter; -use std::io::{Error as IoError, ErrorKind, Result as IoResult}; +use std::io::{self, Error as IoError, ErrorKind, Result as IoResult}; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use thiserror::Error; +use void::Void; /// A channel sender that we use to send incoming connections to /// listeners. @@ -40,6 +42,7 @@ type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>; /// are implemented using [`LocalStream`]. The MockNetwork object is /// shared by a large set of MockNetworkProviders, each of which has /// its own view of its address(es) on the network. +#[derive(Default)] pub struct MockNetwork { /// A map from address to the entries about listeners there. listening: Mutex>, @@ -87,6 +90,12 @@ enum AddrBehavior { /// We don't do the right thing (block) if there is a listener that /// never calls accept. /// +/// UDP is completely broken: +/// datagrams appear to be transmitted, but will never be received. +/// And local address assignment is not implemented +/// so [`.local_addr()`](UdpSocket::local_addr) can return `NONE` +// TODO MOCK UDP: Documentation does describe the brokennesses +/// /// We use a simple `u16` counter to decide what arbitrary port /// numbers to use: Once that counter is exhausted, we will fail with /// an assertion. We don't do anything to prevent those arbitrary @@ -149,12 +158,16 @@ pub struct ProviderBuilder { net: Arc, } +impl Default for MockNetProvider { + fn default() -> Self { + Arc::new(MockNetwork::default()).builder().provider() + } +} + impl MockNetwork { /// Make a new MockNetwork with no active listeners. pub fn new() -> Arc { - Arc::new(MockNetwork { - listening: Mutex::new(HashMap::new()), - }) + Default::default() } /// Return a [`ProviderBuilder`] for creating a [`MockNetProvider`] @@ -299,6 +312,43 @@ impl Stream for MockNetListener { } } +/// A very poor imitation of a UDP socket +#[derive(Debug)] +#[non_exhaustive] +pub struct MockUdpSocket { + /// This is uninhabited. + /// + /// To implement UDP support, implement `.bind()`, and abolish this field, + /// replacing it with the actual implementation. + void: Void, +} + +#[async_trait] +impl UdpProvider for MockNetProvider { + type UdpSocket = MockUdpSocket; + + async fn bind(&self, addr: &SocketAddr) -> IoResult { + let _ = addr; // MockNetProvider UDP is not implemented + Err(io::ErrorKind::Unsupported.into()) + } +} + +#[async_trait] +impl UdpSocket for MockUdpSocket { + async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> { + // This tuple idiom avoids unused variable warnings. + // An alternative would be to write _buf, but then when this is implemented, + // and the void::unreachable call removed, we actually *want* those warnings. + void::unreachable((self.void, buf).0) + } + async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult { + void::unreachable((self.void, buf, target).0) + } + fn local_addr(&self) -> IoResult { + void::unreachable(self.void) + } +} + impl MockNetProvider { /// If we have a local addresses that is in the same family as `other`, /// return it. diff --git a/crates/tor-rtmock/src/net_runtime.rs b/crates/tor-rtmock/src/net_runtime.rs index 0f9c9cdc6..314a05c1d 100644 --- a/crates/tor-rtmock/src/net_runtime.rs +++ b/crates/tor-rtmock/src/net_runtime.rs @@ -3,16 +3,9 @@ // TODO(nickm): This is mostly copy-paste from MockSleepRuntime. If possible, // we should make it so that more code is more shared. -use crate::net::MockNetProvider; -use tor_rtcompat::{BlockOn, Runtime, SleepProvider, TcpProvider, TlsProvider, UdpProvider}; +use crate::util::impl_runtime_prelude::*; -use crate::io::LocalStream; -use async_trait::async_trait; -use futures::task::{FutureObj, Spawn, SpawnError}; -use futures::Future; -use std::io::Result as IoResult; -use std::net::SocketAddr; -use std::time::{Duration, Instant, SystemTime}; +use crate::net::MockNetProvider; /// A wrapper Runtime that overrides the SleepProvider trait for the /// underlying runtime. @@ -42,59 +35,10 @@ impl MockNetRuntime { } } -impl Spawn for MockNetRuntime { - fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> { - self.runtime.spawn_obj(future) - } -} - -impl BlockOn for MockNetRuntime { - fn block_on(&self, future: F) -> F::Output { - self.runtime.block_on(future) - } -} - -#[async_trait] -impl TcpProvider for MockNetRuntime { - type TcpStream = ::TcpStream; - type TcpListener = ::TcpListener; - - async fn connect(&self, addr: &SocketAddr) -> IoResult { - self.net.connect(addr).await - } - async fn listen(&self, addr: &SocketAddr) -> IoResult { - self.net.listen(addr).await - } -} - -impl TlsProvider for MockNetRuntime { - type Connector = >::Connector; - type TlsStream = >::TlsStream; - fn tls_connector(&self) -> Self::Connector { - self.net.tls_connector() - } -} - -#[async_trait] -impl UdpProvider for MockNetRuntime { - type UdpSocket = R::UdpSocket; - - #[inline] - async fn bind(&self, addr: &SocketAddr) -> IoResult { - // TODO this should probably get delegated to MockNetProvider instead - self.runtime.bind(addr).await - } -} - -impl SleepProvider for MockNetRuntime { - type SleepFuture = R::SleepFuture; - fn sleep(&self, dur: Duration) -> Self::SleepFuture { - self.runtime.sleep(dur) - } - fn now(&self) -> Instant { - self.runtime.now() - } - fn wallclock(&self) -> SystemTime { - self.runtime.wallclock() - } +impl_runtime! { + [ ] MockNetRuntime, + spawn: runtime, + block: runtime, + sleep: runtime: R, + net: net: MockNetProvider, } diff --git a/crates/tor-rtmock/src/runtime.rs b/crates/tor-rtmock/src/runtime.rs new file mode 100644 index 000000000..c8fc2a6b8 --- /dev/null +++ b/crates/tor-rtmock/src/runtime.rs @@ -0,0 +1,216 @@ +//! Completely mock runtime + +use amplify::Getters; +use futures::FutureExt as _; +use strum::IntoEnumIterator as _; +use void::{ResultVoidExt as _, Void}; + +use crate::util::impl_runtime_prelude::*; + +use crate::net::MockNetProvider; +use crate::task::{MockExecutor, SchedulingPolicy}; +use crate::time::MockSleepProvider; + +/// Completely mock runtime +/// +/// Suitable for test cases that wish to completely control +/// the environment experienced by the code under test. +/// +/// ### Restrictions +/// +/// The test case must advance the mock time explicitly as desired. +/// There is not currently any facility for automatically +/// making progress by advancing the mock time by the right amounts +/// for the timeouts set by the futures under test. +// ^ I think such a facility could be provided. `MockSleepProvider` would have to +// provide a method to identify the next interesting time event. +// The waitfor machinery in MockSleepProvider and MockSleepRuntime doesn't seem suitable. +/// +/// Tests that use this runtime *must not* interact with the outside world; +/// everything must go through this runtime (and its pieces). +/// +/// #### Allowed +/// +/// * Inter-future communication facilities from `futures` +/// or other runtime-agnostic crates. +/// +/// * Fast synchronous operations that will complete "immediately" or "quickly". +/// E.g.: filesystem calls. +/// +/// * `std::sync::Mutex` (assuming the use is deadlock-free in a single-threaded +/// executor, as it should be in all of Arti). +/// +/// * Slower operations that are run synchronously (without futures `await`) +/// provided their completion doesn't depend on any of the futures we're running. +/// (These kind of operations are often discouraged in async contexts, +/// because they block the async runtime or its worker threads. +/// But they are often OK in tests.) +/// +/// * All facilities provided by this `MockExecutor` and its trait impls. +/// +/// #### Not allowed +/// +/// * Direct access to the real-world clock (`SystemTime::now`, `Instant::now`). +/// Including `coarsetime`, which is not mocked. +/// Exception: CPU use measurements. +/// +/// * Anything that spawns threads and then communicates with those threads +/// using async Rust facilities (futures). +/// +/// * Async sockets, or async use of other kernel-based IPC or network mechanisms. +/// +/// * Anything provided by a Rust runtime/executor project (eg anything from Tokio), +/// unless it is definitively established that it's runtime-agnostic. +#[derive(Debug, Default, Clone, Getters)] +#[getter(prefix = "mock_")] +pub struct MockRuntime { + /// Tasks + task: MockExecutor, + /// Time provider + sleep: MockSleepProvider, + /// Net provider + net: MockNetProvider, +} + +/// Builder for a manually-configured `MockRuntime` +#[derive(Debug, Default, Clone)] +pub struct MockRuntimeBuilder { + /// scheduling policy + scheduling: SchedulingPolicy, + /// starting wall clock time + starting_wallclock: Option, +} + +impl_runtime! { + [ ] MockRuntime, + spawn: task, + block: task, + sleep: sleep: MockSleepProvider, + net: net: MockNetProvider, +} + +impl MockRuntime { + /// Create a new `MockRuntime` with default parameters + pub fn new() -> Self { + Self::default() + } + + /// Return a builder, for creating a `MockRuntime` with some parameters manually configured + pub fn builder() -> MockRuntimeBuilder { + Default::default() + } + + /// Run a test case with a variety of runtime parameters, to try to find bugs + /// + /// `test_case` is an async closure which receives a `MockRuntime`. + /// It will be run with a number of differently configured executors. + /// + /// ### Variations + /// + /// The only variation currently implemented is this: + /// + /// Both FIFO and LIFO scheduling policies are tested, + /// in the hope that this will help discover ordering-dependent bugs. + pub fn test_with_various(mut test_case: TC) + where + TC: FnMut(MockRuntime) -> FUT, + FUT: Future, + { + Self::try_test_with_various(|runtime| test_case(runtime).map(|()| Ok::<_, Void>(()))) + .void_unwrap(); + } + + /// Run a faillible test case with a variety of runtime parameters, to try to find bugs + /// + /// `test_case` is an async closure which receives a `MockRuntime`. + /// It will be run with a number of differently configured executors. + /// + /// This function accepts a fallible closure, + /// and returns the first `Err` to the caller. + /// + /// See [`test_with_various()`](MockRuntime::test_with_various) for more details. + pub fn try_test_with_various(mut test_case: TC) -> Result<(), E> + where + TC: FnMut(MockRuntime) -> FUT, + FUT: Future>, + { + for scheduling in SchedulingPolicy::iter() { + let runtime = MockRuntime::builder().scheduling(scheduling).build(); + runtime.block_on(test_case(runtime.clone()))?; + } + Ok(()) + } + + /// Run tasks in the current executor until every task except this one is waiting + /// + /// Calls [`MockExecutor::progress_until_stalled()`]. + /// + /// # Restriction - no automatic time advance + /// + /// The mocked time will *not* be automatically advanced. + /// + /// Usually + /// (and especially if the tasks under test are waiting for timeouts or periodic events) + /// you must use + /// [`advance()`](MockRuntime::advance) + /// or + /// [`jump_to()`](MockRuntime::jump_to) + /// to ensure the simulated time progresses as required. + /// + /// # Panics + /// + /// Might malfunction or panic if more than one such call is running at once. + /// + /// (Ie, you must `.await` or drop the returned `Future` + /// before calling this method again.) + /// + /// Must be called and awaited within a future being run by `self`. + pub async fn progress_until_stalled(&self) { + self.task.progress_until_stalled().await; + } + + /// See [`MockSleepProvider::advance()`] + pub async fn advance(&self, dur: Duration) { + self.sleep.advance(dur).await; + } + /// See [`MockSleepProvider::jump_to()`] + pub fn jump_to(&self, new_wallclock: SystemTime) { + self.sleep.jump_to(new_wallclock); + } +} + +impl MockRuntimeBuilder { + /// Set the scheduling policy + pub fn scheduling(mut self, scheduling: SchedulingPolicy) -> Self { + self.scheduling = scheduling; + self + } + + /// Set the starting wall clock time + pub fn starting_wallclock(mut self, starting_wallclock: SystemTime) -> Self { + self.starting_wallclock = Some(starting_wallclock); + self + } + + /// Build the runtime + pub fn build(self) -> MockRuntime { + let MockRuntimeBuilder { + scheduling, + starting_wallclock, + } = self; + + let sleep = if let Some(starting_wallclock) = starting_wallclock { + MockSleepProvider::new(starting_wallclock) + } else { + MockSleepProvider::default() + }; + + let task = MockExecutor::with_scheduling(scheduling); + + MockRuntime { + sleep, + task, + ..Default::default() + } + } +} diff --git a/crates/tor-rtmock/src/sleep_runtime.rs b/crates/tor-rtmock/src/sleep_runtime.rs index bc4298f04..0e2cec38e 100644 --- a/crates/tor-rtmock/src/sleep_runtime.rs +++ b/crates/tor-rtmock/src/sleep_runtime.rs @@ -1,17 +1,12 @@ //! Declare MockSleepRuntime. -use crate::time::MockSleepProvider; -use tor_rtcompat::{BlockOn, Runtime, SleepProvider, TcpProvider, TlsProvider, UdpProvider}; - -use async_trait::async_trait; -use futures::task::{FutureObj, Spawn, SpawnError}; -use futures::Future; use pin_project::pin_project; -use std::io::Result as IoResult; -use std::net::SocketAddr; -use std::time::{Duration, Instant, SystemTime}; use tracing::trace; +use crate::time::MockSleepProvider; + +use crate::util::impl_runtime_prelude::*; + /// A wrapper Runtime that overrides the SleepProvider trait for the /// underlying runtime. #[derive(Clone, Debug)] @@ -51,6 +46,10 @@ impl MockSleepRuntime { /// Run a future under mock time, advancing time forward where necessary until it completes. /// Users of this function should read the whole of this documentation before using! /// + /// **NOTE** Instead of using this, consider [`MockRuntime`](crate::MockRuntime), + /// which will fully isolate the test case + /// (albeit at the cost of demanding manual management of the simulated time). + /// /// The returned future will run `fut`, expecting it to create `Sleeping` futures (as returned /// by `MockSleepProvider::sleep()` and similar functions). When all such created futures have /// been polled (indicating the future is waiting on them), time will be advanced in order that @@ -80,68 +79,12 @@ impl MockSleepRuntime { } } -impl Spawn for MockSleepRuntime { - fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> { - self.runtime.spawn_obj(future) - } -} - -impl BlockOn for MockSleepRuntime { - fn block_on(&self, future: F) -> F::Output { - self.runtime.block_on(future) - } -} - -#[async_trait] -impl TcpProvider for MockSleepRuntime { - type TcpStream = R::TcpStream; - type TcpListener = R::TcpListener; - - async fn connect(&self, addr: &SocketAddr) -> IoResult { - self.runtime.connect(addr).await - } - async fn listen(&self, addr: &SocketAddr) -> IoResult { - self.runtime.listen(addr).await - } -} - -impl TlsProvider for MockSleepRuntime { - type Connector = R::Connector; - type TlsStream = R::TlsStream; - fn tls_connector(&self) -> Self::Connector { - self.runtime.tls_connector() - } -} - -#[async_trait] -impl UdpProvider for MockSleepRuntime { - type UdpSocket = R::UdpSocket; - - async fn bind(&self, addr: &SocketAddr) -> IoResult { - self.runtime.bind(addr).await - } -} - -impl SleepProvider for MockSleepRuntime { - type SleepFuture = crate::time::Sleeping; - fn sleep(&self, dur: Duration) -> Self::SleepFuture { - self.sleep.sleep(dur) - } - fn now(&self) -> Instant { - self.sleep.now() - } - fn wallclock(&self) -> SystemTime { - self.sleep.wallclock() - } - fn block_advance>(&self, reason: T) { - self.sleep.block_advance(reason); - } - fn release_advance>(&self, reason: T) { - self.sleep.release_advance(reason); - } - fn allow_one_advance(&self, dur: Duration) { - self.sleep.allow_one_advance(dur); - } +impl_runtime! { + [ ] MockSleepRuntime, + spawn: runtime, + block: runtime, + sleep: sleep: MockSleepProvider, + net: runtime: R, } /// A future that advances time until another future is ready to complete. diff --git a/crates/tor-rtmock/src/task.rs b/crates/tor-rtmock/src/task.rs new file mode 100644 index 000000000..2acf4e08c --- /dev/null +++ b/crates/tor-rtmock/src/task.rs @@ -0,0 +1,747 @@ +//! Executor for running tests with mocked environment +//! +//! See [`MockExecutor`] + +use std::collections::VecDeque; +use std::fmt::{self, Debug, Display}; +use std::future::Future; +use std::iter; +use std::pin::Pin; +use std::sync::{Arc, Mutex, MutexGuard}; +use std::task::{Context, Poll, Wake, Waker}; + +use futures::pin_mut; +use futures::task::{FutureObj, Spawn, SpawnError}; +use futures::FutureExt as _; + +use educe::Educe; +use itertools::{chain, izip}; +use slotmap::DenseSlotMap; +use strum::EnumIter; +use tracing::trace; + +use tor_rtcompat::BlockOn; + +use Poll::*; +use TaskState::*; + +/// Type-erased future, one for each of our (normal) tasks +type TaskFuture = FutureObj<'static, ()>; + +/// Future for the argument to `block_on`, which is handled specially +type MainFuture<'m> = Pin<&'m mut dyn Future>; + +//---------- principal data structures ---------- + +/// Executor for running tests with mocked environment +/// +/// For test cases which don't actually wait for anything in the real world. +/// +/// This is the executor. +/// It implements [`Spawn`] and [`BlockOn`] +/// +/// It will usually be used as part of a `MockRuntime`. +/// +/// # Restricted environment +/// +/// Tests run with this executor must not attempt to block +/// on anything "outside": +/// every future that anything awaits must (eventually) be woken directly +/// *by some other task* in the same test case. +/// +/// (By directly we mean that the [`Waker::wake`] call is made +/// by that waking future, before that future itself awaits anything.) +/// +/// # Panics +/// +/// This executor will malfunction or panic if reentered. +#[derive(Clone, Default, Educe)] +#[educe(Debug)] +pub struct MockExecutor { + /// Mutable state + #[educe(Debug(ignore))] + data: ArcMutexData, +} + +/// Mutable state, wrapper type mostly so we can provide `.lock()` +#[derive(Clone, Default)] +struct ArcMutexData(Arc>); + +/// Task id, module to hide `Ti` alias +mod task_id { + slotmap::new_key_type! { + /// Task ID, usually called `TaskId` + /// + /// Short name in special `task_id` module so that [`Debug`] is nice + pub(super) struct Ti; + } +} +use task_id::Ti as TaskId; + +/// Executor's state +/// +/// ### Task state machine +/// +/// A task is created in `tasks`, `Awake`, so also in `awake`. +/// +/// When we poll it, we take it out of `awake` and set it to `Asleep`, +/// and then call `poll()`. +/// Any time after that, it can be made `Awake` again (and put back onto `awake`) +/// by the waker ([`ActualWaker`], wrapped in [`Waker`]). +/// +/// The task's future is of course also present here in this data structure. +/// However, during poll we must release the lock, +/// so we cannot borrow the future from `Data`. +/// Instead, we move it out. So `Task.fut` is an `Option`. +/// +/// ### "Main" task - the argument to `block_on` +/// +/// The signature of `BlockOn::block_on` accepts a non-`'static` future +/// (and a non-`Send`/`Sync` one). +/// +/// So we cannot store that future in `Data` because `Data` is `'static`. +/// Instead, this main task future is passed as an argument down the call stack. +/// In the data structure we simply store a placeholder, `TaskFutureInfo::Main`. +#[derive(Default)] +struct Data { + /// Tasks + /// + /// Includes tasks spawned with `spawn`, + /// and also the future passed to `block_on`. + tasks: DenseSlotMap, + + /// `awake` lists precisely: tasks that are `Awake`, plus maybe stale `TaskId`s + /// + /// Tasks are pushed onto the *back* when woken, + /// so back is the most recently woken. + awake: VecDeque, + + /// If a future from `progress_until_stalled` exists + progressing_until_stalled: Option, + + /// Scheduling policy + scheduling: SchedulingPolicy, +} + +/// How we should schedule? +#[derive(Debug, Clone, Default, EnumIter)] +#[non_exhaustive] +pub enum SchedulingPolicy { + /// Task *most* recently woken is run + /// + /// This is the default. + /// + /// It will expose starvation bugs if a task never sleeps. + /// (Which is a good thing in tests.) + #[default] + Stack, + /// Task *least* recently woken is run. + Queue, +} + +/// Record of a single task +/// +/// Tracks a spawned task, or the main task (the argument to `block_on`). +/// +/// Stored in [`Data`]`.tasks`. +struct Task { + /// For debugging output + desc: String, + /// Has this been woken via a waker? (And is it in `Data.awake`?) + state: TaskState, + /// The actual future (or a placeholder for it) + /// + /// May be `None` because we've temporarily moved it out so we can poll it + fut: Option, +} + +/// A future as stored in our record of a [`Task`] +enum TaskFutureInfo { + /// The [`Future`]. All is normal. + Normal(TaskFuture), + /// The future isn't here because this task is the main future for `block_on` + Main, +} + +/// State of a task - do we think it needs to be polled? +/// +/// Stored in [`Task`]`.state`. +#[derive(Debug)] +enum TaskState { + /// Awake - needs to be polled + /// + /// Established by [`waker.wake()`](Waker::wake) + Awake, + /// Asleep - does *not* need to be polled + /// + /// Established each time just before we call the future's [`poll`](Future::poll) + Asleep, +} + +/// Actual implementor of `Wake` for use in a `Waker` +/// +/// Futures (eg, channels from [`futures`]) will use this to wake a task +/// when it should be polled. +struct ActualWaker { + /// Executor state + data: ArcMutexData, + + /// Which task this is + id: TaskId, +} + +/// State used for an in-progress call to +/// [`progress_until_stalled`][`MockExecutor::progress_until_stalled`] +/// +/// If present in [`Data`], an (async) call to `progress_until_stalled` +/// is in progress. +/// +/// The future from `progress_until_stalled`, [`ProgressUntilStalledFuture`] +/// is a normal-ish future. +/// It can be polled in the normal way. +/// When it is polled, it looks here, in `finished`, to see if it's `Ready`. +/// +/// The future is made ready, and woken (via `waker`), +/// by bespoke code in the task executor loop. +/// +/// When `ProgressUntilStalledFuture` (maybe completes and) is dropped, +/// its `Drop` impl is used to remove this from `Data.progressing_until_stalled`. +#[derive(Debug)] +struct ProgressingUntilStalled { + /// Have we, in fact, stalled? + /// + /// Made `Ready` by special code in the executor loop + finished: Poll<()>, + + /// Waker + /// + /// Signalled by special code in the executor loop + waker: Option, +} + +/// Future from +/// [`progress_until_stalled`][`MockExecutor::progress_until_stalled`] +/// +/// See [`ProgressingUntilStalled`] for an overview of this aspect of the contraption. +/// +/// Existence of this struct implies `Data.progressing_until_stalled` is `Some`. +/// There can only be one at a time. +#[derive(Educe)] +#[educe(Debug)] +struct ProgressUntilStalledFuture { + /// Executor's state; this future's state is in `.progressing_until_stalled` + #[educe(Debug(ignore))] + data: ArcMutexData, +} + +//---------- creation ---------- + +impl MockExecutor { + /// Make a `MockExecutor` with default parameters + pub fn new() -> Self { + Self::default() + } + + /// Make a `MockExecutor` with a specific `SchedulingPolicy` + pub fn with_scheduling(scheduling: SchedulingPolicy) -> Self { + Data { + scheduling, + ..Default::default() + } + .into() + } +} + +impl From for MockExecutor { + fn from(data: Data) -> MockExecutor { + MockExecutor { + data: ArcMutexData(Arc::new(Mutex::new(data))), + } + } +} + +//---------- spawning ---------- + +impl MockExecutor { + /// Spawn a task and return something to identify it + /// + /// `desc` should `Display` as some kind of short string (ideally without spaces) + /// and will be used in the `Debug` impl and trace log messages from `MockExecutor`. + /// + /// The returned value is an opaque task identifier which is very cheap to clone + /// and which can be used by the caller in debug logging, + /// if it's desired to correlate with the debug output from `MockExecutor`. + /// Most callers will want to ignore it. + /// + /// This method is infalliable. (The `MockExecutor` cannot be shut down.) + pub fn spawn_identified( + &self, + desc: impl Display, + fut: impl Future + Send + Sync + 'static, + ) -> impl Debug + Clone + Send + Sync + 'static { + self.spawn_internal(desc.to_string(), FutureObj::from(Box::new(fut))) + } + + /// Spawn a task and return its `TaskId` + /// + /// Convenience method for use by `spawn_identified` and `spawn_obj`. + /// The future passed to `block_on` is not handled here. + fn spawn_internal(&self, desc: String, fut: TaskFuture) -> TaskId { + let mut data = self.data.lock(); + data.insert_task(desc, TaskFutureInfo::Normal(fut)) + } +} + +impl Data { + /// Insert a task given its `TaskFutureInfo` and return its `TaskId`. + fn insert_task(&mut self, desc: String, fut: TaskFutureInfo) -> TaskId { + let state = Awake; + let id = self.tasks.insert(Task { + state, + desc, + fut: Some(fut), + }); + self.awake.push_back(id); + trace!("MockExecutor spawned {:?}={:?}", id, self.tasks[id]); + id + } +} + +impl Spawn for MockExecutor { + fn spawn_obj(&self, future: TaskFuture) -> Result<(), SpawnError> { + self.spawn_internal("".into(), future); + Ok(()) + } +} + +//---------- block_on ---------- + +impl BlockOn for MockExecutor { + /// Run `fut` to completion, synchronously + /// + /// # Panics + /// + /// Might malfunction or panic if: + /// + /// * The provided future doesn't complete (without externally blocking), + /// but instead waits for something. + /// + /// * The `MockExecutor` is reentered. (Eg, `block_on` is reentered.) + fn block_on(&self, fut: F) -> F::Output + where + F: Future, + { + let mut value: Option = None; + let fut = { + let value = &mut value; + async move { + trace!("MockExecutor block_on future..."); + let t = fut.await; + trace!("MockExecutor block_on future returned..."); + *value = Some(t); + trace!("MockExecutor block_on future exiting."); + } + }; + + { + pin_mut!(fut); + self.data + .lock() + .insert_task("main".into(), TaskFutureInfo::Main); + self.execute_to_completion(fut); + } + + #[allow(clippy::let_and_return)] // clarity + let value = value.take().unwrap_or_else(|| { + let data = self.data.lock(); + panic!( + r" +all futures blocked. waiting for the real world? or deadlocked (waiting for each other) ? + +{data:#?} +" + ); + }); + + value + } +} + +//---------- execution - core implementation ---------- + +impl MockExecutor { + /// Keep polling tasks until nothing more can be done + /// + /// Ie, stop when `awake` is empty and `progressing_until_stalled` is `None`. + fn execute_to_completion(&self, mut main_fut: MainFuture) { + trace!("MockExecutor execute_to_completion..."); + loop { + self.execute_until_first_stall(main_fut.as_mut()); + + // Handle `progressing_until_stalled` + let pus_waker = { + let mut data = self.data.lock(); + let pus = &mut data.progressing_until_stalled; + trace!("MockExecutor execute_to_completion PUS={:?}", &pus); + let Some(pus) = pus else { + // No progressing_until_stalled, we're actually done. + break; + }; + assert_eq!( + pus.finished, Pending, + "ProgressingUntilStalled finished twice?!" + ); + pus.finished = Ready(()); + pus.waker + .clone() + .expect("ProgressUntilStalledFuture not ever polled!") + }; + pus_waker.wake(); + } + trace!("MockExecutor execute_to_completion done"); + } + + /// Keep polling tasks until `awake` is empty + /// + /// (Ignores `progressing_until_stalled` - so if one is active, + /// will return when all other tasks have blocked.) + /// + /// # Panics + /// + /// Might malfunction or panic if called reentrantly + fn execute_until_first_stall(&self, mut main_fut: MainFuture) { + trace!("MockExecutor execute_until_first_stall ..."); + 'outer: loop { + // Take a `Awake` task off `awake` and make it `Polling` + let (id, mut fut) = 'inner: loop { + let mut data = self.data.lock(); + let Some(id) = data.schedule() else { break 'outer }; + let Some(task) = data.tasks.get_mut(id) else { + trace!("MockExecutor {id:?} vanished"); + continue; + }; + task.state = Asleep; + let fut = task.fut.take().expect("future missing from task!"); + break 'inner (id, fut); + }; + + // Poll the selected task + let waker = Waker::from(Arc::new(ActualWaker { + data: self.data.clone(), + id, + })); + trace!("MockExecutor {id:?} polling..."); + let mut cx = Context::from_waker(&waker); + let r = match &mut fut { + TaskFutureInfo::Normal(fut) => fut.poll_unpin(&mut cx), + TaskFutureInfo::Main => main_fut.as_mut().poll(&mut cx), + }; + + // Deal with the returned `Poll` + { + let mut data = self.data.lock(); + let task = data + .tasks + .get_mut(id) + .expect("task vanished while we were polling it"); + + match r { + Pending => { + trace!("MockExecutor {id:?} -> Pending"); + if task.fut.is_some() { + panic!("task reinserted while we polled it?!"); + } + // The task might have been woken *by its own poll method*. + // That's why we set it to `Asleep` *earlier* rather than here. + // All we need to do is put the future back. + task.fut = Some(fut); + } + Ready(()) => { + trace!("MockExecutor {id:?} -> Ready"); + // Oh, it finished! + // It might be in `awake`, but that's allowed to contain stale tasks, + // so we *don't* need to scan that list and remove it. + data.tasks.remove(id); + } + } + } + } + trace!("MockExecutor execute_until_first_stall done."); + } +} + +impl Data { + /// Return the next task to run + /// + /// The task is removed from `awake`, but **`state` is not set to `Asleep`**. + /// The caller must restore the invariant! + fn schedule(&mut self) -> Option { + use SchedulingPolicy as SP; + match self.scheduling { + SP::Stack => self.awake.pop_back(), + SP::Queue => self.awake.pop_front(), + } + } +} + +impl Wake for ActualWaker { + fn wake(self: Arc) { + let mut data = self.data.lock(); + trace!("MockExecutor {:?} wake", &self.id); + let Some(task) = data.tasks.get_mut(self.id) else { return }; + match task.state { + Awake => {} + Asleep => { + task.state = Awake; + data.awake.push_back(self.id); + } + } + } +} + +//---------- "progress until stalled" functionality ---------- + +impl MockExecutor { + /// Run tasks in the current executor until every other task is waiting + /// + /// # Panics + /// + /// Might malfunction or panic if more than one such call is running at once. + /// + /// (Ie, you must `.await` or drop the returned `Future` + /// before calling this method again.) + /// + /// Must be called and awaited within a future being run by `self`. + pub fn progress_until_stalled(&self) -> impl Future { + let mut data = self.data.lock(); + assert!( + data.progressing_until_stalled.is_none(), + "progress_until_stalled called more than once" + ); + trace!("MockExecutor progress_until_stalled..."); + data.progressing_until_stalled = Some(ProgressingUntilStalled { + finished: Pending, + waker: None, + }); + ProgressUntilStalledFuture { + data: self.data.clone(), + } + } +} + +impl Future for ProgressUntilStalledFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let mut data = self.data.lock(); + let pus = data.progressing_until_stalled.as_mut(); + trace!("MockExecutor progress_until_stalled polling... {:?}", &pus); + let pus = pus.expect("ProgressingUntilStalled missing"); + pus.waker = Some(cx.waker().clone()); + pus.finished + } +} + +impl Drop for ProgressUntilStalledFuture { + fn drop(&mut self) { + self.data.lock().progressing_until_stalled = None; + } +} + +//---------- ancillary and convenience functions ---------- + +/// Trait to let us assert at compile time that something is nicely `Sync` etc. +trait EnsureSyncSend: Sync + Send + 'static {} +impl EnsureSyncSend for ActualWaker {} +impl EnsureSyncSend for MockExecutor {} + +impl ArcMutexData { + /// Lock and obtain the guard + /// + /// Convenience method which panics on poison + fn lock(&self) -> MutexGuard { + self.0.lock().expect("data lock poisoned") + } +} + +//---------- bespoke Debug impls ---------- + +// See `impl Debug for Data` for notes on the output +impl Debug for Task { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let Task { desc, state, fut } = self; + write!(f, "{:?}", desc)?; + write!(f, "=")?; + match fut { + None => write!(f, "P")?, + Some(TaskFutureInfo::Normal(_)) => write!(f, "f")?, + Some(TaskFutureInfo::Main) => write!(f, "m")?, + } + match state { + Awake => write!(f, "W")?, + Asleep => write!(f, "s")?, + }; + Ok(()) + } +} + +/// Helper: `Debug`s as a list of tasks, given the `Data` for lookups and a list of the ids +struct DebugTasks<'d, F>(&'d Data, F); + +// See `impl Debug for Data` for notes on the output +impl Debug for DebugTasks<'_, F> +where + F: Fn() -> I, + I: Iterator, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let DebugTasks(data, ids) = self; + for (id, delim) in izip!(ids(), chain!(iter::once(""), iter::repeat(" ")),) { + write!(f, "{delim}{id:?}")?; + match data.tasks.get(id) { + None => write!(f, "-")?, + Some(task) => write!(f, "={task:?}")?, + } + } + Ok(()) + } +} + +/// `Task`s in `Data` are printed as `Ti(ID)"SPEC"=FLAGS"`. +/// +/// `FLAGS` are: +/// +/// * `P`: this task is being polled (its `TaskFutureInfo` is absent) +/// * `f`: this is a normal task with a future and its future is present in `Data` +/// * `m`: this is the main task from `block_on` +/// +/// * `W`: the task is awake +/// * `s`: the task is asleep +// +// We do it this way because the naive dump from derive is very expansive +// and makes it impossible to see the wood for the trees. +// This very compact representation it easier to find a task of interest in the output. +// +// This is implemented in `impl Debug for Task`. +impl Debug for Data { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let Data { + tasks, + awake, + progressing_until_stalled: pus, + scheduling, + } = self; + let mut s = f.debug_struct("Data"); + s.field("tasks", &DebugTasks(self, || tasks.keys())); + s.field("awake", &DebugTasks(self, || awake.iter().cloned())); + s.field("p.u.s", pus); + s.field("scheduling", scheduling); + s.finish() + } +} + +#[cfg(test)] +mod test { + // @@ begin test lint list maintained by maint/add_warning @@ + #![allow(clippy::bool_assert_comparison)] + #![allow(clippy::clone_on_copy)] + #![allow(clippy::dbg_macro)] + #![allow(clippy::print_stderr)] + #![allow(clippy::print_stdout)] + #![allow(clippy::single_char_pattern)] + #![allow(clippy::unwrap_used)] + #![allow(clippy::unchecked_duration_subtraction)] + //! + use super::*; + use futures::channel::mpsc; + use futures::{SinkExt as _, StreamExt as _}; + use tracing_test::traced_test; + + #[traced_test] + #[test] + fn simple() { + let runtime = MockExecutor::default(); + let val = runtime.block_on(async { 42 }); + assert_eq!(val, 42); + } + + #[traced_test] + #[test] + fn stall() { + let runtime = MockExecutor::default(); + + runtime.block_on({ + let runtime = runtime.clone(); + async move { + const N: usize = 3; + let (mut txs, mut rxs): (Vec<_>, Vec<_>) = + (0..N).map(|_| mpsc::channel::(5)).unzip(); + + let mut rx_n = rxs.pop().unwrap(); + + for (i, mut rx) in rxs.into_iter().enumerate() { + runtime.spawn_identified(i, { + let mut txs = txs.clone(); + async move { + loop { + eprintln!("task {i} rx..."); + let v = rx.next().await.unwrap(); + let nv = v + 1; + eprintln!("task {i} rx {v}, tx {nv}"); + let v = nv; + txs[v].send(v).await.unwrap(); + } + } + }); + } + + dbg!(); + let _: mpsc::TryRecvError = rx_n.try_next().unwrap_err(); + + dbg!(); + runtime.progress_until_stalled().await; + + dbg!(); + let _: mpsc::TryRecvError = rx_n.try_next().unwrap_err(); + + dbg!(); + txs[0].send(0).await.unwrap(); + + dbg!(); + runtime.progress_until_stalled().await; + + dbg!(); + let r = rx_n.next().await; + assert_eq!(r, Some(N - 1)); + + dbg!(); + let _: mpsc::TryRecvError = rx_n.try_next().unwrap_err(); + + runtime.spawn_identified("tx", { + let txs = txs.clone(); + async { + eprintln!("sending task..."); + for (i, mut tx) in txs.into_iter().enumerate() { + eprintln!("sending 0 to {i}..."); + tx.send(0).await.unwrap(); + } + eprintln!("sending task done"); + } + }); + + for i in 0..txs.len() { + eprintln!("main {i} wait stall..."); + runtime.progress_until_stalled().await; + eprintln!("main {i} rx wait..."); + let r = rx_n.next().await; + eprintln!("main {i} rx = {r:?}"); + assert!(r == Some(0) || r == Some(N - 1)); + } + + eprintln!("finishing..."); + runtime.progress_until_stalled().await; + eprintln!("finished."); + } + }); + } +} diff --git a/crates/tor-rtmock/src/time.rs b/crates/tor-rtmock/src/time.rs index 9e053cb2f..7ca114f4c 100644 --- a/crates/tor-rtmock/src/time.rs +++ b/crates/tor-rtmock/src/time.rs @@ -86,6 +86,13 @@ pub struct Sleeping { provider: Weak>, } +impl Default for MockSleepProvider { + fn default() -> Self { + let wallclock = humantime::parse_rfc3339("2023-07-05T11:25:56Z").expect("parse"); + MockSleepProvider::new(wallclock) + } +} + impl MockSleepProvider { /// Create a new MockSleepProvider, starting at a given wall-clock time. pub fn new(wallclock: SystemTime) -> Self { diff --git a/crates/tor-rtmock/src/util.rs b/crates/tor-rtmock/src/util.rs new file mode 100644 index 000000000..51d4a916c --- /dev/null +++ b/crates/tor-rtmock/src/util.rs @@ -0,0 +1,119 @@ +//! Internal utilities for `tor_rtmock` + +/// Implements `Runtime` for a struct made of multiple sub-providers +/// +/// The `$SomeMockRuntime` type must be a struct containing +/// field(s) which implement `SleepProvider`, `NetProvider`, etc. +/// +/// `$gens` are the generics, written as (for example) `[ ]`. +/// +/// The remaining arguments are the fields. +/// For each field there's: +/// - the short name of what is being provided (a fixed identifier) +/// - the field name in `$SockMockRuntime` +/// - for some cases, the type of that field +/// +/// The fields must be specified in the expected order! +// +// This could be further reduced with more macrology: +// ambassador might be able to remove most of the body (although does it do async well?) +// derive-adhoc would allow a more natural input syntax and avoid restating field types +macro_rules! impl_runtime { { + [ $($gens:tt)* ] $SomeMockRuntime:ty, + spawn: $spawn:ident, + block: $block:ident, + sleep: $sleep:ident: $SleepProvider:ty, + net: $net:ident: $NetProvider:ty, +} => { + impl $($gens)* Spawn for $SomeMockRuntime { + fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> { + self.$spawn.spawn_obj(future) + } + } + + impl $($gens)* BlockOn for $SomeMockRuntime { + fn block_on(&self, future: F) -> F::Output { + self.$block.block_on(future) + } + } + + #[async_trait] + impl $($gens)* TcpProvider for $SomeMockRuntime { + type TcpStream = <$NetProvider as TcpProvider>::TcpStream; + type TcpListener = <$NetProvider as TcpProvider>::TcpListener; + + async fn connect(&self, addr: &SocketAddr) -> IoResult { + self.$net.connect(addr).await + } + async fn listen(&self, addr: &SocketAddr) -> IoResult { + self.$net.listen(addr).await + } + } + + impl $($gens)* TlsProvider<<$NetProvider as TcpProvider>::TcpStream> for $SomeMockRuntime { + type Connector = <$NetProvider as TlsProvider< + <$NetProvider as TcpProvider>::TcpStream + >>::Connector; + type TlsStream = <$NetProvider as TlsProvider< + <$NetProvider as TcpProvider>::TcpStream + >>::TlsStream; + fn tls_connector(&self) -> Self::Connector { + self.$net.tls_connector() + } + } + + #[async_trait] + impl $($gens)* UdpProvider for $SomeMockRuntime { + type UdpSocket = <$NetProvider as UdpProvider>::UdpSocket; + + #[inline] + async fn bind(&self, addr: &SocketAddr) -> IoResult { + self.$net.bind(addr).await + } + } + + impl $($gens)* SleepProvider for $SomeMockRuntime { + type SleepFuture = <$SleepProvider as SleepProvider>::SleepFuture; + + fn sleep(&self, dur: Duration) -> Self::SleepFuture { + self.$sleep.sleep(dur) + } + fn now(&self) -> Instant { + self.$sleep.now() + } + fn wallclock(&self) -> SystemTime { + self.$sleep.wallclock() + } + fn block_advance>(&self, reason: T) { + self.$sleep.block_advance(reason); + } + fn release_advance>(&self, reason: T) { + self.$sleep.release_advance(reason); + } + fn allow_one_advance(&self, dur: Duration) { + self.$sleep.allow_one_advance(dur); + } + } +} } + +/// Prelude that must be imported to use [`impl_runtime!`](impl_runtime) +// +// This could have been part of the expansion of `impl_runtime!`, +// but it seems rather too exciting for a macro to import things as a side gig. +// +// Arguably this ought to be an internal crate::prelude instead. +// But crate-internal preludes are controversial within the Arti team. -Diziet +// +// For macro visibility reasons, this must come *lexically after* the macro, +// to allow it to refer to the macro in the doc comment. +pub(crate) mod impl_runtime_prelude { + pub(crate) use async_trait::async_trait; + pub(crate) use futures::task::{FutureObj, Spawn, SpawnError}; + pub(crate) use futures::Future; + pub(crate) use std::io::Result as IoResult; + pub(crate) use std::net::SocketAddr; + pub(crate) use std::time::{Duration, Instant, SystemTime}; + pub(crate) use tor_rtcompat::{ + BlockOn, Runtime, SleepProvider, TcpProvider, TlsProvider, UdpProvider, + }; +}