Skip to content

Commit

Permalink
feat: MSM skip doubling when window has all zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanpwang committed Apr 24, 2024
1 parent 7613f82 commit 6385f8e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "halo2curves-axiom"
version = "0.6.1"
version = "0.6.2"
authors = ["Privacy Scaling Explorations team", "Taiko Labs", "Intrinsic Technologies"]
license = "MIT/Apache-2.0"
edition = "2021"
Expand Down
144 changes: 101 additions & 43 deletions src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::CurveAffine;
use ff::Field;
use ff::PrimeField;
use group::Group;
use rayon::iter::IntoParallelIterator;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};
Expand Down Expand Up @@ -287,6 +288,7 @@ impl<C: CurveAffine> Schedule<C> {
}

pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
// Do conversion to bytes once
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
Expand All @@ -299,7 +301,34 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &

let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;

for current_window in (0..number_of_windows).rev() {
// In each window, get the booth index of each coefficient
let mut coeffs_in_windows = Vec::with_capacity(number_of_windows);
// Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_window = None;
for current_window in 0..number_of_windows {
let coeffs_in_window: Vec<i32> = coeffs
.iter()
.map(|coeff| {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
if coeff != 0 {
max_nonzero_window = Some(current_window);
}
coeff
})
.collect();
coeffs_in_windows.push(coeffs_in_window);
}
// Save memory and drop coeffs as bytes since it's not needed anymore
drop(coeffs);

if max_nonzero_window.is_none() {
return;
}
for coeffs_in_window in coeffs_in_windows
.into_iter()
.take(max_nonzero_window.unwrap() + 1)
.rev()
{
for _ in 0..c {
*acc = acc.double();
}
Expand Down Expand Up @@ -337,8 +366,7 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];

for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
for (coeff, base) in coeffs_in_window.into_iter().zip(bases.iter()) {
if coeff.is_positive() {
buckets[coeff as usize - 1].add_assign(base);
}
Expand Down Expand Up @@ -422,52 +450,82 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(

// number of windows
let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;
// accumumator for each window
let mut acc = vec![C::Curve::identity(); number_of_windows];
acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| {
// jacobian buckets for already scheduled points
let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];

// schedular for affine addition
let mut sched = Schedule::new(c);

for (base_idx, coeff) in coeffs.iter().enumerate() {
let buck_idx = get_booth_index(w, c, coeff.as_ref());

if buck_idx != 0 {
// parse bucket index
let sign = buck_idx.is_positive();
let buck_idx = buck_idx.unsigned_abs() as usize - 1;
// In each window, get the booth index of each coefficient
let mut coeffs_in_windows = Vec::with_capacity(number_of_windows);
// Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_window = None;
for current_window in 0..number_of_windows {
let coeffs_in_window: Vec<i32> = coeffs
.iter()
.map(|coeff| {
let coeff = get_booth_index(current_window, c, coeff.as_ref());
if coeff != 0 {
max_nonzero_window = Some(current_window);
}
coeff
})
.collect();
coeffs_in_windows.push(coeffs_in_window);
}
// Save memory and drop coeffs as bytes since it's not needed anymore
drop(coeffs);

if sched.contains(buck_idx) {
// greedy accumulation
// we use original bases here
j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
} else {
// also flushes the schedule if full
sched.add(&bases_local, base_idx, buck_idx, sign);
if max_nonzero_window.is_none() {
// Everything is zero
return C::Curve::identity();
}
let number_of_windows = max_nonzero_window.unwrap() + 1;
// accumumator for each window
let mut acc = vec![C::Curve::identity(); number_of_windows];
coeffs_in_windows
.into_par_iter()
.take(number_of_windows)
.zip(acc.par_iter_mut())
.enumerate()
.rev()
.for_each(|(w, (coeffs_in_window, acc))| {
// jacobian buckets for already scheduled points
let mut j_bucks = vec![Bucket::<C>::None; 1 << (c - 1)];

// schedular for affine addition
let mut sched = Schedule::new(c);

for (base_idx, buck_idx) in coeffs_in_window.into_iter().enumerate() {
if buck_idx != 0 {
// parse bucket index
let sign = buck_idx.is_positive();
let buck_idx = buck_idx.unsigned_abs() as usize - 1;

if sched.contains(buck_idx) {
// greedy accumulation
// we use original bases here
j_bucks[buck_idx].add_assign(&bases[base_idx], sign);
} else {
// also flushes the schedule if full
sched.add(&bases_local, base_idx, buck_idx, sign);
}
}
}
}

// flush the schedule
sched.execute(&bases_local);

// summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
running_sum += j_buck.add(a_buck);
*acc += running_sum;
}
// flush the schedule
sched.execute(&bases_local);

// summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() {
running_sum += j_buck.add(a_buck);
*acc += running_sum;
}

// shift accumulator to the window position
for _ in 0..c * w {
*acc = acc.double();
}
});
// shift accumulator to the window position
for _ in 0..c * w {
*acc = acc.double();
}
});
acc.into_iter().sum::<_>()
}

Expand Down

0 comments on commit 6385f8e

Please sign in to comment.