Skip to content

Commit

Permalink
Performance improvements for shuffle and partial_shuffle (#1272)
Browse files Browse the repository at this point in the history
* Made shuffle and partial_shuffle faster
* Use criterion benchmarks for shuffle
* Added a note about RNG word size
* Tidied comments
* Added a debug_assert
* Added a comment re possible further optimization
* Added and updated copyright notices
* Revert cfg mistake
* Reverted change to mod.rs
* Removed ChaCha20 benches from shuffle
* moved debug_assert out of a const fn
  • Loading branch information
wainwrightmark authored Jan 8, 2023
1 parent 1e96eb4 commit 4bde8a0
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 19 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,9 @@ criterion = { version = "0.4" }
[[bench]]
name = "seq_choose"
path = "benches/seq_choose.rs"
harness = false

[[bench]]
name = "shuffle"
path = "benches/shuffle.rs"
harness = false
2 changes: 1 addition & 1 deletion benches/seq_choose.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2022 Developers of the Rand project.
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand Down
50 changes: 50 additions & 0 deletions benches/shuffle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::prelude::*;
use rand::SeedableRng;

criterion_group!(
name = benches;
config = Criterion::default();
targets = bench
);
criterion_main!(benches);

pub fn bench(c: &mut Criterion) {
bench_rng::<rand_chacha::ChaCha12Rng>(c, "ChaCha12");
bench_rng::<rand_pcg::Pcg32>(c, "Pcg32");
bench_rng::<rand_pcg::Pcg64>(c, "Pcg64");
}

fn bench_rng<Rng: RngCore + SeedableRng>(c: &mut Criterion, rng_name: &'static str) {
for length in [1, 2, 3, 10, 100, 1000, 10000].map(|x| black_box(x)) {
c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| {
let mut rng = Rng::seed_from_u64(123);
let mut vec: Vec<usize> = (0..length).collect();
b.iter(|| {
vec.shuffle(&mut rng);
vec[0]
})
});

if length >= 10 {
c.bench_function(
format!("partial_shuffle_{length}_{rng_name}").as_str(),
|b| {
let mut rng = Rng::seed_from_u64(123);
let mut vec: Vec<usize> = (0..length).collect();
b.iter(|| {
vec.partial_shuffle(&mut rng, length / 2);
vec[0]
})
},
);
}
}
}
8 changes: 8 additions & 0 deletions src/seq/coin_flipper.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::RngCore;

pub(crate) struct CoinFlipper<R: RngCore> {
Expand Down
108 changes: 108 additions & 0 deletions src/seq/increasing_uniform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::{Rng, RngCore};

/// Similar to a Uniform distribution,
/// but after returning a number in the range [0,n], n is increased by 1.
pub(crate) struct IncreasingUniform<R: RngCore> {
pub rng: R,
n: u32,
// Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) )
chunk: u32,
chunk_remaining: u8,
}

impl<R: RngCore> IncreasingUniform<R> {
/// Create a dice roller.
/// The next item returned will be a random number in the range [0,n]
pub fn new(rng: R, n: u32) -> Self {
// If n = 0, the first number returned will always be 0
// so we don't need to generate a random number
let chunk_remaining = if n == 0 { 1 } else { 0 };
Self {
rng,
n,
chunk: 0,
chunk_remaining,
}
}

/// Returns a number in [0,n] and increments n by 1.
/// Generates new random bits as needed
/// Panics if `n >= u32::MAX`
#[inline]
pub fn next_index(&mut self) -> usize {
let next_n = self.n + 1;

// There's room for further optimisation here:
// gen_range uses rejection sampling (or other method; see #1196) to avoid bias.
// When the initial sample is biased for range 0..bound
// it may still be viable to use for a smaller bound
// (especially if small biases are considered acceptable).

let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| {
// If the chunk is empty, generate a new chunk
let (bound, remaining) = calculate_bound_u32(next_n);
// bound = (n + 1) * (n + 2) *..* (n + remaining)
self.chunk = self.rng.gen_range(0..bound);
// Chunk is a random number in
// [0, (n + 1) * (n + 2) *..* (n + remaining) )

remaining - 1
});

let result = if next_chunk_remaining == 0 {
// `chunk` is a random number in the range [0..n+1)
// Because `chunk_remaining` is about to be set to zero
// we do not need to clear the chunk here
self.chunk as usize
} else {
// `chunk` is a random number in a range that is a multiple of n+1
// so r will be a random number in [0..n+1)
let r = self.chunk % next_n;
self.chunk /= next_n;
r as usize
};

self.chunk_remaining = next_chunk_remaining;
self.n = next_n;
result
}
}

#[inline]
/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1)
fn calculate_bound_u32(m: u32) -> (u32, u8) {
debug_assert!(m > 0);
#[inline]
const fn inner(m: u32) -> (u32, u8) {
let mut product = m;
let mut current = m + 1;

loop {
if let Some(p) = u32::checked_mul(product, current) {
product = p;
current += 1;
} else {
// Count has a maximum value of 13 for when min is 1 or 2
let count = (current - m) as u8;
return (product, count);
}
}
}

const RESULT2: (u32, u8) = inner(2);
if m == 2 {
// Making this value a constant instead of recalculating it
// gives a significant (~50%) performance boost for small shuffles
return RESULT2;
}

inner(m)
}
51 changes: 33 additions & 18 deletions src/seq/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018 Developers of the Rand project.
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand Down Expand Up @@ -29,6 +29,8 @@ mod coin_flipper;
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub mod index;

mod increasing_uniform;

#[cfg(feature = "alloc")]
use core::ops::Index;

Expand All @@ -42,6 +44,7 @@ use crate::distributions::WeightedError;
use crate::Rng;

use self::coin_flipper::CoinFlipper;
use self::increasing_uniform::IncreasingUniform;

/// Extension trait on slices, providing random mutation and sampling methods.
///
Expand Down Expand Up @@ -620,10 +623,11 @@ impl<T> SliceRandom for [T] {
where
R: Rng + ?Sized,
{
for i in (1..self.len()).rev() {
// invariant: elements with index > i have been locked in place.
self.swap(i, gen_index(rng, i + 1));
if self.len() <= 1 {
// There is no need to shuffle an empty or single element slice
return;
}
self.partial_shuffle(rng, self.len());
}

fn partial_shuffle<R>(
Expand All @@ -632,19 +636,30 @@ impl<T> SliceRandom for [T] {
where
R: Rng + ?Sized,
{
// This applies Durstenfeld's algorithm for the
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
// for an unbiased permutation, but exits early after choosing `amount`
// elements.

let len = self.len();
let end = if amount >= len { 0 } else { len - amount };
let m = self.len().saturating_sub(amount);

for i in (end..len).rev() {
// invariant: elements with index > i have been locked in place.
self.swap(i, gen_index(rng, i + 1));
// The algorithm below is based on Durstenfeld's algorithm for the
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
// for an unbiased permutation.
// It ensures that the last `amount` elements of the slice
// are randomly selected from the whole slice.

//`IncreasingUniform::next_index()` is faster than `gen_index`
//but only works for 32 bit integers
//So we must use the slow method if the slice is longer than that.
if self.len() < (u32::MAX as usize) {
let mut chooser = IncreasingUniform::new(rng, m as u32);
for i in m..self.len() {
let index = chooser.next_index();
self.swap(i, index);
}
} else {
for i in m..self.len() {
let index = gen_index(rng, i + 1);
self.swap(i, index);
}
}
let r = self.split_at_mut(end);
let r = self.split_at_mut(m);
(r.1, r.0)
}
}
Expand Down Expand Up @@ -765,11 +780,11 @@ mod test {

let mut r = crate::test::rng(414);
nums.shuffle(&mut r);
assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]);
assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]);
nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let res = nums.partial_shuffle(&mut r, 6);
assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]);
assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]);
assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]);
assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]);
}

#[derive(Clone)]
Expand Down

0 comments on commit 4bde8a0

Please sign in to comment.