Merge branch 'sync_guard_api'

This commit is contained in:
Nick Mathewson 2021-10-10 12:15:05 -04:00
commit 99effeb532
2 changed files with 40 additions and 61 deletions

View File

@ -8,10 +8,9 @@ use crate::GuardMgrInner;
use futures::{ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
lock::Mutex,
stream::{self, StreamExt}, stream::{self, StreamExt},
}; };
use std::sync::Weak; use std::sync::{Mutex, Weak};
/// A message sent by to the [`report_status_events()`] task. /// A message sent by to the [`report_status_events()`] task.
#[derive(Debug)] #[derive(Debug)]
@ -48,7 +47,7 @@ pub(crate) type MsgResult = Result<Msg, futures::channel::oneshot::Canceled>;
pub(crate) async fn report_status_events( pub(crate) async fn report_status_events(
runtime: impl tor_rtcompat::SleepProvider, runtime: impl tor_rtcompat::SleepProvider,
inner: Weak<Mutex<GuardMgrInner>>, inner: Weak<Mutex<GuardMgrInner>>,
ctrl: mpsc::Receiver<MsgResult>, ctrl: mpsc::UnboundedReceiver<MsgResult>,
) { ) {
// Multiplexes a bunch of one-shot receivers to tell us about guard // Multiplexes a bunch of one-shot receivers to tell us about guard
// status outcomes. // status outcomes.
@ -72,7 +71,7 @@ pub(crate) async fn report_status_events(
Some(Ok(Msg::Status(id, status))) => { Some(Ok(Msg::Status(id, status))) => {
// We've got a report about a guard status. // We've got a report about a guard status.
if let Some(inner) = inner.upgrade() { if let Some(inner) = inner.upgrade() {
let mut inner = inner.lock().await; let mut inner = inner.lock().expect("Poisoned lock");
inner.handle_msg(id, status, &runtime); inner.handle_msg(id, status, &runtime);
} else { } else {
// The guard manager has gone away. // The guard manager has gone away.
@ -110,7 +109,7 @@ pub(crate) async fn run_periodic<R: tor_rtcompat::SleepProvider>(
) { ) {
loop { loop {
let delay = if let Some(inner) = inner.upgrade() { let delay = if let Some(inner) = inner.upgrade() {
let mut inner = inner.lock().await; let mut inner = inner.lock().expect("Poisoned lock");
let wallclock = runtime.wallclock(); let wallclock = runtime.wallclock();
let now = runtime.now(); let now = runtime.now();
inner.run_periodic_events(wallclock, now) inner.run_periodic_events(wallclock, now)

View File

@ -129,13 +129,11 @@
// filtered // filtered
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::lock::Mutex;
use futures::task::{SpawnError, SpawnExt}; use futures::task::{SpawnError, SpawnExt};
use futures::SinkExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime}; use std::time::{Duration, Instant, SystemTime};
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
@ -169,13 +167,6 @@ pub struct GuardMgr<R: Runtime> {
runtime: R, runtime: R,
/// Internal state for the guard manager. /// Internal state for the guard manager.
// TODO: I wish I could use a regular mutex rather than a
// futures::lock::Mutex, but I don't see how that's feasible. We
// need to get access to inner.ctrl and then send over it, which
// means we need an async mutex.
//
// Conceivably, I could move ctrl out to GuardMgr, and then put it
// under a sync::Mutex. Is that smart?
inner: Arc<Mutex<GuardMgrInner>>, inner: Arc<Mutex<GuardMgrInner>>,
} }
@ -213,7 +204,14 @@ struct GuardMgrInner {
/// A mpsc channel, used to tell the task running in /// A mpsc channel, used to tell the task running in
/// [`daemon::report_status_events`] about a new event to monitor. /// [`daemon::report_status_events`] about a new event to monitor.
ctrl: mpsc::Sender<daemon::MsgResult>, ///
/// This uses an `UnboundedSener` so that we don't have to await
/// while sending the message, which in turn allows the GuardMgr
/// API to be simpler. The risk, however, is that there's no
/// backpressure in the event that the task running
/// [`daemon::report_status_events`] fails to read from this
/// channel.
ctrl: mpsc::UnboundedSender<daemon::MsgResult>,
/// Information about guards that we've given out, but where we have /// Information about guards that we've given out, but where we have
/// not yet heard whether the guard was successful. /// not yet heard whether the guard was successful.
@ -249,7 +247,7 @@ impl<R: Runtime> GuardMgr<R> {
where where
S: StateMgr + Send + Sync + 'static, S: StateMgr + Send + Sync + 'static,
{ {
let (ctrl, rcv) = mpsc::channel(32); let (ctrl, rcv) = mpsc::unbounded();
let default_storage = state_mgr.create_handle("default_guards"); let default_storage = state_mgr.create_handle("default_guards");
let active_guards = default_storage.load()?.unwrap_or_else(GuardSet::new); let active_guards = default_storage.load()?.unwrap_or_else(GuardSet::new);
let inner = Arc::new(Mutex::new(GuardMgrInner { let inner = Arc::new(Mutex::new(GuardMgrInner {
@ -279,8 +277,8 @@ impl<R: Runtime> GuardMgr<R> {
/// ///
/// Return true if we were able to save, and false if we couldn't /// Return true if we were able to save, and false if we couldn't
/// get the lock. /// get the lock.
pub async fn update_persistent_state(&self) -> Result<bool, GuardMgrError> { pub fn update_persistent_state(&self) -> Result<bool, GuardMgrError> {
let inner = self.inner.lock().await; let inner = self.inner.lock().expect("Poisoned lock");
if inner.default_storage.try_lock()? { if inner.default_storage.try_lock()? {
trace!("Flushing guard state to disk."); trace!("Flushing guard state to disk.");
inner.default_storage.store(&inner.active_guards)?; inner.default_storage.store(&inner.active_guards)?;
@ -298,11 +296,11 @@ impl<R: Runtime> GuardMgr<R> {
/// potential candidate guards. /// potential candidate guards.
/// ///
/// Call this method whenever the `NetDir` changes. /// Call this method whenever the `NetDir` changes.
pub async fn update_network(&self, netdir: &NetDir) { pub fn update_network(&self, netdir: &NetDir) {
trace!("Updating guard state from network directory"); trace!("Updating guard state from network directory");
let now = self.runtime.wallclock(); let now = self.runtime.wallclock();
let mut inner = self.inner.lock().await; let mut inner = self.inner.lock().expect("Poisoned lock");
inner.update(now, Some(netdir)); inner.update(now, Some(netdir));
} }
@ -319,9 +317,9 @@ impl<R: Runtime> GuardMgr<R> {
/// We should really call this every time we read a cell, but that /// We should really call this every time we read a cell, but that
/// isn't efficient or practical. We'll probably have to refactor /// isn't efficient or practical. We'll probably have to refactor
/// things somehow. (TODO) /// things somehow. (TODO)
pub async fn note_internet_activity(&self) { pub fn note_internet_activity(&self) {
let now = self.runtime.now(); let now = self.runtime.now();
let mut inner = self.inner.lock().await; let mut inner = self.inner.lock().expect("Poisoned lock");
inner.last_time_on_internet = Some(now); inner.last_time_on_internet = Some(now);
} }
@ -329,7 +327,7 @@ impl<R: Runtime> GuardMgr<R> {
/// ///
/// (Since there is only one kind of filter right now, there's no /// (Since there is only one kind of filter right now, there's no
/// real reason to call this function, but at least it should work. /// real reason to call this function, but at least it should work.
pub async fn set_filter(&self, filter: GuardFilter, netdir: &NetDir) { pub fn set_filter(&self, filter: GuardFilter, netdir: &NetDir) {
// First we have to see how much of the possible guard space // First we have to see how much of the possible guard space
// this new filter allows. (We don't use this info yet, but we will // this new filter allows. (We don't use this info yet, but we will
// one we have nontrivial filters.) // one we have nontrivial filters.)
@ -345,7 +343,7 @@ impl<R: Runtime> GuardMgr<R> {
}; };
let now = self.runtime.wallclock(); let now = self.runtime.wallclock();
let mut inner = self.inner.lock().await; let mut inner = self.inner.lock().expect("Poisoned lock");
let restrictive_filter = frac_permitted < inner.params.filter_threshold; let restrictive_filter = frac_permitted < inner.params.filter_threshold;
@ -386,7 +384,7 @@ impl<R: Runtime> GuardMgr<R> {
/// ///
/// This function only looks at netdir when all of the known /// This function only looks at netdir when all of the known
/// guards are down; to force an update, use [`GuardMgr::update_network`]. /// guards are down; to force an update, use [`GuardMgr::update_network`].
pub async fn select_guard( pub fn select_guard(
&self, &self,
usage: GuardUsage, usage: GuardUsage,
netdir: Option<&NetDir>, netdir: Option<&NetDir>,
@ -394,7 +392,7 @@ impl<R: Runtime> GuardMgr<R> {
let now = self.runtime.now(); let now = self.runtime.now();
let wallclock = self.runtime.wallclock(); let wallclock = self.runtime.wallclock();
let mut inner = self.inner.lock().await; let mut inner = self.inner.lock().expect("Poisoned lock");
// (I am not 100% sure that we need to consider_all_retries here, but // (I am not 100% sure that we need to consider_all_retries here, but
// it should _probably_ not hurt.) // it should _probably_ not hurt.)
@ -419,12 +417,9 @@ impl<R: Runtime> GuardMgr<R> {
inner.active_guards.record_attempt(&guard_id, now); inner.active_guards.record_attempt(&guard_id, now);
// Have to do this while not holding lock, since it awaits.
// TODO: I wish this function didn't have to be async.
inner inner
.ctrl .ctrl
.send(Ok(daemon::Msg::Observe(rcv))) .unbounded_send(Ok(daemon::Msg::Observe(rcv)))
.await
.expect("Guard observer task exited prematurely"); .expect("Guard observer task exited prematurely");
Ok((guard_id, monitor, usable)) Ok((guard_id, monitor, usable))
@ -437,11 +432,10 @@ impl<R: Runtime> GuardMgr<R> {
let (snd, rcv) = futures::channel::oneshot::channel(); let (snd, rcv) = futures::channel::oneshot::channel();
let pingmsg = daemon::Msg::Ping(snd); let pingmsg = daemon::Msg::Ping(snd);
{ {
let mut inner = self.inner.lock().await; let inner = self.inner.lock().expect("Poisoned lock");
inner inner
.ctrl .ctrl
.send(Ok(pingmsg)) .unbounded_send(Ok(pingmsg))
.await
.expect("Guard observer task exited permaturely."); .expect("Guard observer task exited permaturely.");
} }
let _ = rcv.await; let _ = rcv.await;
@ -896,9 +890,9 @@ mod test {
let (guardmgr, statemgr, netdir) = init(rt.clone()); let (guardmgr, statemgr, netdir) = init(rt.clone());
let usage = GuardUsage::default(); let usage = GuardUsage::default();
guardmgr.update_network(&netdir).await; guardmgr.update_network(&netdir);
let (id, mon, usable) = guardmgr.select_guard(usage, Some(&netdir)).await.unwrap(); let (id, mon, usable) = guardmgr.select_guard(usage, Some(&netdir)).unwrap();
// Report that the circuit succeeded. // Report that the circuit succeeded.
mon.succeeded(); mon.succeeded();
@ -908,16 +902,16 @@ mod test {
// Save the state... // Save the state...
guardmgr.flush_msg_queue().await; guardmgr.flush_msg_queue().await;
guardmgr.update_persistent_state().await.unwrap(); guardmgr.update_persistent_state().unwrap();
drop(guardmgr); drop(guardmgr);
// Try reloading from the state... // Try reloading from the state...
let guardmgr2 = GuardMgr::new(rt.clone(), statemgr.clone()).unwrap(); let guardmgr2 = GuardMgr::new(rt.clone(), statemgr.clone()).unwrap();
guardmgr2.update_network(&netdir).await; guardmgr2.update_network(&netdir);
// Since the guard was confirmed, we should get the same one this time! // Since the guard was confirmed, we should get the same one this time!
let usage = GuardUsage::default(); let usage = GuardUsage::default();
let (id2, _mon, _usable) = guardmgr2.select_guard(usage, Some(&netdir)).await.unwrap(); let (id2, _mon, _usable) = guardmgr2.select_guard(usage, Some(&netdir)).unwrap();
assert_eq!(id2, id); assert_eq!(id2, id);
}) })
} }
@ -930,21 +924,15 @@ mod test {
test_with_all_runtimes!(|rt| async move { test_with_all_runtimes!(|rt| async move {
let (guardmgr, _statemgr, netdir) = init(rt); let (guardmgr, _statemgr, netdir) = init(rt);
let u = GuardUsage::default(); let u = GuardUsage::default();
guardmgr.update_network(&netdir).await; guardmgr.update_network(&netdir);
// We'll have the first two guard fail, which should make us // We'll have the first two guard fail, which should make us
// try a non-primary guard. // try a non-primary guard.
let (id1, mon, _usable) = guardmgr let (id1, mon, _usable) = guardmgr.select_guard(u.clone(), Some(&netdir)).unwrap();
.select_guard(u.clone(), Some(&netdir))
.await
.unwrap();
mon.failed(); mon.failed();
guardmgr.flush_msg_queue().await; // avoid race guardmgr.flush_msg_queue().await; // avoid race
guardmgr.flush_msg_queue().await; // avoid race guardmgr.flush_msg_queue().await; // avoid race
let (id2, mon, _usable) = guardmgr let (id2, mon, _usable) = guardmgr.select_guard(u.clone(), Some(&netdir)).unwrap();
.select_guard(u.clone(), Some(&netdir))
.await
.unwrap();
mon.failed(); mon.failed();
guardmgr.flush_msg_queue().await; // avoid race guardmgr.flush_msg_queue().await; // avoid race
guardmgr.flush_msg_queue().await; // avoid race guardmgr.flush_msg_queue().await; // avoid race
@ -952,14 +940,8 @@ mod test {
assert!(id1 != id2); assert!(id1 != id2);
// Now we should get two sampled guards. They should be different. // Now we should get two sampled guards. They should be different.
let (id3, mon3, usable3) = guardmgr let (id3, mon3, usable3) = guardmgr.select_guard(u.clone(), Some(&netdir)).unwrap();
.select_guard(u.clone(), Some(&netdir)) let (id4, mon4, usable4) = guardmgr.select_guard(u.clone(), Some(&netdir)).unwrap();
.await
.unwrap();
let (id4, mon4, usable4) = guardmgr
.select_guard(u.clone(), Some(&netdir))
.await
.unwrap();
assert!(id3 != id4); assert!(id3 != id4);
let (u3, u4) = futures::join!( let (u3, u4) = futures::join!(
@ -983,12 +965,10 @@ mod test {
test_with_all_runtimes!(|rt| async move { test_with_all_runtimes!(|rt| async move {
let (guardmgr, _statemgr, netdir) = init(rt); let (guardmgr, _statemgr, netdir) = init(rt);
let u = GuardUsage::default(); let u = GuardUsage::default();
guardmgr.update_network(&netdir).await; guardmgr.update_network(&netdir);
guardmgr guardmgr.set_filter(GuardFilter::TestingLimitKeys, &netdir);
.set_filter(GuardFilter::TestingLimitKeys, &netdir)
.await;
let (id1, _mon, _usable) = guardmgr.select_guard(u, Some(&netdir)).await.unwrap(); let (id1, _mon, _usable) = guardmgr.select_guard(u, Some(&netdir)).unwrap();
// Make sure that the filter worked. // Make sure that the filter worked.
assert_eq!(id1.rsa.as_bytes()[0] % 4, 0); assert_eq!(id1.rsa.as_bytes()[0] % 4, 0);
}) })