From 6385f8ef63e050bf847b0a6bb03835b534667cf1 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 23 Apr 2024 18:06:36 -0700 Subject: [PATCH] feat: MSM skip doubling when window has all zeros --- Cargo.toml | 2 +- src/msm.rs | 144 +++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 102 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0810983f..440f6ba3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2curves-axiom" -version = "0.6.1" +version = "0.6.2" authors = ["Privacy Scaling Explorations team", "Taiko Labs", "Intrinsic Technologies"] license = "MIT/Apache-2.0" edition = "2021" diff --git a/src/msm.rs b/src/msm.rs index 25af9711..24fe0ca9 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -4,6 +4,7 @@ use crate::CurveAffine; use ff::Field; use ff::PrimeField; use group::Group; +use rayon::iter::IntoParallelIterator; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; @@ -287,6 +288,7 @@ impl Schedule { } pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + // Do conversion to bytes once let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let c = if bases.len() < 4 { @@ -299,7 +301,34 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1; - for current_window in (0..number_of_windows).rev() { + // In each window, get the booth index of each coefficient + let mut coeffs_in_windows = Vec::with_capacity(number_of_windows); + // Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0 + let mut max_nonzero_window = None; + for current_window in 0..number_of_windows { + let coeffs_in_window: Vec = coeffs + .iter() + .map(|coeff| { + let coeff = get_booth_index(current_window, c, coeff.as_ref()); + if coeff != 0 { + max_nonzero_window = Some(current_window); + } + coeff + }) + .collect(); + coeffs_in_windows.push(coeffs_in_window); + } + // Save memory and drop coeffs as bytes since it's not needed anymore + drop(coeffs); + + if max_nonzero_window.is_none() { + return; + } + for coeffs_in_window in coeffs_in_windows + .into_iter() + .take(max_nonzero_window.unwrap() + 1) + .rev() + { for _ in 0..c { *acc = acc.double(); } @@ -337,8 +366,7 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & let mut buckets: Vec> = vec![Bucket::None; 1 << (c - 1)]; - for (coeff, base) in coeffs.iter().zip(bases.iter()) { - let coeff = get_booth_index(current_window, c, coeff.as_ref()); + for (coeff, base) in coeffs_in_window.into_iter().zip(bases.iter()) { if coeff.is_positive() { buckets[coeff as usize - 1].add_assign(base); } @@ -422,52 +450,82 @@ pub fn best_multiexp_independent_points( // number of windows let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1; - // accumumator for each window - let mut acc = vec![C::Curve::identity(); number_of_windows]; - acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| { - // jacobian buckets for already scheduled points - let mut j_bucks = vec![Bucket::::None; 1 << (c - 1)]; - // schedular for affine addition - let mut sched = Schedule::new(c); - - for (base_idx, coeff) in coeffs.iter().enumerate() { - let buck_idx = get_booth_index(w, c, coeff.as_ref()); - - if buck_idx != 0 { - // parse bucket index - let sign = buck_idx.is_positive(); - let buck_idx = buck_idx.unsigned_abs() as usize - 1; + // In each window, get the booth index of each coefficient + let mut coeffs_in_windows = Vec::with_capacity(number_of_windows); + // Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0 + let mut max_nonzero_window = None; + for current_window in 0..number_of_windows { + let coeffs_in_window: Vec = coeffs + .iter() + .map(|coeff| { + let coeff = get_booth_index(current_window, c, coeff.as_ref()); + if coeff != 0 { + max_nonzero_window = Some(current_window); + } + coeff + }) + .collect(); + coeffs_in_windows.push(coeffs_in_window); + } + // Save memory and drop coeffs as bytes since it's not needed anymore + drop(coeffs); - if sched.contains(buck_idx) { - // greedy accumulation - // we use original bases here - j_bucks[buck_idx].add_assign(&bases[base_idx], sign); - } else { - // also flushes the schedule if full - sched.add(&bases_local, base_idx, buck_idx, sign); + if max_nonzero_window.is_none() { + // Everything is zero + return C::Curve::identity(); + } + let number_of_windows = max_nonzero_window.unwrap() + 1; + // accumumator for each window + let mut acc = vec![C::Curve::identity(); number_of_windows]; + coeffs_in_windows + .into_par_iter() + .take(number_of_windows) + .zip(acc.par_iter_mut()) + .enumerate() + .rev() + .for_each(|(w, (coeffs_in_window, acc))| { + // jacobian buckets for already scheduled points + let mut j_bucks = vec![Bucket::::None; 1 << (c - 1)]; + + // schedular for affine addition + let mut sched = Schedule::new(c); + + for (base_idx, buck_idx) in coeffs_in_window.into_iter().enumerate() { + if buck_idx != 0 { + // parse bucket index + let sign = buck_idx.is_positive(); + let buck_idx = buck_idx.unsigned_abs() as usize - 1; + + if sched.contains(buck_idx) { + // greedy accumulation + // we use original bases here + j_bucks[buck_idx].add_assign(&bases[base_idx], sign); + } else { + // also flushes the schedule if full + sched.add(&bases_local, base_idx, buck_idx, sign); + } } } - } - // flush the schedule - sched.execute(&bases_local); - - // summation by parts - // e.g. 3a + 2b + 1c = a + - // (a) + b + - // ((a) + b) + c - let mut running_sum = C::Curve::identity(); - for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() { - running_sum += j_buck.add(a_buck); - *acc += running_sum; - } + // flush the schedule + sched.execute(&bases_local); + + // summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() { + running_sum += j_buck.add(a_buck); + *acc += running_sum; + } - // shift accumulator to the window position - for _ in 0..c * w { - *acc = acc.double(); - } - }); + // shift accumulator to the window position + for _ in 0..c * w { + *acc = acc.double(); + } + }); acc.into_iter().sum::<_>() }