diff --git a/crates/tor-basic-utils/src/n_key_set.rs b/crates/tor-basic-utils/src/n_key_set.rs index f72cf1f9e..f50aafe89 100644 --- a/crates/tor-basic-utils/src/n_key_set.rs +++ b/crates/tor-basic-utils/src/n_key_set.rs @@ -218,6 +218,32 @@ This could be more efficient in space and time. { self.[<$key _map>].get($key).copied().map(|old_idx| self.remove_at(old_idx).expect("inconsistent state")) } + + + #[doc = concat!("Modify the element with the given value for `", stringify!($key), " by applying `func` to it.")] + /// + /// `func` is allowed to change the keys for this value. All indices + /// are updated to refer to the new keys. If the new keys conflict with + /// any previous values, those values are replaced and returned in a + /// vector. + /// + /// If `func` causes the value to have no keys at all, then the value + /// itself is also removed and returned in the result vector. + /// + /// Note that because this function needs to copy all key values and check whether + /// they have changed, it is not terribly efficient. + $vis fn [] (&mut self, $key: &BorrowAsKey_, func: F_) -> Vec<$V> + where + $KEY : std::borrow::Borrow, + BorrowAsKey_: std::hash::Hash + Eq + ?Sized, + F_: FnOnce(&mut $V) + { + if let Some(idx) = self.[<$key _map>].get($key) { + self.modify_at(*idx, func) + } else { + Vec::new() + } + } )+ /// Return an iterator over the elements in this container. @@ -324,7 +350,7 @@ This could be more efficient in space and time. $( if let Some($key) = $crate::n_key_set!( @access(removed, ($($($flag)+)?) $key : $KEY $({$($source)+})?) ) { let old_idx = self.[<$key _map>].remove($key); - debug_assert_eq!(old_idx, Some(idx)); + assert_eq!(old_idx, Some(idx)); } )* Some(removed) @@ -333,6 +359,72 @@ This could be more efficient in space and time. } } + /// Change the value at `idx` by applying `func` to it. + /// + /// `func` is allowed to change the keys for this value. All indices + /// are updated to refer to the new keys. If the new keys conflict with + /// any previous values, those values are replaced and returned in a + /// vector. + /// + /// If `func` causes the value to have no keys at all, then the value + /// itself is also removed and returned in the result vector. + /// + /// # Panics + /// + /// Panics if `idx` is not present in this set. + fn modify_at(&mut self, idx: usize, func: F_) -> Vec<$V> + where + F_: FnOnce(&mut $V) + { + let value = self.values.get_mut(idx).expect("invalid index"); + $( + let [] = $crate::n_key_set!( @access(value, ($($($flag)+)?) $key : $KEY $({$($source)+})?) ) + .map(|elt| elt.to_owned()) ; + )+ + + func(value); + + // Check whether any keys have changed, and whether there still are + // any keys. + $( + let [] = $crate::n_key_set!( @access( value, ($($($flag)+)?) $key : $KEY $({$($source)+})?) ) ; + )+ + let keys_changed = $( [].as_ref().map(std::borrow::Borrow::borrow) != [] )||+ ; + + if keys_changed { + let found_any_keys = $( [].is_some() )||+ ; + + // Remove this value from every place that it was before. + // + // We can't use remove_at, since we have changed the keys in the + // value: we have to remove them manually from each index + // instead. + $( + if let Some(orig) = [] { + let removed = self.[<$key _map>].remove(&orig); + assert_eq!(removed, Some(idx)); + } + )+ + // Remove the value from its previous place in the index. (This + // results in an extra copy when we call insert(), but if we + // didn't do it, we'd need to reimplement `insert()`.) + let removed = self.values.remove(idx); + if found_any_keys { + // This item belongs: put it back and return the vector of + // whatever was replaced.j + self.insert(removed) + } else { + // This item does not belong any longer, since all its keys + // were removed. + vec![removed] + } + } else { + // We did not change any keys, so we know we have not replaced + // any items. + vec![] + } + } + /// Re-index all the values in this map, so that the map can use a more /// compact representation. /// @@ -550,6 +642,31 @@ mod test { set.check_invariants(); } + #[test] + fn modify_value() { + let mut set: Tuple2Set = (1..=100).map(|idx| (idx, idx * idx)).collect(); + set.check_invariants(); + + let v = set.modify_by_first(&30, |elt| elt.1 = 256); + set.check_invariants(); + // one element was replaced. + assert_eq!(v.len(), 1); + assert_eq!(v[0], (16, 256)); + assert_eq!(set.by_second(&256).unwrap(), &(30, 256)); + assert_eq!(set.by_first(&30).unwrap(), &(30, 256)); + + let v = set.modify_by_first(&30, |elt| *elt = (-100, -100)); + set.check_invariants(); + // no elements were replaced. + assert_eq!(v.len(), 0); + assert_eq!(set.by_first(&30), None); + assert_eq!(set.by_second(&256), None); + assert_eq!(set.by_first(&-100).unwrap(), &(-100, -100)); + assert_eq!(set.by_second(&-100).unwrap(), &(-100, -100)); + + set.check_invariants(); + } + #[allow(dead_code)] struct Weekday { dow: u8,