From 5a3892a64ab4618bec1653654263378d2d4b39bb Mon Sep 17 00:00:00 2001 From: Oleg Andreev Date: Thu, 26 Apr 2018 11:15:39 -0700 Subject: [PATCH] Implement fast sum of powers for any n --- src/util.rs | 126 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 36 deletions(-) diff --git a/src/util.rs b/src/util.rs index 007b725f..f877759f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -84,8 +84,8 @@ impl Poly2 { } /// Raises `x` to the power `n` using binary exponentiation, -/// with (1 to 2)*lg(n) scalar multiplications. -/// TODO: a consttime version of this would be awfully similar to a Montgomery ladder. +/// with `(1 to 2)*lg(n)` scalar multiplications. +/// TODO: a consttime version of this would be similar to a Montgomery ladder. pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { let mut result = Scalar::one(); let mut aux = *x; // x, x^2, x^4, x^8, ... @@ -95,38 +95,85 @@ pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar { result = result * aux; } n = n >> 1; - aux = aux * aux; // FIXME: one unnecessary mult at the last step here! + if n > 0 { + aux = aux * aux; + } } result } -/// Takes the sum of all the powers of `x`, up to `n` -/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplcations and additions. -/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions. -/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2. -pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar { - if !n.is_power_of_two() { - return sum_of_powers_slow(x, n); - } - if n == 0 || n == 1 { - return Scalar::from_u64(n as u64); - } - let mut m = n; - let mut result = Scalar::one() + x; - let mut factor = *x; - while m > 2 { - factor = factor * factor; - result = result + factor * result; - m = m / 2; +/// Computes the sum of all the powers of \\(x\\) \\(S(n) = (x^0 + \dots + x^{n-1})\\) +/// using \\(O(\lg n)\\) multiplications and additions. Length \\(n\\) is not considered secret +/// and algorithm is fastest when \\(n\\) is the power of two. +/// +/// ### Algorithm overview +/// +/// First, let \\(n\\) be a power of two. +/// Then, we can divide the polynomial in two halves like so: +/// \\[ +/// \begin{aligned} +/// (1+\dots+x^{n-1}) &=\\\\ +/// (1+\dots+x^{n/2-1}) + x^{n/2} (1+\dots+x^{n/2-1}) &=\\\\ +/// s + x^{n/2} s. +/// \end{aligned} +/// \\] +/// We can divide each \\(s\\) in half until we arrive to a degree-1 polynomial \\((1+x\cdot 1)\\). +/// Recursively, the total sum can be defined as: +/// \\[ +/// \begin{aligned} +/// S(0) &= 0 \\\\ +/// S(n) &= s_{\lg n} \\\\ +/// s_0 &= 1 \\\\ +/// s_i &= s_{i-1} + x^{2^{i-1}} s_{i-1} +/// \end{aligned} +/// \\] +/// This representation allows us to square \\(x\\) only \\(\lg n\\) times. +/// +/// Lets apply this to \\(n\\) which is not a power of two (\\(2^{k-1} < n < 2^k\\)) which can be represented in binary using +/// bits \\(b_i\\) in \\(\\{0,1\\}\\): +/// \\[ +/// n = b_0 2^0 + \dots + b_{k-1} 2^{k-1} +/// \\] +/// If we scan the bits of \\(n\\) from low to high (\\(i \in [0,k)\\)), +/// we can conditionally (if \\(b_i = 1\\)) add to a resulting scalar +/// an intermediate polynomial with \\(2^i\\) terms using the above algorithm, +/// provided we offset the polynomial by \\(x^{n_i}\\), the next power of \\(x\\) +/// for the existing sum, where \\(n_i = \sum_{j=0}^{i-1} b_j 2^j\\). +/// +/// The full algorithm becomes: +/// \\[ +/// \begin{aligned} +/// S(0) &= 0 & \\\\ +/// S(1) &= 1 & \\\\ +/// S(i) &= S(i-1) + x^{n_i} s_i b_i\\\\ +/// &= S(i) + x^{n_{i-1}} (x^{2^{i-1}})^{b_{i-1}} s_i b_i +/// \end{aligned} +/// \\] +pub fn sum_of_powers(x: &Scalar, mut n: usize) -> Scalar { + let mut result = Scalar::zero(); + let mut f = Scalar::one(); // power of x to offset interim polynomial + let mut s = Scalar::one(); + let mut p = *x; // x, x^2, x^4, ..., x^{2^i} + while n > 0 { + // take a bit from n + let bit = n & 1; + n = n >> 1; + + if bit == 1 { + // bits of `n` are not secret, so it's okay to be vartime because of `n` value. + result += f * s; + if n > 0 { // avoid multiplication if no bits left + f = f * p; + } + } + if n > 0 { // avoid multiplication if no bits left + s = s + p * s; + p = p * p; + } } result } -// takes the sum of all of the powers of x, up to n -fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar { - exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x) -} - #[cfg(test)] mod tests { use super::*; @@ -185,9 +232,14 @@ mod tests { ); } + // takes the sum of all of the powers of x, up to n + fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar { + exp_iter(*x).take(n).fold(Scalar::zero(), |acc, x| acc + x) + } + #[test] - fn test_sum_of_powers() { - let x = Scalar::from_u64(10); + fn test_sum_of_powers_pow2() { + let x = Scalar::from_u64(1337133713371337); assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0)); assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1)); assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2)); @@ -199,14 +251,16 @@ mod tests { } #[test] - fn test_sum_of_powers_slow() { + fn test_sum_of_powers_non_pow2() { let x = Scalar::from_u64(10); - assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero()); - assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one()); - assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from_u64(11)); - assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from_u64(111)); - assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from_u64(1111)); - assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from_u64(11111)); - assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from_u64(111111)); + assert_eq!(sum_of_powers(&x, 0), Scalar::zero()); + assert_eq!(sum_of_powers(&x, 1), Scalar::one()); + assert_eq!(sum_of_powers(&x, 2), Scalar::from_u64(11)); + assert_eq!(sum_of_powers(&x, 3), Scalar::from_u64(111)); + assert_eq!(sum_of_powers(&x, 4), Scalar::from_u64(1111)); + assert_eq!(sum_of_powers(&x, 5), Scalar::from_u64(11111)); + assert_eq!(sum_of_powers(&x, 6), Scalar::from_u64(111111)); + assert_eq!(sum_of_powers(&x, 7), Scalar::from_u64(1111111)); + assert_eq!(sum_of_powers(&x, 8), Scalar::from_u64(11111111)); } }