random_idx_where: Ensure uniform distribution of choice

Previously, this was more likely to select elements that occurred after
other elements that didn't satisfy the predicate.
This commit is contained in:
Jim Newsome 2023-06-22 20:07:59 -05:00
parent 35e6cc285b
commit 7a8bade262
1 changed files with 32 additions and 22 deletions

View File

@ -81,7 +81,7 @@ impl Pool {
F: Fn(&Arc<ClientCirc>) -> bool, F: Fn(&Arc<ClientCirc>) -> bool,
{ {
// Select a circuit satisfying `f` at random. // Select a circuit satisfying `f` at random.
let rv = match random_idx_where(rng, &self.circuits[..], f) { let rv = match random_idx_where(rng, &mut self.circuits[..], f) {
Some(idx) => Some(self.circuits.swap_remove(idx)), Some(idx) => Some(self.circuits.swap_remove(idx)),
None => None, None => None,
}; };
@ -134,23 +134,26 @@ impl Pool {
/// Helper: find a random item `elt` in `slice` such that `predicate(elt)` is /// Helper: find a random item `elt` in `slice` such that `predicate(elt)` is
/// true. Return the index of that item. /// true. Return the index of that item.
/// ///
/// /// Can arbitrarily reorder `slice`. This allows us to visit the indices in uniform-at-random
/// We optimize for the assumption that most elements of `slice` will satisfy /// order, without having to do any O(N) operations or allocations.
/// the predicate, and so we won't usually need to look over the whole slice as fn random_idx_where<R, T, P>(rng: &mut R, mut slice: &mut [T], predicate: P) -> Option<usize>
/// we would do if we had used [`IteratorRandom`](rand::seq::IteratorRandom).
fn random_idx_where<R, T, P>(rng: &mut R, slice: &[T], predicate: P) -> Option<usize>
where where
R: Rng, R: Rng,
P: Fn(&T) -> bool, P: Fn(&T) -> bool,
{ {
let n_circuits = slice.len(); while !slice.is_empty() {
if n_circuits == 0 { let idx = rng.gen_range(0..slice.len());
return None; if predicate(&slice[idx]) {
return Some(idx);
}
let last_idx = slice.len() - 1;
// Move the one we just tried to the end,
// and eliminate it from consideration.
slice.swap(idx, last_idx);
slice = &mut slice[..last_idx];
} }
let shift = rng.gen_range(0..n_circuits); // We didn't find any.
(shift..n_circuits) None
.chain(0..shift)
.find(|idx| predicate(&slice[*idx]))
} }
#[cfg(test)] #[cfg(test)]
@ -172,34 +175,41 @@ mod test {
#[test] #[test]
fn random_idx() { fn random_idx() {
let mut rng = testing_rng(); let mut rng = testing_rng();
let numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27]; let mut orig_numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27];
let mut numbers = orig_numbers.clone();
let mut found = vec![false; numbers.len()]; let mut found: std::collections::HashMap<i32, bool> =
numbers.iter().map(|n| (*n, false)).collect();
for _ in 0..1000 { for _ in 0..1000 {
let idx = random_idx_where(&mut rng, &numbers[..], |n| n & 1 == 1).unwrap(); let idx = random_idx_where(&mut rng, &mut numbers[..], |n| n & 1 == 1).unwrap();
assert!(numbers[idx] & 1 == 1); assert!(numbers[idx] & 1 == 1);
found[idx] = true; found.insert(numbers[idx], true);
} }
for (idx, num) in numbers.iter().enumerate() { for num in numbers.iter() {
assert!(found[idx] == (num & 1 == 1)); assert!(found[num] == (num & 1 == 1));
} }
// Number may be reordered, but should still have the same elements.
numbers.sort();
orig_numbers.sort();
assert_eq!(numbers, orig_numbers);
} }
#[test] #[test]
fn random_idx_empty() { fn random_idx_empty() {
let mut rng = testing_rng(); let mut rng = testing_rng();
let idx = random_idx_where(&mut rng, &[], |_: &i32| panic!()); let idx = random_idx_where(&mut rng, &mut [], |_: &i32| panic!());
assert_eq!(idx, None); assert_eq!(idx, None);
} }
#[test] #[test]
fn random_idx_none() { fn random_idx_none() {
let mut rng = testing_rng(); let mut rng = testing_rng();
let numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27]; let mut numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27];
assert_eq!( assert_eq!(
random_idx_where(&mut rng, &numbers[..], |_: &i32| false), random_idx_where(&mut rng, &mut numbers[..], |_: &i32| false),
None None
); );
} }