Merge branch 'arc_circ' into 'main'

Refactor ClientCirc APIs to use Arc<ClientCirc>.

Closes #846

See merge request tpo/core/arti!1187
This commit is contained in:
gabi-250 2023-05-17 09:47:19 +00:00
commit fb8bc19b9b
14 changed files with 83 additions and 74 deletions

View File

@ -1106,7 +1106,7 @@ impl<R: Runtime> TorClient<R> {
&self, &self,
exit_ports: &[TargetPort], exit_ports: &[TargetPort],
prefs: &StreamPrefs, prefs: &StreamPrefs,
) -> StdResult<ClientCirc, ErrorDetail> { ) -> StdResult<Arc<ClientCirc>, ErrorDetail> {
// TODO HS probably this netdir ought to be made in connect_with_prefs // TODO HS probably this netdir ought to be made in connect_with_prefs
// like for StreamInstructions::Hs. // like for StreamInstructions::Hs.
self.wait_for_bootstrap().await?; self.wait_for_bootstrap().await?;

View File

@ -1,2 +1,3 @@
ADDED: Broadened hspool to accept T:CircTarget in place of OwnedCircTarget. ADDED: Broadened hspool to accept T:CircTarget in place of OwnedCircTarget.
BREAKING: APIs now return and accept Arc<ClientCirc>

View File

@ -43,7 +43,7 @@ pub(crate) trait Buildable: Sized {
ct: &OwnedChanTarget, ct: &OwnedChanTarget,
params: &CircParameters, params: &CircParameters,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<Self>; ) -> Result<Arc<Self>>;
/// Launch a new circuit through a given relay, given a circuit target /// Launch a new circuit through a given relay, given a circuit target
/// `ct` specifying that relay. /// `ct` specifying that relay.
@ -54,7 +54,7 @@ pub(crate) trait Buildable: Sized {
ct: &OwnedCircTarget, ct: &OwnedCircTarget,
params: &CircParameters, params: &CircParameters,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<Self>; ) -> Result<Arc<Self>>;
/// Extend this circuit-like object by one hop, to the location described /// Extend this circuit-like object by one hop, to the location described
/// in `ct`. /// in `ct`.
@ -122,7 +122,7 @@ impl Buildable for ClientCirc {
ct: &OwnedChanTarget, ct: &OwnedChanTarget,
params: &CircParameters, params: &CircParameters,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<Self> { ) -> Result<Arc<Self>> {
let circ = create_common(chanmgr, rt, ct, guard_status, usage).await?; let circ = create_common(chanmgr, rt, ct, guard_status, usage).await?;
circ.create_firsthop_fast(params) circ.create_firsthop_fast(params)
.await .await
@ -139,7 +139,7 @@ impl Buildable for ClientCirc {
ct: &OwnedCircTarget, ct: &OwnedCircTarget,
params: &CircParameters, params: &CircParameters,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<Self> { ) -> Result<Arc<Self>> {
let circ = create_common(chanmgr, rt, ct, guard_status, usage).await?; let circ = create_common(chanmgr, rt, ct, guard_status, usage).await?;
circ.create_firsthop_ntor(ct, params.clone()) circ.create_firsthop_ntor(ct, params.clone())
.await .await
@ -214,7 +214,7 @@ impl<R: Runtime, C: Buildable + Sync + Send + 'static> Builder<R, C> {
n_hops_built: Arc<AtomicU32>, n_hops_built: Arc<AtomicU32>,
guard_status: Arc<GuardStatusHandle>, guard_status: Arc<GuardStatusHandle>,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<C> { ) -> Result<Arc<C>> {
match path { match path {
OwnedPath::ChannelOnly(target) => { OwnedPath::ChannelOnly(target) => {
// If we fail now, it's the guard's fault. // If we fail now, it's the guard's fault.
@ -276,7 +276,7 @@ impl<R: Runtime, C: Buildable + Sync + Send + 'static> Builder<R, C> {
params: &CircParameters, params: &CircParameters,
guard_status: Arc<GuardStatusHandle>, guard_status: Arc<GuardStatusHandle>,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<C> { ) -> Result<Arc<C>> {
let action = Action::BuildCircuit { length: path.len() }; let action = Action::BuildCircuit { length: path.len() };
let (timeout, abandon_timeout) = self.timeouts.timeouts(&action); let (timeout, abandon_timeout) = self.timeouts.timeouts(&action);
let start_time = self.runtime.now(); let start_time = self.runtime.now();
@ -420,7 +420,7 @@ impl<R: Runtime> CircuitBuilder<R> {
params: &CircParameters, params: &CircParameters,
guard_status: Arc<GuardStatusHandle>, guard_status: Arc<GuardStatusHandle>,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<ClientCirc> { ) -> Result<Arc<ClientCirc>> {
self.builder self.builder
.build_owned(path, params, guard_status, usage) .build_owned(path, params, guard_status, usage)
.await .await
@ -437,7 +437,7 @@ impl<R: Runtime> CircuitBuilder<R> {
path: &TorPath<'_>, path: &TorPath<'_>,
params: &CircParameters, params: &CircParameters,
usage: ChannelUsage, usage: ChannelUsage,
) -> Result<ClientCirc> { ) -> Result<Arc<ClientCirc>> {
let owned = path.try_into()?; let owned = path.try_into()?;
self.build_owned(owned, params, Arc::new(None.into()), usage) self.build_owned(owned, params, Arc::new(None.into()), usage)
.await .await
@ -710,7 +710,7 @@ mod test {
ct: &OwnedChanTarget, ct: &OwnedChanTarget,
_: &CircParameters, _: &CircParameters,
_usage: ChannelUsage, _usage: ChannelUsage,
) -> Result<Self> { ) -> Result<Arc<Self>> {
let (d1, d2) = timeouts_from_chantarget(ct); let (d1, d2) = timeouts_from_chantarget(ct);
rt.sleep(d1).await; rt.sleep(d1).await;
if !d2.is_zero() { if !d2.is_zero() {
@ -721,7 +721,7 @@ mod test {
hops: vec![RelayIds::from_relay_ids(ct)], hops: vec![RelayIds::from_relay_ids(ct)],
onehop: true, onehop: true,
}; };
Ok(Mutex::new(c)) Ok(Arc::new(Mutex::new(c)))
} }
async fn create<RT: Runtime>( async fn create<RT: Runtime>(
_: &ChanMgr<RT>, _: &ChanMgr<RT>,
@ -730,7 +730,7 @@ mod test {
ct: &OwnedCircTarget, ct: &OwnedCircTarget,
_: &CircParameters, _: &CircParameters,
_usage: ChannelUsage, _usage: ChannelUsage,
) -> Result<Self> { ) -> Result<Arc<Self>> {
let (d1, d2) = timeouts_from_chantarget(ct); let (d1, d2) = timeouts_from_chantarget(ct);
rt.sleep(d1).await; rt.sleep(d1).await;
if !d2.is_zero() { if !d2.is_zero() {
@ -741,7 +741,7 @@ mod test {
hops: vec![RelayIds::from_relay_ids(ct)], hops: vec![RelayIds::from_relay_ids(ct)],
onehop: false, onehop: false,
}; };
Ok(Mutex::new(c)) Ok(Arc::new(Mutex::new(c)))
} }
async fn extend<RT: Runtime>( async fn extend<RT: Runtime>(
&self, &self,

View File

@ -111,7 +111,7 @@ impl<R: Runtime> HsCircPool<R> {
pub async fn get_or_launch_client_rend<'a>( pub async fn get_or_launch_client_rend<'a>(
&self, &self,
netdir: &'a NetDir, netdir: &'a NetDir,
) -> Result<(ClientCirc, Relay<'a>)> { ) -> Result<(Arc<ClientCirc>, Relay<'a>)> {
// For rendezvous points, clients use 3-hop circuits. // For rendezvous points, clients use 3-hop circuits.
let circ = self let circ = self
.take_or_launch_stub_circuit::<OwnedCircTarget>(netdir, None) .take_or_launch_stub_circuit::<OwnedCircTarget>(netdir, None)
@ -142,7 +142,7 @@ impl<R: Runtime> HsCircPool<R> {
netdir: &NetDir, netdir: &NetDir,
kind: HsCircKind, kind: HsCircKind,
target: T, target: T,
) -> Result<ClientCirc> ) -> Result<Arc<ClientCirc>>
where where
T: CircTarget, T: CircTarget,
{ {
@ -218,7 +218,7 @@ impl<R: Runtime> HsCircPool<R> {
&self, &self,
netdir: &NetDir, netdir: &NetDir,
avoid_target: Option<&T>, avoid_target: Option<&T>,
) -> Result<ClientCirc> ) -> Result<Arc<ClientCirc>>
where where
T: CircTarget, T: CircTarget,
{ {

View File

@ -1,6 +1,6 @@
//! An internal pool object that we use to implement HsCircPool. //! An internal pool object that we use to implement HsCircPool.
use std::sync::Mutex; use std::sync::{Arc, Mutex};
use rand::{seq::IteratorRandom, Rng}; use rand::{seq::IteratorRandom, Rng};
use tor_proto::circuit::ClientCirc; use tor_proto::circuit::ClientCirc;
@ -9,7 +9,7 @@ use tor_proto::circuit::ClientCirc;
#[derive(Default)] #[derive(Default)]
pub(super) struct Pool { pub(super) struct Pool {
/// The collection of circuits themselves, in no particular order. /// The collection of circuits themselves, in no particular order.
circuits: Mutex<Vec<ClientCirc>>, circuits: Mutex<Vec<Arc<ClientCirc>>>,
} }
impl Pool { impl Pool {
@ -19,23 +19,23 @@ impl Pool {
} }
/// Add `circ` to this pool /// Add `circ` to this pool
pub(super) fn insert(&self, circ: ClientCirc) { pub(super) fn insert(&self, circ: Arc<ClientCirc>) {
self.circuits.lock().expect("lock poisoned").push(circ); self.circuits.lock().expect("lock poisoned").push(circ);
} }
/// Remove every circuit from this pool for which `f` returns false. /// Remove every circuit from this pool for which `f` returns false.
pub(super) fn retain<F>(&self, f: F) pub(super) fn retain<F>(&self, f: F)
where where
F: FnMut(&ClientCirc) -> bool, F: FnMut(&Arc<ClientCirc>) -> bool,
{ {
self.circuits.lock().expect("lock poisoned").retain(f); self.circuits.lock().expect("lock poisoned").retain(f);
} }
/// If there is any circuit in this pool for which `f` returns true, return one such circuit at random, and remove it from the pool. /// If there is any circuit in this pool for which `f` returns true, return one such circuit at random, and remove it from the pool.
pub(super) fn take_one_where<R, F>(&self, rng: &mut R, f: F) -> Option<ClientCirc> pub(super) fn take_one_where<R, F>(&self, rng: &mut R, f: F) -> Option<Arc<ClientCirc>>
where where
R: Rng, R: Rng,
F: Fn(&ClientCirc) -> bool, F: Fn(&Arc<ClientCirc>) -> bool,
{ {
let mut circuits = self.circuits.lock().expect("lock poisoned"); let mut circuits = self.circuits.lock().expect("lock poisoned");
// TODO HS: This ensures that we take a circuit at random, but at the // TODO HS: This ensures that we take a circuit at random, but at the

View File

@ -78,7 +78,7 @@ impl<R: Runtime> crate::mgr::AbstractCircBuilder for crate::build::CircuitBuilde
Ok((plan, final_spec)) Ok((plan, final_spec))
} }
async fn build_circuit(&self, plan: Plan) -> Result<(SupportedCircUsage, ClientCirc)> { async fn build_circuit(&self, plan: Plan) -> Result<(SupportedCircUsage, Arc<ClientCirc>)> {
use crate::build::GuardStatusHandle; use crate::build::GuardStatusHandle;
use tor_guardmgr::GuardStatus; use tor_guardmgr::GuardStatus;
let Plan { let Plan {

View File

@ -382,7 +382,7 @@ impl<R: Runtime> CircMgr<R> {
/// Return a circuit suitable for sending one-hop BEGINDIR streams, /// Return a circuit suitable for sending one-hop BEGINDIR streams,
/// launching it if necessary. /// launching it if necessary.
pub async fn get_or_launch_dir(&self, netdir: DirInfo<'_>) -> Result<ClientCirc> { pub async fn get_or_launch_dir(&self, netdir: DirInfo<'_>) -> Result<Arc<ClientCirc>> {
self.expire_circuits(); self.expire_circuits();
let usage = TargetCircUsage::Dir; let usage = TargetCircUsage::Dir;
self.mgr.get_or_launch(&usage, netdir).await.map(|(c, _)| c) self.mgr.get_or_launch(&usage, netdir).await.map(|(c, _)| c)
@ -398,7 +398,7 @@ impl<R: Runtime> CircMgr<R> {
netdir: DirInfo<'_>, // TODO: This has to be a NetDir. netdir: DirInfo<'_>, // TODO: This has to be a NetDir.
ports: &[TargetPort], ports: &[TargetPort],
isolation: StreamIsolation, isolation: StreamIsolation,
) -> Result<ClientCirc> { ) -> Result<Arc<ClientCirc>> {
self.expire_circuits(); self.expire_circuits();
let time = Instant::now(); let time = Instant::now();
{ {
@ -425,7 +425,7 @@ impl<R: Runtime> CircMgr<R> {
pub async fn get_or_launch_dir_specific<T: IntoOwnedChanTarget>( pub async fn get_or_launch_dir_specific<T: IntoOwnedChanTarget>(
&self, &self,
target: T, target: T,
) -> Result<ClientCirc> { ) -> Result<Arc<ClientCirc>> {
self.expire_circuits(); self.expire_circuits();
let usage = TargetCircUsage::DirSpecificTarget(target.to_owned()); let usage = TargetCircUsage::DirSpecificTarget(target.to_owned());
self.mgr self.mgr
@ -452,7 +452,7 @@ impl<R: Runtime> CircMgr<R> {
&self, &self,
planned_target: Option<T>, planned_target: Option<T>,
dir: &NetDir, dir: &NetDir,
) -> Result<ClientCirc> ) -> Result<Arc<ClientCirc>>
where where
T: IntoOwnedChanTarget, T: IntoOwnedChanTarget,
{ {

View File

@ -146,7 +146,7 @@ pub(crate) fn abstract_spec_find_supported<'a, 'b, S: AbstractSpec, C: AbstractC
/// ///
/// From this module's point of view, circuits are simply objects /// From this module's point of view, circuits are simply objects
/// with unique identities, and a possible closed-state. /// with unique identities, and a possible closed-state.
pub(crate) trait AbstractCirc: Clone + Debug { pub(crate) trait AbstractCirc: Debug {
/// Type for a unique identifier for circuits. /// Type for a unique identifier for circuits.
type Id: Clone + Debug + Hash + Eq + Send + Sync; type Id: Clone + Debug + Hash + Eq + Send + Sync;
/// Return the unique identifier for this circuit. /// Return the unique identifier for this circuit.
@ -233,7 +233,7 @@ pub(crate) trait AbstractCircBuilder: Send + Sync {
/// that was originally passed to `plan_circuit`. It _must_ also /// that was originally passed to `plan_circuit`. It _must_ also
/// contain the spec that was originally returned by /// contain the spec that was originally returned by
/// `plan_circuit`. /// `plan_circuit`.
async fn build_circuit(&self, plan: Self::Plan) -> Result<(Self::Spec, Self::Circ)>; async fn build_circuit(&self, plan: Self::Plan) -> Result<(Self::Spec, Arc<Self::Circ>)>;
/// Return a "parallelism factor" with which circuits should be /// Return a "parallelism factor" with which circuits should be
/// constructed for a given purpose. /// constructed for a given purpose.
@ -312,7 +312,7 @@ pub(crate) struct OpenEntry<S, C> {
/// Current AbstractCircSpec for this circuit's permitted usages. /// Current AbstractCircSpec for this circuit's permitted usages.
spec: S, spec: S,
/// The circuit under management. /// The circuit under management.
circ: C, circ: Arc<C>,
/// When does this circuit expire? /// When does this circuit expire?
/// ///
/// (Note that expired circuits are removed from the manager, /// (Note that expired circuits are removed from the manager,
@ -323,7 +323,7 @@ pub(crate) struct OpenEntry<S, C> {
impl<S: AbstractSpec, C: AbstractCirc> OpenEntry<S, C> { impl<S: AbstractSpec, C: AbstractCirc> OpenEntry<S, C> {
/// Make a new OpenEntry for a given circuit and spec. /// Make a new OpenEntry for a given circuit and spec.
fn new(spec: S, circ: C, expiration: ExpirationInfo) -> Self { fn new(spec: S, circ: Arc<C>, expiration: ExpirationInfo) -> Self {
OpenEntry { OpenEntry {
spec, spec,
circ, circ,
@ -744,7 +744,7 @@ pub(crate) struct AbstractCircMgr<B: AbstractCircBuilder, R: Runtime> {
/// An action to take in order to satisfy a request for a circuit. /// An action to take in order to satisfy a request for a circuit.
enum Action<B: AbstractCircBuilder> { enum Action<B: AbstractCircBuilder> {
/// We found an open circuit: return immediately. /// We found an open circuit: return immediately.
Open(B::Circ), Open(Arc<B::Circ>),
/// We found one or more pending circuits: wait until one succeeds, /// We found one or more pending circuits: wait until one succeeds,
/// or all fail. /// or all fail.
Wait(FuturesUnordered<Shared<oneshot::Receiver<PendResult<B>>>>), Wait(FuturesUnordered<Shared<oneshot::Receiver<PendResult<B>>>>),
@ -795,7 +795,7 @@ impl<B: AbstractCircBuilder + 'static, R: Runtime> AbstractCircMgr<B, R> {
self: &Arc<Self>, self: &Arc<Self>,
usage: &<B::Spec as AbstractSpec>::Usage, usage: &<B::Spec as AbstractSpec>::Usage,
dir: DirInfo<'_>, dir: DirInfo<'_>,
) -> Result<(B::Circ, CircProvenance)> { ) -> Result<(Arc<B::Circ>, CircProvenance)> {
/// Return CEIL(a/b). /// Return CEIL(a/b).
/// ///
/// Requires that a+b is less than usize::MAX. /// Requires that a+b is less than usize::MAX.
@ -1024,7 +1024,7 @@ impl<B: AbstractCircBuilder + 'static, R: Runtime> AbstractCircMgr<B, R> {
self: Arc<Self>, self: Arc<Self>,
act: Action<B>, act: Action<B>,
usage: &<B::Spec as AbstractSpec>::Usage, usage: &<B::Spec as AbstractSpec>::Usage,
) -> std::result::Result<(B::Circ, CircProvenance), RetryError<Box<Error>>> { ) -> std::result::Result<(Arc<B::Circ>, CircProvenance), RetryError<Box<Error>>> {
/// Store the error `err` into `retry_err`, as appropriate. /// Store the error `err` into `retry_err`, as appropriate.
fn record_error( fn record_error(
retry_err: &mut RetryError<Box<Error>>, retry_err: &mut RetryError<Box<Error>>,
@ -1380,7 +1380,7 @@ impl<B: AbstractCircBuilder + 'static, R: Runtime> AbstractCircMgr<B, R> {
&self, &self,
usage: &<B::Spec as AbstractSpec>::Usage, usage: &<B::Spec as AbstractSpec>::Usage,
dir: DirInfo<'_>, dir: DirInfo<'_>,
) -> Result<(<B as AbstractCircBuilder>::Spec, B::Circ)> { ) -> Result<(<B as AbstractCircBuilder>::Spec, Arc<B::Circ>)> {
let (_, plan) = self.plan_by_usage(dir, usage)?; let (_, plan) = self.plan_by_usage(dir, usage)?;
self.builder.build_circuit(plan.plan).await self.builder.build_circuit(plan.plan).await
} }
@ -1391,7 +1391,7 @@ impl<B: AbstractCircBuilder + 'static, R: Runtime> AbstractCircMgr<B, R> {
/// out to any future requests. /// out to any future requests.
/// ///
/// Return None if we have no circuit with the given ID. /// Return None if we have no circuit with the given ID.
pub(crate) fn take_circ(&self, id: &<B::Circ as AbstractCirc>::Id) -> Option<B::Circ> { pub(crate) fn take_circ(&self, id: &<B::Circ as AbstractCirc>::Id) -> Option<Arc<B::Circ>> {
let mut list = self.circs.lock().expect("poisoned lock"); let mut list = self.circs.lock().expect("poisoned lock");
list.take_open(id).map(|e| e.circ) list.take_open(id).map(|e| e.circ)
} }
@ -1703,14 +1703,14 @@ mod test {
Ok((plan, spec.clone())) Ok((plan, spec.clone()))
} }
async fn build_circuit(&self, plan: FakePlan) -> Result<(FakeSpec, FakeCirc)> { async fn build_circuit(&self, plan: FakePlan) -> Result<(FakeSpec, Arc<FakeCirc>)> {
let op = plan.op; let op = plan.op;
let sl = self.runtime.sleep(FAKE_CIRC_DELAY); let sl = self.runtime.sleep(FAKE_CIRC_DELAY);
self.runtime.allow_one_advance(FAKE_CIRC_DELAY); self.runtime.allow_one_advance(FAKE_CIRC_DELAY);
sl.await; sl.await;
match op { match op {
FakeOp::Succeed => Ok((plan.spec, FakeCirc { id: FakeId::next() })), FakeOp::Succeed => Ok((plan.spec, Arc::new(FakeCirc { id: FakeId::next() }))),
FakeOp::WrongSpec(s) => Ok((s, FakeCirc { id: FakeId::next() })), FakeOp::WrongSpec(s) => Ok((s, Arc::new(FakeCirc { id: FakeId::next() }))),
FakeOp::Fail => Err(Error::CircTimeout), FakeOp::Fail => Err(Error::CircTimeout),
FakeOp::Delay(d) => { FakeOp::Delay(d) => {
let sl = self.runtime.sleep(d); let sl = self.runtime.sleep(d);
@ -2240,7 +2240,7 @@ mod test {
#[test] #[test]
fn test_find_supported() { fn test_find_supported() {
let (ep_none, ep_web, ep_full) = get_exit_policies(); let (ep_none, ep_web, ep_full) = get_exit_policies();
let fake_circ = FakeCirc { id: FakeId::next() }; let fake_circ = Arc::new(FakeCirc { id: FakeId::next() });
let expiration = ExpirationInfo::Unused { let expiration = ExpirationInfo::Unused {
use_before: Instant::now() + Duration::from_secs(60 * 60), use_before: Instant::now() + Duration::from_secs(60 * 60),
}; };

View File

@ -66,7 +66,7 @@ pub(crate) async fn connect<R: Runtime>(
hsid: HsId, hsid: HsId,
data: &mut Data, data: &mut Data,
secret_keys: HsClientSecretKeys, secret_keys: HsClientSecretKeys,
) -> Result<ClientCirc, ConnError> { ) -> Result<Arc<ClientCirc>, ConnError> {
Context::new( Context::new(
&connector.runtime, &connector.runtime,
&*connector.circpool, &*connector.circpool,
@ -151,7 +151,7 @@ impl<'c, 'd, R: Runtime, M: MocksForConnect<R>> Context<'c, 'd, R, M> {
/// ///
/// This function handles all necessary retrying of fallible operations, /// This function handles all necessary retrying of fallible operations,
/// (and, therefore, must also limit the total work done for a particular call). /// (and, therefore, must also limit the total work done for a particular call).
async fn connect(&mut self) -> Result<ClientCirc, ConnError> { async fn connect(&mut self) -> Result<Arc<ClientCirc>, ConnError> {
// This function must do the following, retrying as appropriate. // This function must do the following, retrying as appropriate.
// - Look up the onion descriptor in the state. // - Look up the onion descriptor in the state.
// - Download the onion descriptor if one isn't there. // - Download the onion descriptor if one isn't there.
@ -376,7 +376,7 @@ trait MockableCircPool<R> {
netdir: &NetDir, netdir: &NetDir,
kind: HsCircKind, kind: HsCircKind,
target: OwnedCircTarget, target: OwnedCircTarget,
) -> tor_circmgr::Result<Self::ClientCirc>; ) -> tor_circmgr::Result<Arc<Self::ClientCirc>>;
} }
/// Mock for `ClientCirc` /// Mock for `ClientCirc`
#[async_trait] #[async_trait]
@ -402,7 +402,7 @@ impl<R: Runtime> MockableCircPool<R> for HsCircPool<R> {
netdir: &NetDir, netdir: &NetDir,
kind: HsCircKind, kind: HsCircKind,
target: OwnedCircTarget, target: OwnedCircTarget,
) -> tor_circmgr::Result<ClientCirc> { ) -> tor_circmgr::Result<Arc<ClientCirc>> {
self.get_or_launch_specific(netdir, kind, target).await self.get_or_launch_specific(netdir, kind, target).await
} }
} }
@ -426,7 +426,7 @@ impl MockableConnectorData for Data {
hsid: HsId, hsid: HsId,
data: &mut Self, data: &mut Self,
secret_keys: HsClientSecretKeys, secret_keys: HsClientSecretKeys,
) -> Result<Self::ClientCirc, ConnError> { ) -> Result<Arc<Self::ClientCirc>, ConnError> {
connect(connector, netdir, hsid, data, secret_keys).await connect(connector, netdir, hsid, data, secret_keys).await
} }
@ -502,10 +502,12 @@ mod test {
_netdir: &NetDir, _netdir: &NetDir,
kind: HsCircKind, kind: HsCircKind,
target: OwnedCircTarget, target: OwnedCircTarget,
) -> tor_circmgr::Result<Self::ClientCirc> { ) -> tor_circmgr::Result<Arc<Self::ClientCirc>> {
assert_eq!(kind, HsCircKind::ClientHsDir); assert_eq!(kind, HsCircKind::ClientHsDir);
self.mglobal.lock().unwrap().hsdirs_asked.push(target); self.mglobal.lock().unwrap().hsdirs_asked.push(target);
Ok(self.clone()) // Adding the `Arc` here is a little ugly, but that's what we get
// for using the same Mocks for everything.
Ok(Arc::new(self.clone()))
} }
} }
#[async_trait] #[async_trait]

View File

@ -118,7 +118,7 @@ impl<R: Runtime> HsClientConnector<R, connect::Data> {
hs_id: HsId, hs_id: HsId,
secret_keys: HsClientSecretKeys, secret_keys: HsClientSecretKeys,
isolation: StreamIsolation, isolation: StreamIsolation,
) -> impl Future<Output = Result<ClientCirc, ConnError>> + Send + Sync + 'r { ) -> impl Future<Output = Result<Arc<ClientCirc>, ConnError>> + Send + Sync + 'r {
// As in tor-circmgr, we take `StreamIsolation`, to ensure that callers in // As in tor-circmgr, we take `StreamIsolation`, to ensure that callers in
// arti-client pass us the final overall isolation, // arti-client pass us the final overall isolation,
// including the per-TorClient isolation. // including the per-TorClient isolation.

View File

@ -102,7 +102,7 @@ enum ServiceState<D: MockableConnectorData> {
data: D, data: D,
/// The circuit /// The circuit
#[educe(Debug(ignore))] #[educe(Debug(ignore))]
circuit: D::ClientCirc, circuit: Arc<D::ClientCirc>,
/// Last time we touched this, including reuse /// Last time we touched this, including reuse
last_used: Instant, last_used: Instant,
}, },
@ -202,7 +202,7 @@ fn obtain_circuit_or_continuation_info<D: MockableConnectorData>(
table_index: TableIndex, table_index: TableIndex,
rechecks: &mut impl Iterator, rechecks: &mut impl Iterator,
mut guard: MutexGuard<'_, Services<D>>, mut guard: MutexGuard<'_, Services<D>>,
) -> Result<Either<Continuation, D::ClientCirc>, ConnError> { ) -> Result<Either<Continuation, Arc<D::ClientCirc>>, ConnError> {
let blank_state = || ServiceState::blank(&connector.runtime); let blank_state = || ServiceState::blank(&connector.runtime);
for _recheck in rechecks { for _recheck in rechecks {
@ -383,7 +383,7 @@ impl<D: MockableConnectorData> Services<D> {
hs_id: HsId, hs_id: HsId,
isolation: Box<dyn Isolation>, isolation: Box<dyn Isolation>,
secret_keys: HsClientSecretKeys, secret_keys: HsClientSecretKeys,
) -> Result<D::ClientCirc, ConnError> { ) -> Result<Arc<D::ClientCirc>, ConnError> {
let blank_state = || ServiceState::blank(&connector.runtime); let blank_state = || ServiceState::blank(&connector.runtime);
let mut rechecks = 0..MAX_RECHECKS; let mut rechecks = 0..MAX_RECHECKS;
@ -466,7 +466,7 @@ impl<D: MockableConnectorData> Services<D> {
#[async_trait] #[async_trait]
pub trait MockableConnectorData: Default + Debug + Send + Sync + 'static { pub trait MockableConnectorData: Default + Debug + Send + Sync + 'static {
/// Client circuit /// Client circuit
type ClientCirc: Clone + Sync + Send + 'static; type ClientCirc: Sync + Send + 'static;
/// Mock state /// Mock state
type MockGlobalState: Clone + Sync + Send + 'static; type MockGlobalState: Clone + Sync + Send + 'static;
@ -478,7 +478,7 @@ pub trait MockableConnectorData: Default + Debug + Send + Sync + 'static {
hsid: HsId, hsid: HsId,
data: &mut Self, data: &mut Self,
secret_keys: HsClientSecretKeys, secret_keys: HsClientSecretKeys,
) -> Result<Self::ClientCirc, ConnError>; ) -> Result<Arc<Self::ClientCirc>, ConnError>;
/// Is circuit OK? Ie, not `.is_closing()`. /// Is circuit OK? Ie, not `.is_closing()`.
fn circuit_is_ok(circuit: &Self::ClientCirc) -> bool; fn circuit_is_ok(circuit: &Self::ClientCirc) -> bool;
@ -550,8 +550,8 @@ pub(crate) mod test {
_hsid: HsId, _hsid: HsId,
_data: &mut MockData, _data: &mut MockData,
_secret_keys: HsClientSecretKeys, _secret_keys: HsClientSecretKeys,
) -> Result<Self::ClientCirc, E> { ) -> Result<Arc<Self::ClientCirc>, E> {
let make = |()| MockCirc::new(); let make = |()| Arc::new(MockCirc::new());
let mut give = connector.mock_for_state.give.clone(); let mut give = connector.mock_for_state.give.clone();
if let Ready(ret) = &*give.borrow() { if let Ready(ret) = &*give.borrow() {
return ret.clone().map(make); return ret.clone().map(make);
@ -630,7 +630,7 @@ pub(crate) mod test {
id: u8, id: u8,
secret_keys: &HsClientSecretKeys, secret_keys: &HsClientSecretKeys,
isolation: Option<NarrowableIsolation>, isolation: Option<NarrowableIsolation>,
) -> Result<MockCirc, ConnError> { ) -> Result<Arc<MockCirc>, ConnError> {
let netdir = tor_netdir::testnet::construct_netdir() let netdir = tor_netdir::testnet::construct_netdir()
.unwrap_if_sufficient() .unwrap_if_sufficient()
.unwrap(); .unwrap();

View File

@ -0,0 +1,2 @@
BREAKING: APIs now return and accept Arc<ClientCirc>

View File

@ -91,7 +91,7 @@ pub const CIRCUIT_BUFFER_SIZE: usize = 128;
#[cfg_attr(docsrs, doc(cfg(feature = "send-control-msg")))] #[cfg_attr(docsrs, doc(cfg(feature = "send-control-msg")))]
pub use {msghandler::MsgHandler, reactor::MetaCellDisposition}; pub use {msghandler::MsgHandler, reactor::MetaCellDisposition};
#[derive(Clone, Debug)] #[derive(Debug)]
/// A circuit that we have constructed over the Tor network. /// A circuit that we have constructed over the Tor network.
/// ///
/// This struct is the interface used by the rest of the code, It is fairly /// This struct is the interface used by the rest of the code, It is fairly
@ -140,7 +140,7 @@ pub struct PendingClientCirc {
/// or a DESTROY cell. /// or a DESTROY cell.
recvcreated: oneshot::Receiver<CreateResponse>, recvcreated: oneshot::Receiver<CreateResponse>,
/// The ClientCirc object that we can expose on success. /// The ClientCirc object that we can expose on success.
circ: ClientCirc, circ: Arc<ClientCirc>,
} }
/// Description of the network's current rules for building circuits. /// Description of the network's current rules for building circuits.
@ -208,7 +208,7 @@ pub(crate) struct StreamTarget {
/// Channel to send cells down. /// Channel to send cells down.
tx: mpsc::Sender<AnyRelayMsg>, tx: mpsc::Sender<AnyRelayMsg>,
/// Reference to the circuit that this stream is on. /// Reference to the circuit that this stream is on.
circ: ClientCirc, circ: Arc<ClientCirc>,
} }
impl ClientCirc { impl ClientCirc {
@ -421,7 +421,7 @@ impl ClientCirc {
/// The caller will typically want to see the first cell in response, /// The caller will typically want to see the first cell in response,
/// to see whether it is e.g. an END or a CONNECTED. /// to see whether it is e.g. an END or a CONNECTED.
async fn begin_stream_impl( async fn begin_stream_impl(
&self, self: &Arc<ClientCirc>,
begin_msg: AnyRelayMsg, begin_msg: AnyRelayMsg,
cmd_checker: AnyCmdChecker, cmd_checker: AnyCmdChecker,
) -> Result<(StreamReader, StreamTarget)> { ) -> Result<(StreamReader, StreamTarget)> {
@ -469,7 +469,11 @@ impl ClientCirc {
/// Start a DataStream (anonymized connection) to the given /// Start a DataStream (anonymized connection) to the given
/// address and port, using a BEGIN cell. /// address and port, using a BEGIN cell.
async fn begin_data_stream(&self, msg: AnyRelayMsg, optimistic: bool) -> Result<DataStream> { async fn begin_data_stream(
self: &Arc<ClientCirc>,
msg: AnyRelayMsg,
optimistic: bool,
) -> Result<DataStream> {
let (reader, target) = self let (reader, target) = self
.begin_stream_impl(msg, DataCmdChecker::new_any()) .begin_stream_impl(msg, DataCmdChecker::new_any())
.await?; .await?;
@ -486,7 +490,7 @@ impl ClientCirc {
/// The use of a string for the address is intentional: you should let /// The use of a string for the address is intentional: you should let
/// the remote Tor relay do the hostname lookup for you. /// the remote Tor relay do the hostname lookup for you.
pub async fn begin_stream( pub async fn begin_stream(
&self, self: &Arc<ClientCirc>,
target: &str, target: &str,
port: u16, port: u16,
parameters: Option<StreamParameters>, parameters: Option<StreamParameters>,
@ -501,7 +505,7 @@ impl ClientCirc {
/// Start a new stream to the last relay in the circuit, using /// Start a new stream to the last relay in the circuit, using
/// a BEGIN_DIR cell. /// a BEGIN_DIR cell.
pub async fn begin_dir_stream(&self) -> Result<DataStream> { pub async fn begin_dir_stream(self: Arc<ClientCirc>) -> Result<DataStream> {
// Note that we always open begindir connections optimistically. // Note that we always open begindir connections optimistically.
// Since they are local to a relay that we've already authenticated // Since they are local to a relay that we've already authenticated
// with and built a circuit to, there should be no additional checks // with and built a circuit to, there should be no additional checks
@ -515,7 +519,7 @@ impl ClientCirc {
/// ///
/// Note that this function does not check for timeouts; that's /// Note that this function does not check for timeouts; that's
/// the caller's responsibility. /// the caller's responsibility.
pub async fn resolve(&self, hostname: &str) -> Result<Vec<IpAddr>> { pub async fn resolve(self: &Arc<ClientCirc>, hostname: &str) -> Result<Vec<IpAddr>> {
let resolve_msg = Resolve::new(hostname); let resolve_msg = Resolve::new(hostname);
let resolved_msg = self.try_resolve(resolve_msg).await?; let resolved_msg = self.try_resolve(resolve_msg).await?;
@ -536,7 +540,7 @@ impl ClientCirc {
/// ///
/// Note that this function does not check for timeouts; that's /// Note that this function does not check for timeouts; that's
/// the caller's responsibility. /// the caller's responsibility.
pub async fn resolve_ptr(&self, addr: IpAddr) -> Result<Vec<String>> { pub async fn resolve_ptr(self: &Arc<ClientCirc>, addr: IpAddr) -> Result<Vec<String>> {
let resolve_ptr_msg = Resolve::new_reverse(&addr); let resolve_ptr_msg = Resolve::new_reverse(&addr);
let resolved_msg = self.try_resolve(resolve_ptr_msg).await?; let resolved_msg = self.try_resolve(resolve_ptr_msg).await?;
@ -557,7 +561,7 @@ impl ClientCirc {
/// Helper: Send the resolve message, and read resolved message from /// Helper: Send the resolve message, and read resolved message from
/// resolve stream. /// resolve stream.
async fn try_resolve(&self, msg: Resolve) -> Result<Resolved> { async fn try_resolve(self: &Arc<ClientCirc>, msg: Resolve) -> Result<Resolved> {
let (reader, _) = self let (reader, _) = self
.begin_stream_impl(msg.into(), ResolveCmdChecker::new_any()) .begin_stream_impl(msg.into(), ResolveCmdChecker::new_any())
.await?; .await?;
@ -637,7 +641,7 @@ impl PendingClientCirc {
let pending = PendingClientCirc { let pending = PendingClientCirc {
recvcreated: createdreceiver, recvcreated: createdreceiver,
circ: circuit, circ: Arc::new(circuit),
}; };
(pending, reactor) (pending, reactor)
} }
@ -654,7 +658,7 @@ impl PendingClientCirc {
/// There's no authentication in CRATE_FAST, /// There's no authentication in CRATE_FAST,
/// so we don't need to know whom we're connecting to: we're just /// so we don't need to know whom we're connecting to: we're just
/// connecting to whichever relay the channel is for. /// connecting to whichever relay the channel is for.
pub async fn create_firsthop_fast(self, params: &CircParameters) -> Result<ClientCirc> { pub async fn create_firsthop_fast(self, params: &CircParameters) -> Result<Arc<ClientCirc>> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
self.circ self.circ
.control .control
@ -680,7 +684,7 @@ impl PendingClientCirc {
self, self,
target: &Tg, target: &Tg,
params: CircParameters, params: CircParameters,
) -> Result<ClientCirc> ) -> Result<Arc<ClientCirc>>
where where
Tg: tor_linkspec::CircTarget, Tg: tor_linkspec::CircTarget,
{ {
@ -799,7 +803,7 @@ impl StreamTarget {
/// Return a reference to the circuit that this `StreamTarget` is using. /// Return a reference to the circuit that this `StreamTarget` is using.
#[cfg(feature = "experimental-api")] #[cfg(feature = "experimental-api")]
pub(crate) fn circuit(&self) -> &ClientCirc { pub(crate) fn circuit(&self) -> &Arc<ClientCirc> {
&self.circ &self.circ
} }
} }
@ -1028,7 +1032,7 @@ mod test {
rt: &R, rt: &R,
chan: Channel, chan: Channel,
next_msg_from: HopNum, next_msg_from: HopNum,
) -> (ClientCirc, mpsc::Sender<ClientCircChanMsg>) { ) -> (Arc<ClientCirc>, mpsc::Sender<ClientCircChanMsg>) {
let circid = 128.into(); let circid = 128.into();
let (_created_send, created_recv) = oneshot::channel(); let (_created_send, created_recv) = oneshot::channel();
let (circmsg_send, circmsg_recv) = mpsc::channel(64); let (circmsg_send, circmsg_recv) = mpsc::channel(64);
@ -1070,7 +1074,7 @@ mod test {
async fn newcirc<R: Runtime>( async fn newcirc<R: Runtime>(
rt: &R, rt: &R,
chan: Channel, chan: Channel,
) -> (ClientCirc, mpsc::Sender<ClientCircChanMsg>) { ) -> (Arc<ClientCirc>, mpsc::Sender<ClientCircChanMsg>) {
newcirc_ext(rt, chan, 2.into()).await newcirc_ext(rt, chan, 2.into()).await
} }
@ -1372,7 +1376,7 @@ mod test {
rt: &R, rt: &R,
n_to_send: usize, n_to_send: usize,
) -> ( ) -> (
ClientCirc, Arc<ClientCirc>,
DataStream, DataStream,
mpsc::Sender<ClientCircChanMsg>, mpsc::Sender<ClientCircChanMsg>,
StreamId, StreamId,

View File

@ -109,7 +109,7 @@ pub struct DataStream {
/// DataWriterState, but for now we can't actually access that state all the time, /// DataWriterState, but for now we can't actually access that state all the time,
/// since it might be inside a boxed future. /// since it might be inside a boxed future.
#[cfg(feature = "experimental-api")] #[cfg(feature = "experimental-api")]
circuit: ClientCirc, circuit: std::sync::Arc<ClientCirc>,
} }
/// The write half of a [`DataStream`], implementing [`futures::io::AsyncWrite`]. /// The write half of a [`DataStream`], implementing [`futures::io::AsyncWrite`].