random_idx_where: Ensure uniform distribution of choice

Previously, this was more likely to select elements that occurred after
other elements that didn't satisfy the predicate.
This commit is contained in:
Jim Newsome 2023-06-22 20:07:59 -05:00
parent 35e6cc285b
commit 7a8bade262
1 changed files with 32 additions and 22 deletions

View File

@ -81,7 +81,7 @@ impl Pool {
F: Fn(&Arc<ClientCirc>) -> bool,
{
// Select a circuit satisfying `f` at random.
let rv = match random_idx_where(rng, &self.circuits[..], f) {
let rv = match random_idx_where(rng, &mut self.circuits[..], f) {
Some(idx) => Some(self.circuits.swap_remove(idx)),
None => None,
};
@ -134,23 +134,26 @@ impl Pool {
/// Helper: find a random item `elt` in `slice` such that `predicate(elt)` is
/// true. Return the index of that item.
///
///
/// We optimize for the assumption that most elements of `slice` will satisfy
/// the predicate, and so we won't usually need to look over the whole slice as
/// we would do if we had used [`IteratorRandom`](rand::seq::IteratorRandom).
fn random_idx_where<R, T, P>(rng: &mut R, slice: &[T], predicate: P) -> Option<usize>
/// Can arbitrarily reorder `slice`. This allows us to visit the indices in uniform-at-random
/// order, without having to do any O(N) operations or allocations.
fn random_idx_where<R, T, P>(rng: &mut R, mut slice: &mut [T], predicate: P) -> Option<usize>
where
R: Rng,
P: Fn(&T) -> bool,
{
let n_circuits = slice.len();
if n_circuits == 0 {
return None;
while !slice.is_empty() {
let idx = rng.gen_range(0..slice.len());
if predicate(&slice[idx]) {
return Some(idx);
}
let last_idx = slice.len() - 1;
// Move the one we just tried to the end,
// and eliminate it from consideration.
slice.swap(idx, last_idx);
slice = &mut slice[..last_idx];
}
let shift = rng.gen_range(0..n_circuits);
(shift..n_circuits)
.chain(0..shift)
.find(|idx| predicate(&slice[*idx]))
// We didn't find any.
None
}
#[cfg(test)]
@ -172,34 +175,41 @@ mod test {
#[test]
fn random_idx() {
let mut rng = testing_rng();
let numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27];
let mut orig_numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27];
let mut numbers = orig_numbers.clone();
let mut found = vec![false; numbers.len()];
let mut found: std::collections::HashMap<i32, bool> =
numbers.iter().map(|n| (*n, false)).collect();
for _ in 0..1000 {
let idx = random_idx_where(&mut rng, &numbers[..], |n| n & 1 == 1).unwrap();
let idx = random_idx_where(&mut rng, &mut numbers[..], |n| n & 1 == 1).unwrap();
assert!(numbers[idx] & 1 == 1);
found[idx] = true;
found.insert(numbers[idx], true);
}
for (idx, num) in numbers.iter().enumerate() {
assert!(found[idx] == (num & 1 == 1));
for num in numbers.iter() {
assert!(found[num] == (num & 1 == 1));
}
// Number may be reordered, but should still have the same elements.
numbers.sort();
orig_numbers.sort();
assert_eq!(numbers, orig_numbers);
}
#[test]
fn random_idx_empty() {
let mut rng = testing_rng();
let idx = random_idx_where(&mut rng, &[], |_: &i32| panic!());
let idx = random_idx_where(&mut rng, &mut [], |_: &i32| panic!());
assert_eq!(idx, None);
}
#[test]
fn random_idx_none() {
let mut rng = testing_rng();
let numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27];
let mut numbers: Vec<i32> = vec![1, 3, 4, 8, 11, 19, 12, 6, 27];
assert_eq!(
random_idx_where(&mut rng, &numbers[..], |_: &i32| false),
random_idx_where(&mut rng, &mut numbers[..], |_: &i32| false),
None
);
}