From dcec17ceb078d0161759081f95c27eabc43f09c5 Mon Sep 17 00:00:00 2001 From: Eduard S Date: Tue, 23 Jul 2024 00:54:49 -0700 Subject: [PATCH] feat: skip zeroes in msm (#168) * feat: skip zeroes in msm * Update src/msm.rs Co-authored-by: David Nevado --------- Co-authored-by: David Nevado --- Cargo.toml | 12 ++++-- benches/msm.rs | 110 ++++++++++++++++++++++++++++++++++--------------- src/msm.rs | 20 ++++++++- 3 files changed, 105 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0810983f..02eebedd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,11 @@ [package] name = "halo2curves-axiom" version = "0.6.1" -authors = ["Privacy Scaling Explorations team", "Taiko Labs", "Intrinsic Technologies"] +authors = [ + "Privacy Scaling Explorations team", + "Taiko Labs", + "Intrinsic Technologies", +] license = "MIT/Apache-2.0" edition = "2021" repository = "https://github.com/axiom-crypto/halo2curves" @@ -39,7 +43,10 @@ num-traits = "0.2" paste = "1.0.11" serde = { version = "1.0", default-features = false, optional = true } serde_arrays = { version = "0.1.0", optional = true } -hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } +hex = { version = "0.4", optional = true, default-features = false, features = [ + "alloc", + "serde", +] } blake2b_simd = "1" rayon = "1.8" digest = "0.10.7" @@ -87,4 +94,3 @@ harness = false [[bench]] name = "msm" harness = false -required-features = ["multicore"] diff --git a/benches/msm.rs b/benches/msm.rs index 7c38ed3a..2a98c525 100644 --- a/benches/msm.rs +++ b/benches/msm.rs @@ -13,14 +13,14 @@ extern crate criterion; use criterion::{BenchmarkId, Criterion}; -use ff::Field; +use ff::{Field, PrimeField}; use group::prime::PrimeCurveAffine; use halo2curves_axiom::bn256::{Fr as Scalar, G1Affine as Point}; use halo2curves_axiom::msm::{best_multiexp, multiexp_serial}; -use maybe_rayon::current_thread_index; -use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}; -use rand_core::SeedableRng; +use rand_core::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; +use rayon::current_thread_index; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use std::time::SystemTime; const SAMPLE_SIZE: usize = 10; @@ -30,15 +30,15 @@ const SEED: [u8; 16] = [ 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5, ]; -fn generate_coefficients_and_curvepoints(k: u8) -> (Vec, Vec) { +fn generate_curvepoints(k: u8) -> Vec { let n: u64 = { assert!(k < 64); 1 << k }; - println!("\n\nGenerating 2^{k} = {n} coefficients and curve points..",); + println!("Generating 2^{k} = {n} curve points..",); let timer = SystemTime::now(); - let coeffs = (0..n) + let bases = (0..n) .into_par_iter() .map_init( || { @@ -51,10 +51,36 @@ fn generate_coefficients_and_curvepoints(k: u8) -> (Vec, Vec) { } XorShiftRng::from_seed(thread_seed) }, - |rng, _| Scalar::random(rng), + |rng, _| Point::random(rng), ) .collect(); - let bases = (0..n) + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} curve points took: {} sec.\n\n", + end.as_secs() + ); + bases +} + +fn generate_coefficients(k: u8, bits: usize) -> Vec { + let n: u64 = { + assert!(k < 64); + 1 << k + }; + let max_val: Option = match bits { + 1 => Some(1), + 8 => Some(0xff), + 16 => Some(0xffff), + 32 => Some(0xffff_ffff), + 64 => Some(0xffff_ffff_ffff_ffff), + 128 => Some(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff), + 256 => None, + _ => panic!("unexpected bit size {}", bits), + }; + + println!("Generating 2^{k} = {n} coefficients..",); + let timer = SystemTime::now(); + let coeffs = (0..n) .into_par_iter() .map_init( || { @@ -67,16 +93,25 @@ fn generate_coefficients_and_curvepoints(k: u8) -> (Vec, Vec) { } XorShiftRng::from_seed(thread_seed) }, - |rng, _| Point::random(rng), + |rng, _| { + if let Some(max_val) = max_val { + let v_lo = rng.next_u64() as u128; + let v_hi = rng.next_u64() as u128; + let mut v = v_lo + (v_hi << 64); + v &= max_val; // Mask the 128bit value to get a lower number of bits + Scalar::from_u128(v) + } else { + Scalar::random(rng) + } + }, ) .collect(); let end = timer.elapsed().unwrap(); println!( - "Generating 2^{k} = {n} coefficients and curve points took: {} sec.\n\n", + "Generating 2^{k} = {n} coefficients took: {} sec.\n\n", end.as_secs() ); - - (coeffs, bases) + coeffs } fn msm(c: &mut Criterion) { @@ -86,28 +121,37 @@ fn msm(c: &mut Criterion) { .chain(MULTICORE_RANGE.iter()) .max() .unwrap_or(&16); - let (coeffs, bases) = generate_coefficients_and_curvepoints(max_k); + let bases = generate_curvepoints(max_k); + let bits = [1, 8, 16, 32, 64, 128, 256]; + let coeffs: Vec<_> = bits + .iter() + .map(|b| generate_coefficients(max_k, *b)) + .collect(); - for k in SINGLECORE_RANGE { - group - .bench_function(BenchmarkId::new("singlecore", k), |b| { - assert!(k < 64); - let n: usize = 1 << k; - let mut acc = Point::identity().into(); - b.iter(|| multiexp_serial(&coeffs[..n], &bases[..n], &mut acc)); - }) - .sample_size(10); - } - for k in MULTICORE_RANGE { - group - .bench_function(BenchmarkId::new("multicore", k), |b| { - assert!(k < 64); - let n: usize = 1 << k; - b.iter(|| { - best_multiexp(&coeffs[..n], &bases[..n]); + for (b_index, b) in bits.iter().enumerate() { + for k in SINGLECORE_RANGE { + let id = format!("{b}b_{k}"); + group + .bench_function(BenchmarkId::new("singlecore", id), |b| { + assert!(k < 64); + let n: usize = 1 << k; + let mut acc = Point::identity().into(); + b.iter(|| multiexp_serial(&coeffs[b_index][..n], &bases[..n], &mut acc)); + }) + .sample_size(10); + } + for k in MULTICORE_RANGE { + let id = format!("{b}b_{k}"); + group + .bench_function(BenchmarkId::new("multicore", id), |b| { + assert!(k < 64); + let n: usize = 1 << k; + b.iter(|| { + best_multiexp(&coeffs[b_index][..n], &bases[..n]); + }) }) - }) - .sample_size(SAMPLE_SIZE); + .sample_size(SAMPLE_SIZE); + } } group.finish(); } diff --git a/src/msm.rs b/src/msm.rs index 25af9711..1aab35ef 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -297,7 +297,25 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & (f64::from(bases.len() as u32)).ln().ceil() as usize }; - let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1; + let field_byte_size = C::Scalar::NUM_BITS.div_ceil(8u32) as usize; + // OR all coefficients in order to make a mask to figure out the maximum number of bytes used + // among all coefficients. + let mut acc_or = vec![0; field_byte_size]; + for coeff in &coeffs { + for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) { + *acc_limb = *acc_limb | *limb; + } + } + let max_byte_size = field_byte_size + - acc_or + .iter() + .rev() + .position(|v| *v != 0) + .unwrap_or(field_byte_size); + if max_byte_size == 0 { + return; + } + let number_of_windows = max_byte_size * 8 as usize / c + 1; for current_window in (0..number_of_windows).rev() { for _ in 0..c {