diff --git a/crates/arti-rpcserver/src/objmap.rs b/crates/arti-rpcserver/src/objmap.rs index 1fbaef780..4be9109dc 100644 --- a/crates/arti-rpcserver/src/objmap.rs +++ b/crates/arti-rpcserver/src/objmap.rs @@ -93,30 +93,34 @@ mod fake_generational_arena { /// A mechanism to look up RPC `Objects` by their `ObjectId`. #[derive(Default)] pub(crate) struct ObjMap { - /// Generationally indexed arena of object references. + /// Generationally indexed arena of strong object references. + strong_arena: Arena>, + /// Generationally indexed arena of weak object references. /// /// Invariants: - /// * No object has more than one weak reference in this arena. - /// * Every weak `entry` in this arena at position `idx` has a corresponding + /// * No object has more than one reference in this arena. + /// * Every `entry` in this arena at position `idx` has a corresponding /// entry in `reverse_map` entry such that /// `reverse_map[entry.tagged_addr()] == idx`. - arena: Arena, - /// Backwards reference to look up arena references by the underlying object identity. + weak_arena: Arena, + /// Backwards reference to look up weak arena references by the underlying + /// object identity. /// /// Invariants: - /// * For every weak `(addr,idx)` entry in this map, there is a corresponding - /// ArenaEntry in `arena` such that `arena[idx].tagged_addr() == addr` - reverse_map: HashMap, + /// * For every weak `(addr,idx)` entry in this map, there is a + /// corresponding ArenaEntry in `arena` such that + /// `arena[idx].tagged_addr() == addr` + reverse_map: HashMap, /// Testing only: How many times have we tidied this map? #[cfg(test)] n_tidies: usize, } -/// A single entry to an Object stored in the generational arena. +/// A single entry to a weak Object stored in the generational arena. /// -struct ArenaEntry { +struct WeakArenaEntry { /// The actual Arc or Weak reference for the object that we're storing here. - obj: ObjRef, + obj: Weak, /// /// This contains a strong or weak reference, along with the object's true TypeId. /// See the [`TaggedAddr`] for more info on @@ -124,35 +128,6 @@ struct ArenaEntry { id: any::TypeId, } -/// Strong or weak reference to an Object. -enum ObjRef { - /// A strong reference - Strong(Arc), - /// A weak reference - Weak(Weak), -} - -impl ObjRef { - /// Try to return a strong reference to this object, upgrading a weak - /// reference if needed. - /// - /// A `None` return indicates a dangling weak reference. - fn strong(&self) -> Option> { - match self { - ObjRef::Strong(s) => Some(s.clone()), - ObjRef::Weak(w) => Weak::upgrade(w), - } - } - - /// Return the [`RawAddr`] associated with this object. - fn raw_addr(&self) -> RawAddr { - match self { - ObjRef::Strong(s) => raw_addr_of(s), - ObjRef::Weak(w) => raw_addr_of_weak(w), - } - } -} - /// The raw address of an object held in an Arc or Weak. /// /// This will be the same for every clone of an Arc, and the same for every Weak @@ -202,7 +177,12 @@ struct TaggedAddr { /// A generational index for [`ObjMap`]. #[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub(crate) struct GenIdx(generational_arena::Index); +pub(crate) enum GenIdx { + /// An index into the arena of weak references. + Weak(generational_arena::Index), + /// An index into the arena of strong references + Strong(generational_arena::Index), +} /// Return the [`RawAddr`] of an arbitrary `Arc`. fn raw_addr_of(arc: &Arc) -> RawAddr { @@ -216,59 +196,37 @@ fn raw_addr_of_weak(arc: &Weak) -> RawAddr { RawAddr(Weak::as_ptr(arc) as *const () as usize) } -impl ArenaEntry { - /// Create a new `ArenaEntry` for a strong reference. - fn new_strong(object: Arc) -> Self { - let id = (*object).type_id(); - Self { - obj: ObjRef::Strong(object), - id, - } - } - - /// Create a new `ArenaEntry` for a weak reference. - fn new_weak(object: &Arc) -> Self { +impl WeakArenaEntry { + /// Create a new `WeakArenaEntry` for a weak reference. + fn new(object: &Arc) -> Self { let id = (**object).type_id(); Self { - obj: ObjRef::Weak(Arc::downgrade(object)), + obj: Arc::downgrade(object), id, } } /// Return true if this `ArenaEntry` is really present. /// - /// Note that this function can produce false positives (if the entry is Weak - /// and its last strong reference is dropped in another thread), but it can + /// Note that this function can produce false positives (if the entry's + /// last strong reference is dropped in another thread), but it can /// never produce false negatives. fn is_present(&self) -> bool { - match &self.obj { - ObjRef::Strong(_) => true, - ObjRef::Weak(w) => { - // This is safe from false negatives because: if we can ever - // observe strong_count == 0, then there is no way for anybody - // else to "resurrect" the object. - w.strong_count() > 0 - } - } + // This is safe from false negatives because: if we can ever + // observe strong_count == 0, then there is no way for anybody + // else to "resurrect" the object. + self.obj.strong_count() > 0 } /// Return a strong reference to the object in this entry, if possible. fn strong(&self) -> Option> { - match &self.obj { - ObjRef::Strong(s) => Some(Arc::clone(s)), - ObjRef::Weak(w) => Weak::upgrade(w), - } - } - - /// Return true if this is a weak reference. - fn is_weak(&self) -> bool { - matches!(&self.obj, ObjRef::Weak(_)) + Weak::upgrade(&self.obj) } /// Return the [`TaggedAddr`] that can be used to identify this entry's object. fn tagged_addr(&self) -> TaggedAddr { TaggedAddr { - addr: self.obj.raw_addr(), + addr: raw_addr_of_weak(&self.obj), type_id: self.id, } } @@ -308,10 +266,14 @@ impl GenIdx { use base64ct::Encoding; use rand::Rng; use tor_bytes::Writer; - let (a, b) = self.0.into_raw_parts(); - let x = rng.gen::(); + let (weak_bit, idx) = match self { + GenIdx::Weak(idx) => (1, idx), + GenIdx::Strong(idx) => (0, idx), + }; + let (a, b) = idx.into_raw_parts(); + let x = rng.gen::() << 1; let mut bytes = Vec::new(); - bytes.write_u64(x); + bytes.write_u64(x | weak_bit); bytes.write_u64((a as u64).wrapping_add(x)); bytes.write_u64(b.wrapping_sub(x)); rpc::ObjectId::from(base64ct::Base64UrlUnpadded::encode_string(&bytes[..])) @@ -330,6 +292,8 @@ impl GenIdx { .map_err(|_| rpc::LookupError::NoObject(id.clone())) }; let x = get_u64()?; + let is_weak = (x & 1) == 1; + let x = x & !1; let a = get_u64()?; let b = get_u64()?; r.should_be_exhausted() @@ -338,7 +302,12 @@ impl GenIdx { let a = a.wrapping_sub(x) as usize; let b = b.wrapping_add(x); - Ok(GenIdx(generational_arena::Index::from_raw_parts(a, b))) + let idx = generational_arena::Index::from_raw_parts(a, b); + if is_weak { + Ok(GenIdx::Weak(idx)) + } else { + Ok(GenIdx::Strong(idx)) + } } } @@ -348,7 +317,7 @@ impl ObjMap { Self::default() } - /// Reclaim unused space in this map. + /// Reclaim unused space in this map's weak arena. /// /// This runs in `O(n)` time. fn tidy(&mut self) { @@ -356,26 +325,26 @@ impl ObjMap { { self.n_tidies += 1; } - self.arena.retain(|index, entry| { + self.weak_arena.retain(|index, entry| { let present = entry.is_present(); if !present { // For everything we are removing from the `arena`, we must also // remove it from `reverse_map`. let ptr = entry.tagged_addr(); let found = self.reverse_map.remove(&ptr); - debug_assert_eq!(found, Some(GenIdx(index))); + debug_assert_eq!(found, Some(index)); } present }); } - /// If needed, clean this arena and resize it. + /// If needed, clean the weak arena and resize it. /// /// (We call this whenever we're about to add an entry. This ensures that /// our insertion operations run in `O(1)` time.) fn adjust_size(&mut self) { // If we're about to fill the arena... - if self.arena.len() >= self.arena.capacity() { + if self.weak_arena.len() >= self.weak_arena.capacity() { // ... we delete any dead `Weak` entries. self.tidy(); // Then, if the arena is still above half-full, we double the @@ -385,17 +354,15 @@ impl ObjMap { // entries, or else we might re-run tidy() too soon. But we don't // want to grow the arena if tidy() removed _most_ entries, or some // normal usage patterns will lead to unbounded growth.) - if self.arena.len() > self.arena.capacity() / 2 { - self.arena.reserve(self.arena.capacity()); + if self.weak_arena.len() > self.weak_arena.capacity() / 2 { + self.weak_arena.reserve(self.weak_arena.capacity()); } } } /// Unconditionally insert a strong entry for `value` in self, and return its index. pub(crate) fn insert_strong(&mut self, value: Arc) -> GenIdx { - self.adjust_size(); - - GenIdx(self.arena.insert(ArenaEntry::new_strong(value))) + GenIdx::Strong(self.strong_arena.insert(value)) } /// Ensure that there is a weak entry for `value` in self, and return an @@ -406,31 +373,38 @@ impl ObjMap { let ptr = TaggedAddr::for_object(&value); if let Some(idx) = self.reverse_map.get(&ptr) { #[cfg(debug_assertions)] - match self.arena.get(idx.0) { + match self.weak_arena.get(*idx) { Some(entry) => debug_assert!(entry.tagged_addr() == ptr), None => panic!("Found a dangling reference"), } - return *idx; + return GenIdx::Weak(*idx); } self.adjust_size(); - - let idx = GenIdx(self.arena.insert(ArenaEntry::new_weak(&value))); + let idx = self.weak_arena.insert(WeakArenaEntry::new(&value)); self.reverse_map.insert(ptr, idx); - idx + GenIdx::Weak(idx) } /// Return the entry from this ObjMap for `idx`. pub(crate) fn lookup(&self, idx: GenIdx) -> Option> { - self.arena.get(idx.0).and_then(ArenaEntry::strong) + match idx { + GenIdx::Weak(idx) => self.weak_arena.get(idx).and_then(WeakArenaEntry::strong), + GenIdx::Strong(idx) => self.strong_arena.get(idx).map(Arc::clone), + } } /// Remove the entry at `idx`, if any. pub(crate) fn remove(&mut self, idx: GenIdx) { - if let Some(entry) = self.arena.remove(idx.0) { - if entry.is_weak() { - let old_idx = self.reverse_map.remove(&entry.tagged_addr()); - debug_assert_eq!(old_idx, Some(idx)); + match idx { + GenIdx::Weak(idx) => { + if let Some(entry) = self.weak_arena.remove(idx) { + let old_idx = self.reverse_map.remove(&entry.tagged_addr()); + debug_assert_eq!(old_idx, Some(idx)); + } + } + GenIdx::Strong(idx) => { + self.strong_arena.remove(idx); } } } @@ -438,19 +412,16 @@ impl ObjMap { /// Testing only: Assert that every invariant for this structure is met. #[cfg(test)] fn assert_okay(&self) { - for (index, entry) in self.arena.iter() { - if !entry.is_weak() { - continue; - }; + for (index, entry) in self.weak_arena.iter() { let ptr = entry.tagged_addr(); - assert_eq!(self.reverse_map.get(&ptr), Some(&GenIdx(index))); + assert_eq!(self.reverse_map.get(&ptr), Some(&index)); assert_eq!(ptr, entry.tagged_addr()); } for (ptr, idx) in self.reverse_map.iter() { let entry = self - .arena - .get(idx.0) + .weak_arena + .get(*idx) .expect("Dangling pointer in reverse map"); assert_eq!(&entry.tagged_addr(), ptr); @@ -708,6 +679,7 @@ mod test { #[test] fn tidy() { let mut map = ObjMap::new(); + let mut keep_these = vec![]; let mut s = vec![]; let mut w = vec![]; for _ in 0..100 { @@ -717,7 +689,9 @@ mod test { w.push(map.insert_weak(o.clone())); t.push(o); } - s.push(map.insert_strong(Arc::new(ExampleObject("cafe".into())))); + let obj = Arc::new(ExampleObject("cafe".into())); + keep_these.push(obj.clone()); + s.push(map.insert_weak(obj)); drop(t); map.assert_okay(); } @@ -727,11 +701,11 @@ mod test { assert!(w.iter().all(|id| map.lookup(*id).is_none())); assert!(s.iter().all(|id| map.lookup(*id).is_some())); - assert_ne!(dbg!(map.arena.len()), 1100); + assert_ne!(map.weak_arena.len() + map.strong_arena.len(), 1100); map.assert_okay(); map.tidy(); map.assert_okay(); - assert_eq!(map.arena.len(), 100); + assert_eq!(map.weak_arena.len() + map.strong_arena.len(), 100); // This number is a bit arbitrary. assert!(dbg!(map.n_tidies) < 30); @@ -746,16 +720,21 @@ mod test { let mut map = ObjMap::new(); map.insert_strong(obj); map.insert_strong(wrap); - assert_eq!(map.arena.len(), 2); + assert_eq!(map.strong_arena.len(), 2); } #[test] fn objid_encoding() { use rand::Rng; fn test_roundtrip(a: usize, b: u64, rng: &mut tor_basic_utils::test_rng::TestingRng) { - let idx = GenIdx(generational_arena::Index::from_raw_parts(a, b)); - let s1 = dbg!(idx.encode_with_rng(rng)); - let s2 = dbg!(idx.encode_with_rng(rng)); + let idx = generational_arena::Index::from_raw_parts(a, b); + let idx = if rng.gen_bool(0.5) { + GenIdx::Strong(idx) + } else { + GenIdx::Weak(idx) + }; + let s1 = idx.encode_with_rng(rng); + let s2 = idx.encode_with_rng(rng); assert_ne!(s1, s2); assert_eq!(idx, GenIdx::try_decode(&s1).unwrap()); assert_eq!(idx, GenIdx::try_decode(&s2).unwrap());