Skip to content

Commit

Permalink
Implement fast sum of powers for any n
Browse files Browse the repository at this point in the history
  • Loading branch information
oleganza committed Apr 26, 2018
1 parent a473b6e commit 5a3892a
Showing 1 changed file with 90 additions and 36 deletions.
126 changes: 90 additions & 36 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ impl Poly2 {
}

/// Raises `x` to the power `n` using binary exponentiation,
/// with (1 to 2)*lg(n) scalar multiplications.
/// TODO: a consttime version of this would be awfully similar to a Montgomery ladder.
/// with `(1 to 2)*lg(n)` scalar multiplications.
/// TODO: a consttime version of this would be similar to a Montgomery ladder.
pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
let mut result = Scalar::one();
let mut aux = *x; // x, x^2, x^4, x^8, ...
Expand All @@ -95,38 +95,85 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
result = result * aux;
}
n = n >> 1;
aux = aux * aux; // FIXME: one unnecessary mult at the last step here!
if n > 0 {
aux = aux * aux;
}
}
result
}

/// Takes the sum of all the powers of `x`, up to `n`
/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplcations and additions.
/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
if !n.is_power_of_two() {
return sum_of_powers_slow(x, n);
}
if n == 0 || n == 1 {
return Scalar::from_u64(n as u64);
}
let mut m = n;
let mut result = Scalar::one() + x;
let mut factor = *x;
while m > 2 {
factor = factor * factor;
result = result + factor * result;
m = m / 2;
/// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\)
/// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret
/// and algorithm is fastest when \\(n\\) is the power of two.
///
/// ### Algorithm overview
///
/// First, let \\(n\\) be a power of two.
/// Then, we can divide the polynomial in two halves like so:
/// \\[
/// \begin{aligned}
/// (1+\dots+x^{n-1}) &=\\\\
/// (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) &=\\\\
/// s + x^{n/2} s.
/// \end{aligned}
/// \\]
/// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\).
/// Recursively, the total sum can be defined as:
/// \\[
/// \begin{aligned}
/// S(0) &= 0 \\\\
/// S(n) &= s_{\lg n} \\\\
/// s_0 &= 1 \\\\
/// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1}
/// \end{aligned}
/// \\]
/// This representation allows us to square \\(x\\) only \\(\lg n\\) times.
///
/// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using
/// bits \\(b_i\\) in \\(\\{0,1\\}\\):
/// \\[
/// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1}
/// \\]
/// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)),
/// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar
/// an intermediate polynomial with \\(2^i\\) terms using the above algorithm,
/// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\)
/// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\).
///
/// The full algorithm becomes:
/// \\[
/// \begin{aligned}
/// S(0) &= 0 & \\\\
/// S(1) &= 1 & \\\\
/// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\
/// &= S(i) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i
/// \end{aligned}
/// \\]
pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar {
let mut result = Scalar::zero();
let mut f = Scalar::one(); // power of x to offset interim polynomial
let mut s = Scalar::one();
let mut p = *x; // x, x^2, x^4, ..., x^{2^i}
while n > 0 {
// take a bit from n
let bit = n & 1;
n = n >> 1;

if bit == 1 {
// bits of `n` are not secret, so it's okay to be vartime because of `n` value.
result += f * s;
if n > 0 { // avoid multiplication if no bits left
f = f * p;
}
}
if n > 0 { // avoid multiplication if no bits left
s = s + p * s;
p = p * p;
}
}
result
}

// takes the sum of all of the powers of x, up to n
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -185,9 +232,14 @@ mod tests {
);
}

// takes the sum of all of the powers of x, up to n
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x)
}

#[test]
fn test_sum_of_powers() {
let x = Scalar::from_u64(10);
fn test_sum_of_powers_pow2() {
let x = Scalar::from_u64(1337133713371337);
assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0));
assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1));
assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2));
Expand All @@ -199,14 +251,16 @@ mod tests {
}

#[test]
fn test_sum_of_powers_slow() {
fn test_sum_of_powers_non_pow2() {
let x = Scalar::from_u64(10);
assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero());
assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one());
assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from_u64(11));
assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from_u64(111));
assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from_u64(1111));
assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from_u64(11111));
assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from_u64(111111));
assert_eq!(sum_of_powers(&x, 0), Scalar::zero());
assert_eq!(sum_of_powers(&x, 1), Scalar::one());
assert_eq!(sum_of_powers(&x, 2), Scalar::from_u64(11));
assert_eq!(sum_of_powers(&x, 3), Scalar::from_u64(111));
assert_eq!(sum_of_powers(&x, 4), Scalar::from_u64(1111));
assert_eq!(sum_of_powers(&x, 5), Scalar::from_u64(11111));
assert_eq!(sum_of_powers(&x, 6), Scalar::from_u64(111111));
assert_eq!(sum_of_powers(&x, 7), Scalar::from_u64(1111111));
assert_eq!(sum_of_powers(&x, 8), Scalar::from_u64(11111111));
}
}

0 comments on commit 5a3892a

Please sign in to comment.