From 723c9762fbb09bf3e82ca9fcb4ed0adc12d11b5a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 21 Aug 2023 11:14:00 -0600 Subject: [PATCH 1/7] Add field conversion to/from `[u64;4]` (#80) * feat: add field conversion to/from `[u64;4]` * Added conversion tests * Added `montgomery_reduce_short` for no-asm * For bn256, uses assembly conversion when asm feature is on * fix: remove conflict for asm * chore: bump rust-toolchain to 1.67.0 --- rust-toolchain | 2 +- src/bn256/assembly.rs | 8 +++++++ src/bn256/fq.rs | 25 ++++++++++------------ src/bn256/fr.rs | 25 ++++++++++------------ src/derive/field.rs | 49 +++++++++++++++++++++++++++++++++++++++++++ src/secp256k1/fp.rs | 20 ++++++++++-------- src/secp256k1/fq.rs | 20 ++++++++++-------- src/secp256r1/fp.rs | 28 +++++++++++++------------ src/secp256r1/fq.rs | 20 ++++++++++-------- src/tests/field.rs | 16 ++++++++++++++ 10 files changed, 144 insertions(+), 69 deletions(-) diff --git a/rust-toolchain b/rust-toolchain index 7cc6ef41..77c582d8 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.63.0 \ No newline at end of file +1.67.0 \ No newline at end of file diff --git a/src/bn256/assembly.rs b/src/bn256/assembly.rs index a466005a..a8c1984d 100644 --- a/src/bn256/assembly.rs +++ b/src/bn256/assembly.rs @@ -608,6 +608,14 @@ macro_rules! field_arithmetic_asm { $field([r0, r1, r2, r3]) } } + + impl From<$field> for [u64; 4] { + fn from(elt: $field) -> [u64; 4] { + // Turn into canonical form by computing + // (a.R) / R = a + elt.montgomery_reduce_256().0 + } + } }; } diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index b8e1383a..0e2fc10f 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -3,7 +3,7 @@ use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] use crate::{field_arithmetic, field_specific}; -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::bn256::LegendreSymbol; use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ @@ -271,20 +271,12 @@ impl ff::PrimeField for Fq { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - - #[cfg(not(feature = "asm"))] - let tmp = - Self::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - #[cfg(feature = "asm")] - let tmp = self.montgomery_reduce_256(); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -384,6 +376,11 @@ mod test { crate::tests::field::random_field_tests::("fq".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("fq".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 890c12e8..cd422d4b 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -18,7 +18,7 @@ pub use table::FR_TABLE; #[cfg(not(feature = "bn256-table"))] use crate::impl_from_u64; -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, @@ -300,20 +300,12 @@ impl ff::PrimeField for Fr { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - - #[cfg(not(feature = "asm"))] - let tmp = - Self::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - #[cfg(feature = "asm")] - let tmp = self.montgomery_reduce_256(); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -406,6 +398,11 @@ mod test { ); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("fr".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/derive/field.rs b/src/derive/field.rs index 0a88556a..945ee981 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -267,6 +267,12 @@ macro_rules! field_common { } } + impl From<[u64; 4]> for $field { + fn from(digits: [u64; 4]) -> Self { + Self::from_raw(digits) + } + } + impl From<$field> for [u8; 32] { fn from(value: $field) -> [u8; 32] { value.to_repr() @@ -442,6 +448,49 @@ macro_rules! field_arithmetic { $field([d0 & mask, d1 & mask, d2 & mask, d3 & mask]) } + + /// Montgomery reduce where last 4 registers are 0 + #[inline(always)] + pub(crate) const fn montgomery_reduce_short(r: &[u64; 4]) -> $field { + // The Montgomery reduction here is based on Algorithm 14.32 in + // Handbook of Applied Cryptography + // . + + let k = r[0].wrapping_mul($inv); + let (_, r0) = macx(r[0], k, $modulus.0[0]); + let (r1, r0) = mac(r[1], k, $modulus.0[1], r0); + let (r2, r0) = mac(r[2], k, $modulus.0[2], r0); + let (r3, r0) = mac(r[3], k, $modulus.0[3], r0); + + let k = r1.wrapping_mul($inv); + let (_, r1) = macx(r1, k, $modulus.0[0]); + let (r2, r1) = mac(r2, k, $modulus.0[1], r1); + let (r3, r1) = mac(r3, k, $modulus.0[2], r1); + let (r0, r1) = mac(r0, k, $modulus.0[3], r1); + + let k = r2.wrapping_mul($inv); + let (_, r2) = macx(r2, k, $modulus.0[0]); + let (r3, r2) = mac(r3, k, $modulus.0[1], r2); + let (r0, r2) = mac(r0, k, $modulus.0[2], r2); + let (r1, r2) = mac(r1, k, $modulus.0[3], r2); + + let k = r3.wrapping_mul($inv); + let (_, r3) = macx(r3, k, $modulus.0[0]); + let (r0, r3) = mac(r0, k, $modulus.0[1], r3); + let (r1, r3) = mac(r1, k, $modulus.0[2], r3); + let (r2, r3) = mac(r2, k, $modulus.0[3], r3); + + // Result may be within MODULUS of the correct value + (&$field([r0, r1, r2, r3])).sub(&$modulus) + } + } + + impl From<$field> for [u64; 4] { + fn from(elt: $field) -> [u64; 4] { + // Turn into canonical form by computing + // (a.R) / R = a + $field::montgomery_reduce_short(&elt.0).0 + } } }; } diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index f332f0b6..01fecf84 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output, @@ -255,15 +255,12 @@ impl ff::PrimeField for Fp { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fp::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -353,6 +350,11 @@ mod test { crate::tests::field::random_field_tests::("secp256k1 base".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256k1 base".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 9c5e7665..d38dc517 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output, @@ -266,15 +266,12 @@ impl ff::PrimeField for Fq { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fq::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -360,6 +357,11 @@ mod test { crate::tests::field::random_field_tests::("secp256k1 scalar".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256k1 scalar".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index d351b64d..e26f19fc 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output, @@ -273,15 +273,12 @@ impl ff::PrimeField for Fp { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fp::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -368,19 +365,24 @@ mod test { #[test] fn test_field() { - crate::tests::field::random_field_tests::("secp256k1 base".to_string()); + crate::tests::field::random_field_tests::("secp256r1 base".to_string()); + } + + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256r1 base".to_string()); } #[test] #[cfg(feature = "bits")] fn test_bits() { - crate::tests::field::random_bits_tests::("secp256k1 base".to_string()); + crate::tests::field::random_bits_tests::("secp256r1 base".to_string()); } #[test] fn test_serialization() { - crate::tests::field::random_serialization_test::("secp256k1 base".to_string()); + crate::tests::field::random_serialization_test::("secp256r1 base".to_string()); #[cfg(feature = "derive_serde")] - crate::tests::field::random_serde_test::("secp256k1 base".to_string()); + crate::tests::field::random_serde_test::("secp256r1 base".to_string()); } } diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index e28c3fe6..05fcf1fa 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use core::convert::TryInto; use core::fmt; @@ -262,15 +262,12 @@ impl ff::PrimeField for Fq { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fq::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -362,6 +359,11 @@ mod test { crate::tests::field::random_field_tests::("secp256r1 scalar".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256r1 scalar".to_string()); + } + #[test] fn test_serialization() { crate::tests::field::random_serialization_test::("secp256r1 scalar".to_string()); diff --git a/src/tests/field.rs b/src/tests/field.rs index a85b3f0e..a064441e 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -212,6 +212,22 @@ fn random_expansion_tests(mut rng: R, type_name: String) { end_timer!(start); } +pub fn random_conversion_tests>(type_name: String) { + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + let _message = format!("conversion {type_name}"); + let start = start_timer!(|| _message); + for _ in 0..1000000 { + let a = F::random(&mut rng); + let bytes = a.to_repr(); + let b = F::from_repr(bytes).unwrap(); + assert_eq!(a, b); + } + end_timer!(start); +} + #[cfg(feature = "bits")] pub fn random_bits_tests(type_name: String) { let mut rng = XorShiftRng::from_seed([ From 1d71d343b3764f6a819513a0cd9a855cd3f6b698 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Tue, 22 Aug 2023 16:18:30 +0200 Subject: [PATCH 2/7] Compute Legendre symbol for `hash_to_curve` (#77) * Add `Legendre` trait and macro - Add Legendre macro with norm and legendre symbol computation - Add macro for automatic implementation in prime fields * Add legendre macro call for prime fields * Remove unused imports * Remove leftover * Add `is_quadratic_non_residue` for hash_to_curve * Add `legendre` function * Compute modulus separately * Substitute division for shift * Update modulus computation * Add quadratic residue check func * Add quadratic residue tests * Add hash_to_curve bench * Implement Legendre trait for all curves * Move misplaced comment * Add all curves to hash bench * fix: add suggestion for legendre_exp * fix: imports after rebase --- Cargo.toml | 4 +++ benches/hash_to_curve.rs | 59 ++++++++++++++++++++++++++++++++++++++++ src/bn256/fq.rs | 33 ++++++---------------- src/bn256/fq2.rs | 56 +++++++++++++++++++------------------- src/bn256/fr.rs | 10 +++++-- src/bn256/mod.rs | 7 ----- src/hash_to_curve.rs | 9 ++++-- src/legendre.rs | 50 ++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ src/pasta/mod.rs | 9 ++++++ src/secp256k1/fp.rs | 7 +++++ src/secp256k1/fq.rs | 6 ++++ src/secp256r1/fp.rs | 7 +++++ src/secp256r1/fq.rs | 8 +++++- src/tests/curve.rs | 5 +++- src/tests/field.rs | 16 ++++++++++- 16 files changed, 220 insertions(+), 68 deletions(-) create mode 100644 benches/hash_to_curve.rs create mode 100644 src/legendre.rs diff --git a/Cargo.toml b/Cargo.toml index e0b9d8ba..f29c917e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,3 +63,7 @@ required-features = ["reexport"] [[bench]] name = "group" harness = false + +[[bench]] +name = "hash_to_curve" +harness = false diff --git a/benches/hash_to_curve.rs b/benches/hash_to_curve.rs new file mode 100644 index 00000000..76f6733a --- /dev/null +++ b/benches/hash_to_curve.rs @@ -0,0 +1,59 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use pasta_curves::arithmetic::CurveExt; +use rand_core::{OsRng, RngCore}; +use std::iter; + +fn hash_to_secp256k1(c: &mut Criterion) { + hash_to_curve::(c, "Secp256k1"); +} + +fn hash_to_secq256k1(c: &mut Criterion) { + hash_to_curve::(c, "Secq256k1"); +} + +fn hash_to_secp256r1(c: &mut Criterion) { + hash_to_curve::(c, "Secp256r1"); +} + +fn hash_to_pallas(c: &mut Criterion) { + hash_to_curve::(c, "Pallas"); +} + +fn hash_to_vesta(c: &mut Criterion) { + hash_to_curve::(c, "Vesta"); +} + +fn hash_to_bn256(c: &mut Criterion) { + hash_to_curve::(c, "Bn256"); +} + +fn hash_to_grumpkin(c: &mut Criterion) { + hash_to_curve::(c, "Grumpkin"); +} + +fn hash_to_curve(c: &mut Criterion, name: &'static str) { + { + let hasher = G::hash_to_curve("test"); + let mut rng = OsRng; + let message = iter::repeat_with(|| rng.next_u32().to_be_bytes()) + .take(1024) + .flatten() + .collect::>(); + + c.bench_function(&format!("Hash to {}", name), move |b| { + b.iter(|| hasher(black_box(&message))) + }); + } +} + +criterion_group!( + benches, + hash_to_secp256k1, + hash_to_secq256k1, + hash_to_secp256r1, + hash_to_pallas, + hash_to_vesta, + hash_to_bn256, + hash_to_grumpkin, +); +criterion_main!(benches); diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index 0e2fc10f..0024723a 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -1,11 +1,10 @@ #[cfg(feature = "asm")] use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] -use crate::{field_arithmetic, field_specific}; +use crate::{arithmetic::macx, field_arithmetic, field_specific}; -use crate::arithmetic::{adc, mac, macx, sbb}; -use crate::bn256::LegendreSymbol; -use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use crate::arithmetic::{adc, mac, sbb}; +use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, impl_binops_multiplicative, @@ -160,27 +159,10 @@ impl Fq { pub const fn size() -> usize { 32 } - - pub fn legendre(&self) -> LegendreSymbol { - // s = self^((modulus - 1) // 2) - // 0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3 - let s = &[ - 0x9e10460b6c3e7ea3u64, - 0xcbc0b548b438e546u64, - 0xdc2822db40c0ac2eu64, - 0x183227397098d014u64, - ]; - let s = self.pow(s); - if s == Self::zero() { - LegendreSymbol::Zero - } else if s == Self::one() { - LegendreSymbol::QuadraticResidue - } else { - LegendreSymbol::QuadraticNonResidue - } - } } +prime_field_legendre!(Fq); + impl ff::Field for Fq { const ZERO: Self = Self::zero(); const ONE: Self = Self::one(); @@ -310,6 +292,7 @@ impl WithSmallOrderMulGroup<3> for Fq { #[cfg(test)] mod test { use super::*; + use crate::legendre::Legendre; use ff::Field; use rand_core::OsRng; @@ -322,7 +305,7 @@ mod test { let a = Fq::random(OsRng); let mut b = a; b = b.square(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); let b = b.sqrt().unwrap(); let mut negb = b; @@ -335,7 +318,7 @@ mod test { for _ in 0..10000 { let mut b = c; b = b.square(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); b = b.sqrt().unwrap(); diff --git a/src/bn256/fq2.rs b/src/bn256/fq2.rs index e5a249ee..468b018c 100644 --- a/src/bn256/fq2.rs +++ b/src/bn256/fq2.rs @@ -1,6 +1,6 @@ use super::fq::{Fq, NEGATIVE_ONE}; -use super::LegendreSymbol; use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use crate::legendre::Legendre; use core::convert::TryInto; use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; @@ -125,6 +125,30 @@ impl_binops_additive!(Fq2, Fq2); impl_binops_multiplicative!(Fq2, Fq2); impl_sum_prod!(Fq2); +impl Legendre for Fq2 { + type BasePrimeField = Fq; + + #[inline] + fn legendre_exp() -> &'static [u64] { + lazy_static::lazy_static! { + // (p-1) / 2 + static ref LEGENDRE_EXP: Vec = + (num_bigint::BigUint::from_bytes_le((-::ONE).to_repr().as_ref())/2usize).to_u64_digits(); + } + &*LEGENDRE_EXP + } + + /// Norm of Fq2 as extension field in i over Fq + #[inline] + fn norm(&self) -> Self::BasePrimeField { + let mut t0 = self.c0; + let mut t1 = self.c1; + t0 = t0.square(); + t1 = t1.square(); + t1 + t0 + } +} + impl Fq2 { #[inline] pub const fn zero() -> Fq2 { @@ -174,10 +198,6 @@ impl Fq2 { res } - pub fn legendre(&self) -> LegendreSymbol { - self.norm().legendre() - } - pub fn mul_assign(&mut self, other: &Self) { let mut t1 = self.c0 * other.c0; let mut t0 = self.c0 + self.c1; @@ -298,15 +318,6 @@ impl Fq2 { self.c1 += &t0; } - /// Norm of Fq2 as extension field in i over Fq - pub fn norm(&self) -> Fq { - let mut t0 = self.c0; - let mut t1 = self.c1; - t0 = t0.square(); - t1 = t1.square(); - t1 + t0 - } - pub fn invert(&self) -> CtOption { let mut t1 = self.c1; t1 = t1.square(); @@ -696,17 +707,6 @@ fn test_fq2_mul_nonresidue() { } } -#[test] -fn test_fq2_legendre() { - assert_eq!(LegendreSymbol::Zero, Fq2::ZERO.legendre()); - // i^2 = -1 - let mut m1 = Fq2::ONE; - m1 = m1.neg(); - assert_eq!(LegendreSymbol::QuadraticResidue, m1.legendre()); - m1.mul_by_nonresidue(); - assert_eq!(LegendreSymbol::QuadraticNonResidue, m1.legendre()); -} - #[test] pub fn test_sqrt() { let mut rng = XorShiftRng::from_seed([ @@ -716,7 +716,7 @@ pub fn test_sqrt() { for _ in 0..10000 { let a = Fq2::random(&mut rng); - if a.legendre() == LegendreSymbol::QuadraticNonResidue { + if a.legendre() == -Fq::ONE { assert!(bool::from(a.sqrt().is_none())); } } @@ -725,7 +725,7 @@ pub fn test_sqrt() { let a = Fq2::random(&mut rng); let mut b = a; b.square_assign(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); let b = b.sqrt().unwrap(); let mut negb = b; @@ -738,7 +738,7 @@ pub fn test_sqrt() { for _ in 0..10000 { let mut b = c; b.square_assign(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); b = b.sqrt().unwrap(); diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index cd422d4b..8a57ff9f 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -1,7 +1,7 @@ #[cfg(feature = "asm")] use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] -use crate::{field_arithmetic, field_specific}; +use crate::{arithmetic::macx, field_arithmetic, field_specific}; #[cfg(feature = "bn256-table")] #[rustfmt::skip] @@ -18,7 +18,7 @@ pub use table::FR_TABLE; #[cfg(not(feature = "bn256-table"))] use crate::impl_from_u64; -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, mac, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, @@ -166,6 +166,7 @@ field_common!( R3 ); impl_sum_prod!(Fr); +prime_field_legendre!(Fr); #[cfg(not(feature = "bn256-table"))] impl_from_u64!(Fr, R2); @@ -470,4 +471,9 @@ mod test { end_timer!(timer); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/bn256/mod.rs b/src/bn256/mod.rs index 9cd08946..3530b765 100644 --- a/src/bn256/mod.rs +++ b/src/bn256/mod.rs @@ -16,10 +16,3 @@ pub use fq12::*; pub use fq2::*; pub use fq6::*; pub use fr::*; - -#[derive(Debug, PartialEq, Eq)] -pub enum LegendreSymbol { - Zero = 0, - QuadraticResidue = 1, - QuadraticNonResidue = -1, -} diff --git a/src/hash_to_curve.rs b/src/hash_to_curve.rs index 4cef7095..5d51ec75 100644 --- a/src/hash_to_curve.rs +++ b/src/hash_to_curve.rs @@ -5,6 +5,8 @@ use pasta_curves::arithmetic::CurveExt; use static_assertions::const_assert; use subtle::{ConditionallySelectable, ConstantTimeEq}; +use crate::legendre::Legendre; + /// Hashes over a message and writes the output to all of `buf`. /// Modified from https://github.com/zcash/pasta_curves/blob/7e3fc6a4919f6462a32b79dd226cb2587b7961eb/src/hashtocurve.rs#L11. fn hash_to_field>( @@ -94,6 +96,7 @@ pub(crate) fn svdw_map_to_curve( ) -> C where C: CurveExt, + C::Base: Legendre, { let one = C::Base::ONE; let a = C::a(); @@ -128,7 +131,7 @@ where // 14. gx1 = gx1 + B let gx1 = gx1 + b; // 15. e1 = is_square(gx1) - let e1 = gx1.sqrt().is_some(); + let e1 = !gx1.ct_quadratic_non_residue(); // 16. x2 = c2 + tv4 let x2 = c2 + tv4; // 17. gx2 = x2^2 @@ -140,7 +143,7 @@ where // 20. gx2 = gx2 + B let gx2 = gx2 + b; // 21. e2 = is_square(gx2) AND NOT e1 # Avoid short-circuit logic ops - let e2 = gx2.sqrt().is_some() & (!e1); + let e2 = !gx2.ct_quadratic_non_residue() & (!e1); // 22. x3 = tv2^2 let x3 = tv2.square(); // 23. x3 = x3 * tv3 @@ -182,7 +185,7 @@ pub(crate) fn svdw_hash_to_curve<'a, C>( ) -> Box C + 'a> where C: CurveExt, - C::Base: FromUniformBytes<64>, + C::Base: FromUniformBytes<64> + Legendre, { let [c1, c2, c3, c4] = svdw_precomputed_constants::(z); diff --git a/src/legendre.rs b/src/legendre.rs new file mode 100644 index 00000000..6f6fda17 --- /dev/null +++ b/src/legendre.rs @@ -0,0 +1,50 @@ +use ff::{Field, PrimeField}; +use subtle::{Choice, ConstantTimeEq}; + +pub trait Legendre: Field { + type BasePrimeField: PrimeField; + + // This is (p-1)/2 where p is the modulus of the base prime field + fn legendre_exp() -> &'static [u64]; + + fn norm(&self) -> Self::BasePrimeField; + + #[inline] + fn legendre(&self) -> Self::BasePrimeField { + self.norm().pow(Self::legendre_exp()) + } + + #[inline] + fn ct_quadratic_residue(&self) -> Choice { + self.legendre().ct_eq(&Self::BasePrimeField::ONE) + } + + #[inline] + fn ct_quadratic_non_residue(&self) -> Choice { + self.legendre().ct_eq(&-Self::BasePrimeField::ONE) + } +} + +#[macro_export] +macro_rules! prime_field_legendre { + ($field:ident ) => { + impl crate::legendre::Legendre for $field { + type BasePrimeField = Self; + + #[inline] + fn legendre_exp() -> &'static [u64] { + lazy_static::lazy_static! { + // (p-1) / 2 + static ref LEGENDRE_EXP: Vec = + (num_bigint::BigUint::from_bytes_le((-<$field as ff::Field>::ONE).to_repr().as_ref())/2usize).to_u64_digits(); + } + &*LEGENDRE_EXP + } + + #[inline] + fn norm(&self) -> Self::BasePrimeField { + self.clone() + } + } + }; +} diff --git a/src/lib.rs b/src/lib.rs index b75d7143..3fa8e98f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ mod arithmetic; pub mod hash_to_curve; +#[macro_use] +pub mod legendre; pub mod serde; pub mod bn256; diff --git a/src/pasta/mod.rs b/src/pasta/mod.rs index 164697b5..078b663e 100644 --- a/src/pasta/mod.rs +++ b/src/pasta/mod.rs @@ -38,6 +38,9 @@ const ENDO_PARAMS_EP: EndoParameters = EndoParameters { endo!(Eq, Fp, ENDO_PARAMS_EQ); endo!(Ep, Fq, ENDO_PARAMS_EP); +prime_field_legendre!(Fp); +prime_field_legendre!(Fq); + #[test] fn test_endo() { use ff::Field; @@ -71,3 +74,9 @@ fn test_endo() { } } } + +#[test] +fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + crate::tests::field::random_quadratic_residue_test::(); +} diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index 01fecf84..f6a2a54b 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -295,6 +295,8 @@ impl WithSmallOrderMulGroup<3> for Fp { const ZETA: Self = ZETA; } +prime_field_legendre!(Fp); + #[cfg(test)] mod test { use super::*; @@ -367,4 +369,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256k1 base".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index d38dc517..304f5f10 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -302,6 +302,8 @@ impl WithSmallOrderMulGroup<3> for Fq { const ZETA: Self = ZETA; } +prime_field_legendre!(Fq); + #[cfg(test)] mod test { use super::*; @@ -374,4 +376,8 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256k1 scalar".to_string()); } + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index e26f19fc..228e4a67 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -313,6 +313,8 @@ impl WithSmallOrderMulGroup<3> for Fp { const ZETA: Self = ZETA; } +prime_field_legendre!(Fp); + #[cfg(test)] mod test { use super::*; @@ -385,4 +387,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256r1 base".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index 05fcf1fa..1b98761c 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -1,6 +1,5 @@ use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; -use core::convert::TryInto; use core::fmt; use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; @@ -298,6 +297,8 @@ impl WithSmallOrderMulGroup<3> for Fq { const ZETA: Self = ZETA; } +prime_field_legendre!(Fq); + #[cfg(test)] mod test { use super::*; @@ -370,4 +371,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256r1 scalar".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 9bb0fe0b..54d23791 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -2,6 +2,7 @@ use crate::ff::Field; use crate::group::prime::PrimeCurveAffine; +use crate::legendre::Legendre; use crate::tests::fe_from_str; use crate::{group::GroupEncoding, serde::SerdeObject}; use crate::{hash_to_curve, CurveAffine, CurveExt}; @@ -343,7 +344,9 @@ pub fn svdw_map_to_curve_test( z: G::Base, precomputed_constants: [&'static str; 4], test_vector: impl IntoIterator, -) { +) where + ::Base: Legendre, +{ let [c1, c2, c3, c4] = hash_to_curve::svdw_precomputed_constants::(z); assert_eq!([c1, c2, c3, c4], precomputed_constants.map(fe_from_str)); for (u, (x, y)) in test_vector.into_iter() { diff --git a/src/tests/field.rs b/src/tests/field.rs index a064441e..b04f801e 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -1,6 +1,7 @@ -use crate::ff::Field; use crate::serde::SerdeObject; +use crate::{ff::Field, legendre::Legendre}; use ark_std::{end_timer, start_timer}; +use ff::PrimeField; use rand::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; @@ -287,3 +288,16 @@ where } end_timer!(start); } + +pub fn random_quadratic_residue_test() { + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + for _ in 0..100000 { + let elem = F::random(&mut rng); + let is_quad_res_or_zero: bool = elem.sqrt().is_some().into(); + let is_quad_non_res: bool = elem.ct_quadratic_non_residue().into(); + assert_eq!(!is_quad_non_res, is_quad_res_or_zero) + } +} From 2bb4633111e77b7a962d2bbbf2de0c1641f16880 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Thu, 24 Aug 2023 15:36:42 +0200 Subject: [PATCH 3/7] Add simplified SWU method (#81) * Fix broken link * Add simple SWU algorithm * Add simplified SWU hash_to_curve for secp256r1 * add: sswu z reference * update MAP_ID identifier Co-authored-by: Han --------- Co-authored-by: Han --- src/hash_to_curve.rs | 129 ++++++++++++++++++++++++++++++++++++++++- src/secp256r1/curve.rs | 124 +++++++++++++++++++++++---------------- 2 files changed, 200 insertions(+), 53 deletions(-) diff --git a/src/hash_to_curve.rs b/src/hash_to_curve.rs index 5d51ec75..22251102 100644 --- a/src/hash_to_curve.rs +++ b/src/hash_to_curve.rs @@ -3,7 +3,7 @@ use ff::{Field, FromUniformBytes, PrimeField}; use pasta_curves::arithmetic::CurveExt; use static_assertions::const_assert; -use subtle::{ConditionallySelectable, ConstantTimeEq}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use crate::legendre::Legendre; @@ -85,6 +85,94 @@ fn hash_to_field>( } } +// Implementation of +#[allow(clippy::too_many_arguments)] +pub(crate) fn simple_svdw_map_to_curve(u: C::Base, z: C::Base) -> C +where + C: CurveExt, +{ + let zero = C::Base::ZERO; + let one = C::Base::ONE; + let a = C::a(); + let b = C::b(); + + //1. tv1 = u^2 + let tv1 = u.square(); + //2. tv1 = Z * tv1 + let tv1 = z * tv1; + //3. tv2 = tv1^2 + let tv2 = tv1.square(); + //4. tv2 = tv2 + tv1 + let tv2 = tv2 + tv1; + //5. tv3 = tv2 + 1 + let tv3 = tv2 + one; + //6. tv3 = B * tv3 + let tv3 = b * tv3; + //7. tv4 = CMOV(Z, -tv2, tv2 != 0) # tv4 = z if tv2 is 0 else tv4 = -tv2 + let tv2_is_not_zero = !tv2.ct_eq(&zero); + let tv4 = C::Base::conditional_select(&z, &-tv2, tv2_is_not_zero); + //8. tv4 = A * tv4 + let tv4 = a * tv4; + //9. tv2 = tv3^2 + let tv2 = tv3.square(); + //10. tv6 = tv4^2 + let tv6 = tv4.square(); + //11. tv5 = A * tv6 + let tv5 = a * tv6; + //12. tv2 = tv2 + tv5 + let tv2 = tv2 + tv5; + //13. tv2 = tv2 * tv3 + let tv2 = tv2 * tv3; + //14. tv6 = tv6 * tv4 + let tv6 = tv6 * tv4; + //15. tv5 = B * tv6 + let tv5 = b * tv6; + //16. tv2 = tv2 + tv5 + let tv2 = tv2 + tv5; + //17. x = tv1 * tv3 + let x = tv1 * tv3; + //18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6) + let (is_gx1_square, y1) = sqrt_ratio(&tv2, &tv6, &z); + //19. y = tv1 * u + let y = tv1 * u; + //20. y = y * y1 + let y = y * y1; + //21. x = CMOV(x, tv3, is_gx1_square) + let x = C::Base::conditional_select(&x, &tv3, is_gx1_square); + //22. y = CMOV(y, y1, is_gx1_square) + let y = C::Base::conditional_select(&y, &y1, is_gx1_square); + //23. e1 = sgn0(u) == sgn0(y) + let e1 = u.is_odd().ct_eq(&y.is_odd()); + //24. y = CMOV(-y, y, e1) # Select correct sign of y + let y = C::Base::conditional_select(&-y, &y, e1); + //25. x = x / tv4 + let x = x * tv4.invert().unwrap(); + //26. return (x, y) + C::new_jacobian(x, y, one).unwrap() +} + +#[allow(clippy::type_complexity)] +pub(crate) fn simple_svdw_hash_to_curve<'a, C>( + curve_id: &'static str, + domain_prefix: &'a str, + z: C::Base, +) -> Box C + 'a> +where + C: CurveExt, + C::Base: FromUniformBytes<64>, +{ + Box::new(move |message| { + let mut us = [C::Base::ZERO; 2]; + hash_to_field("SSWU", curve_id, domain_prefix, message, &mut us); + + let [q0, q1]: [C; 2] = us.map(|u| simple_svdw_map_to_curve(u, z)); + + let r = q0 + &q1; + debug_assert!(bool::from(r.is_on_curve())); + r + }) +} + #[allow(clippy::too_many_arguments)] pub(crate) fn svdw_map_to_curve( u: C::Base, @@ -176,7 +264,44 @@ where C::new_jacobian(x, y, one).unwrap() } -/// Implementation of https://www.ietf.org/id/draft-irtf-cfrg-hash-to-curve-16.html#name-shallue-van-de-woestijne-met +// Implement https://datatracker.ietf.org/doc/html/rfc9380#name-sqrt_ratio-for-any-field +// Copied from ff sqrt_ratio_generic subsituting F::ROOT_OF_UNITY for input Z +fn sqrt_ratio(num: &F, div: &F, z: &F) -> (Choice, F) { + // General implementation: + // + // a = num * inv0(div) + // = { 0 if div is zero + // { num/div otherwise + // + // b = z * a + // = { 0 if div is zero + // { z*num/div otherwise + + // Since z is non-square, a and b are either both zero (and both square), or + // only one of them is square. We can therefore choose the square root to return + // based on whether a is square, but for the boolean output we need to handle the + // num != 0 && div == 0 case specifically. + + let a = div.invert().unwrap_or(F::ZERO) * num; + let b = a * z; + let sqrt_a = a.sqrt(); + let sqrt_b = b.sqrt(); + + let num_is_zero = num.is_zero(); + let div_is_zero = div.is_zero(); + let is_square = sqrt_a.is_some(); + let is_nonsquare = sqrt_b.is_some(); + assert!(bool::from( + num_is_zero | div_is_zero | (is_square ^ is_nonsquare) + )); + + ( + is_square & (num_is_zero | !div_is_zero), + CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(), + ) +} + +/// Implementation of https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#section-6.6.1 #[allow(clippy::type_complexity)] pub(crate) fn svdw_hash_to_curve<'a, C>( curve_id: &'static str, diff --git a/src/secp256r1/curve.rs b/src/secp256r1/curve.rs index 5ce4522c..7e6e24ef 100644 --- a/src/secp256r1/curve.rs +++ b/src/secp256r1/curve.rs @@ -1,6 +1,7 @@ use crate::ff::WithSmallOrderMulGroup; use crate::ff::{Field, PrimeField}; use crate::group::{prime::PrimeCurveAffine, Curve, Group as _, GroupEncoding}; +use crate::hash_to_curve::simple_svdw_hash_to_curve; use crate::secp256r1::Fp; use crate::secp256r1::Fq; use crate::{Coordinates, CurveAffine, CurveExt}; @@ -75,77 +76,98 @@ new_curve_impl!( SECP_A, SECP_B, "secp256r1", - |_, _| unimplemented!(), + |curve_id, domain_prefix| simple_svdw_hash_to_curve(curve_id, domain_prefix, Secp256r1::SSVDW_Z), ); -#[test] -fn test_curve() { - crate::tests::curve::curve_tests::(); +impl Secp256r1 { + // Optimal Z with: + // 0xffffffff00000001000000000000000000000000fffffffffffffffffffffff5 + // Z = -10 (reference: ) + const SSVDW_Z: Fp = Fp::from_raw([ + 0xfffffffffffffff5, + 0x00000000ffffffff, + 0x0000000000000000, + 0xffffffff00000001, + ]); } -#[test] -fn test_serialization() { - crate::tests::curve::random_serialization_test::(); - #[cfg(feature = "derive_serde")] - crate::tests::curve::random_serde_test::(); -} - -#[test] -fn ecdsa_example() { +#[cfg(test)] +mod tests { + use super::*; use crate::group::Curve; - use crate::CurveAffine; + use crate::secp256r1::{Fp, Fq, Secp256r1}; use ff::FromUniformBytes; use rand_core::OsRng; - fn mod_n(x: Fp) -> Fq { - let mut x_repr = [0u8; 32]; - x_repr.copy_from_slice(x.to_repr().as_ref()); - let mut x_bytes = [0u8; 64]; - x_bytes[..32].copy_from_slice(&x_repr[..]); - Fq::from_uniform_bytes(&x_bytes) + #[test] + fn test_hash_to_curve() { + crate::tests::curve::hash_to_curve_test::(); + } + + #[test] + fn test_curve() { + crate::tests::curve::curve_tests::(); } - let g = Secp256r1::generator(); + #[test] + fn test_serialization() { + crate::tests::curve::random_serialization_test::(); + #[cfg(feature = "derive_serde")] + crate::tests::curve::random_serde_test::(); + } + + #[test] + fn ecdsa_example() { + fn mod_n(x: Fp) -> Fq { + let mut x_repr = [0u8; 32]; + x_repr.copy_from_slice(x.to_repr().as_ref()); + let mut x_bytes = [0u8; 64]; + x_bytes[..32].copy_from_slice(&x_repr[..]); + Fq::from_uniform_bytes(&x_bytes) + } + + let g = Secp256r1::generator(); - for _ in 0..1000 { - // Generate a key pair - let sk = Fq::random(OsRng); - let pk = (g * sk).to_affine(); + for _ in 0..1000 { + // Generate a key pair + let sk = Fq::random(OsRng); + let pk = (g * sk).to_affine(); - // Generate a valid signature - // Suppose `m_hash` is the message hash - let msg_hash = Fq::random(OsRng); + // Generate a valid signature + // Suppose `m_hash` is the message hash + let msg_hash = Fq::random(OsRng); - let (r, s) = { - // Draw arandomness - let k = Fq::random(OsRng); - let k_inv = k.invert().unwrap(); + let (r, s) = { + // Draw arandomness + let k = Fq::random(OsRng); + let k_inv = k.invert().unwrap(); - // Calculate `r` - let r_point = (g * k).to_affine().coordinates().unwrap(); - let x = r_point.x(); - let r = mod_n(*x); + // Calculate `r` + let r_point = (g * k).to_affine().coordinates().unwrap(); + let x = r_point.x(); + let r = mod_n(*x); - // Calculate `s` - let s = k_inv * (msg_hash + (r * sk)); + // Calculate `s` + let s = k_inv * (msg_hash + (r * sk)); - (r, s) - }; + (r, s) + }; - { - // Verify - let s_inv = s.invert().unwrap(); - let u_1 = msg_hash * s_inv; - let u_2 = r * s_inv; + { + // Verify + let s_inv = s.invert().unwrap(); + let u_1 = msg_hash * s_inv; + let u_2 = r * s_inv; - let v_1 = g * u_1; - let v_2 = pk * u_2; + let v_1 = g * u_1; + let v_2 = pk * u_2; - let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); - let x_candidate = r_point.x(); - let r_candidate = mod_n(*x_candidate); + let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); + let x_candidate = r_point.x(); + let r_candidate = mod_n(*x_candidate); - assert_eq!(r, r_candidate); + assert_eq!(r, r_candidate); + } } } } From 6e2ff3853c8fe91300650a733100640dacf313e6 Mon Sep 17 00:00:00 2001 From: Han Date: Mon, 4 Sep 2023 12:24:29 +0800 Subject: [PATCH 4/7] Bring back curve algorithms for `a = 0` (#82) * refactor: bring back curve algorithms for `a = 0` * fix: clippy warning --- benches/group.rs | 16 +- benches/hash_to_curve.rs | 2 +- src/bn256/fq2.rs | 2 +- src/derive/curve.rs | 441 +++++++++++++++++++++++++-------------- src/legendre.rs | 2 +- 5 files changed, 291 insertions(+), 172 deletions(-) diff --git a/benches/group.rs b/benches/group.rs index 68cfee53..b1936e68 100644 --- a/benches/group.rs +++ b/benches/group.rs @@ -18,28 +18,28 @@ fn criterion_benchmark(c: &mut Criterion) { let v = vec![G::generator(); N]; let mut q = vec![G::AffineExt::identity(); N]; - c.bench_function(&format!("{} check on curve", name), move |b| { + c.bench_function(&format!("{name} check on curve"), move |b| { b.iter(|| black_box(p1).is_on_curve()) }); - c.bench_function(&format!("{} check equality", name), move |b| { + c.bench_function(&format!("{name} check equality"), move |b| { b.iter(|| black_box(p1) == black_box(p1)) }); - c.bench_function(&format!("{} to affine", name), move |b| { + c.bench_function(&format!("{name} to affine"), move |b| { b.iter(|| G::AffineExt::from(black_box(p1))) }); - c.bench_function(&format!("{} doubling", name), move |b| { + c.bench_function(&format!("{name} doubling"), move |b| { b.iter(|| black_box(p1).double()) }); - c.bench_function(&format!("{} addition", name), move |b| { + c.bench_function(&format!("{name} addition"), move |b| { b.iter(|| black_box(p1).add(&p2)) }); - c.bench_function(&format!("{} mixed addition", name), move |b| { + c.bench_function(&format!("{name} mixed addition"), move |b| { b.iter(|| black_box(p2).add(&p1_affine)) }); - c.bench_function(&format!("{} scalar multiplication", name), move |b| { + c.bench_function(&format!("{name} scalar multiplication"), move |b| { b.iter(|| black_box(p1) * black_box(s)) }); - c.bench_function(&format!("{} batch to affine n={}", name, N), move |b| { + c.bench_function(&format!("{name} batch to affine n={N}"), move |b| { b.iter(|| { G::batch_normalize(black_box(&v), black_box(&mut q)); black_box(&q)[0] diff --git a/benches/hash_to_curve.rs b/benches/hash_to_curve.rs index 76f6733a..bda1c1d3 100644 --- a/benches/hash_to_curve.rs +++ b/benches/hash_to_curve.rs @@ -40,7 +40,7 @@ fn hash_to_curve(c: &mut Criterion, name: &'static str) { .flatten() .collect::>(); - c.bench_function(&format!("Hash to {}", name), move |b| { + c.bench_function(&format!("Hash to {name}"), move |b| { b.iter(|| hasher(black_box(&message))) }); } diff --git a/src/bn256/fq2.rs b/src/bn256/fq2.rs index 468b018c..66d2c6a7 100644 --- a/src/bn256/fq2.rs +++ b/src/bn256/fq2.rs @@ -135,7 +135,7 @@ impl Legendre for Fq2 { static ref LEGENDRE_EXP: Vec = (num_bigint::BigUint::from_bytes_le((-::ONE).to_repr().as_ref())/2usize).to_u64_digits(); } - &*LEGENDRE_EXP + &LEGENDRE_EXP } /// Norm of Fq2 as extension field in i over Fq diff --git a/src/derive/curve.rs b/src/derive/curve.rs index 1eeef572..098d1a2f 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -114,8 +114,7 @@ macro_rules! new_curve_impl { $base::from_bytes(&xbytes).and_then(|x| { CtOption::new(Self::identity(), x.is_zero() & (is_inf)).or_else(|| { - let x3 = x.square() * x; - (x3 + $name::curve_constant_a() * x + $name::curve_constant_b()).sqrt().and_then(|y| { + $name_affine::y2(x).sqrt().and_then(|y| { let sign = Choice::from(y.to_bytes()[0] & 1); let y = $base::conditional_select(&y, &-y, ysign ^ sign); @@ -321,18 +320,10 @@ macro_rules! new_curve_impl { } } - const fn curve_constant_a() -> $base { - $name_affine::curve_constant_a() - } - - const fn curve_constant_b() -> $base { - $name_affine::curve_constant_b() - } - #[inline] fn curve_constant_3b() -> $base { lazy_static::lazy_static! { - static ref CONST_3B: $base = $constant_b + $constant_b + $constant_b; + static ref CONST_3B: $base = $constant_b + $constant_b + $constant_b; } *CONST_3B } @@ -354,23 +345,24 @@ macro_rules! new_curve_impl { } } - const fn curve_constant_a() -> $base { - $constant_a - } - - const fn curve_constant_b() -> $base { - $constant_b + #[inline(always)] + fn y2(x: $base) -> $base { + if $constant_a == $base::ZERO { + let x3 = x.square() * x; + (x3 + $constant_b) + } else { + let x2 = x.square(); + ((x2 + $constant_a) * x + $constant_b) + } } - pub fn random(mut rng: impl RngCore) -> Self { loop { let x = $base::random(&mut rng); let ysign = (rng.next_u32() % 2) as u8; - let x3 = x.square() * x; - let y = (x3 + $name::curve_constant_a() * x + $name::curve_constant_b()).sqrt(); - if let Some(y) = Option::<$base>::from(y) { + let y2 = $name_affine::y2(x); + if let Some(y) = Option::<$base>::from(y2.sqrt()) { let sign = y.to_bytes()[0] & 1; let y = if ysign ^ sign == 0 { y } else { -y }; @@ -479,20 +471,30 @@ macro_rules! new_curve_impl { } fn is_on_curve(&self) -> Choice { - // Check (Y/Z)^2 = (X/Z)^3 + a(X/Z) + b - // <=> Z Y^2 - X^3 - a(X Z^2) = Z^3 b + if $constant_a == $base::ZERO { + // Check (Y/Z)^2 = (X/Z)^3 + b + // <=> Z Y^2 - X^3 = Z^3 b + + (self.z * self.y.square() - self.x.square() * self.x) + .ct_eq(&(self.z.square() * self.z * $constant_b)) + | self.z.is_zero() + } else { + // Check (Y/Z)^2 = (X/Z)^3 + a(X/Z) + b + // <=> Z Y^2 - X^3 - a(X Z^2) = Z^3 b - (self.z * self.y.square() - self.x.square() * self.x - $name::curve_constant_a() * self.x * self.z.square()) - .ct_eq(&(self.z.square() * self.z * $name::curve_constant_b())) - | self.z.is_zero() + let z2 = self.z.square(); + (self.z * self.y.square() - (self.x.square() + $constant_a * z2) * self.x) + .ct_eq(&(z2 * self.z * $constant_b)) + | self.z.is_zero() + } } fn b() -> Self::Base { - $name::curve_constant_b() + $constant_b } fn a() -> Self::Base { - $name::curve_constant_a() + $constant_a } fn new_jacobian(x: Self::Base, y: Self::Base, z: Self::Base) -> CtOption { @@ -566,46 +568,76 @@ macro_rules! new_curve_impl { } fn double(&self) -> Self { - // Algorithm 3, https://eprint.iacr.org/2015/1060.pdf - let t0 = self.x.square(); - let t1 = self.y.square(); - let t2 = self.z.square(); - let t3 = self.x * self.y; - let t3 = t3 + t3; - let z3 = self.x * self.z; - let z3 = z3 + z3; - let x3 = $name::curve_constant_a() * z3; - let y3 = $name::mul_by_3b(&t2); - let y3 = x3 + y3; - let x3 = t1 - y3; - let y3 = t1 + y3; - let y3 = x3 * y3; - let x3 = t3 * x3; - let z3 = $name::mul_by_3b(&z3); - let t2 = $name::curve_constant_a() * t2; - let t3 = t0 - t2; - let t3 = $name::curve_constant_a() * t3; - let t3 = t3 + z3; - let z3 = t0 + t0; - let t0 = z3 + t0; - let t0 = t0 + t2; - let t0 = t0 * t3; - let y3 = y3 + t0; - let t2 = self.y * self.z; - let t2 = t2 + t2; - let t0 = t2 * t3; - let x3 = x3 - t0; - let z3 = t2 * t1; - let z3 = z3 + z3; - let z3 = z3 + z3; - - let tmp = $name { - x: x3, - y: y3, - z: z3, - }; - - $name::conditional_select(&tmp, &$name::identity(), self.is_identity()) + if $constant_a == $base::ZERO { + // Algorithm 9, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.y.square(); + let z3 = t0 + t0; + let z3 = z3 + z3; + let z3 = z3 + z3; + let t1 = self.y * self.z; + let t2 = self.z.square(); + let t2 = $name::mul_by_3b(&t2); + let x3 = t2 * z3; + let y3 = t0 + t2; + let z3 = t1 * z3; + let t1 = t2 + t2; + let t2 = t1 + t2; + let t0 = t0 - t2; + let y3 = t0 * y3; + let y3 = x3 + y3; + let t1 = self.x * self.y; + let x3 = t0 * t1; + let x3 = x3 + x3; + + let tmp = $name { + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, &$name::identity(), self.is_identity()) + } else { + // Algorithm 3, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x.square(); + let t1 = self.y.square(); + let t2 = self.z.square(); + let t3 = self.x * self.y; + let t3 = t3 + t3; + let z3 = self.x * self.z; + let z3 = z3 + z3; + let x3 = $constant_a * z3; + let y3 = $name::mul_by_3b(&t2); + let y3 = x3 + y3; + let x3 = t1 - y3; + let y3 = t1 + y3; + let y3 = x3 * y3; + let x3 = t3 * x3; + let z3 = $name::mul_by_3b(&z3); + let t2 = $constant_a * t2; + let t3 = t0 - t2; + let t3 = $constant_a * t3; + let t3 = t3 + z3; + let z3 = t0 + t0; + let t0 = z3 + t0; + let t0 = t0 + t2; + let t0 = t0 * t3; + let y3 = y3 + t0; + let t2 = self.y * self.z; + let t2 = t2 + t2; + let t0 = t2 * t3; + let x3 = x3 - t0; + let z3 = t2 * t1; + let z3 = z3 + z3; + let z3 = z3 + z3; + + let tmp = $name { + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, &$name::identity(), self.is_identity()) + } } fn generator() -> Self { @@ -823,9 +855,15 @@ macro_rules! new_curve_impl { type CurveExt = $name; fn is_on_curve(&self) -> Choice { - // y^2 - x^3 - ax ?= b - (self.y.square() - self.x.square() * self.x - $name::curve_constant_a() * self.x).ct_eq(&$name::curve_constant_b()) - | self.is_identity() + if $constant_a == $base::ZERO { + // y^2 - x^3 ?= b + (self.y.square() - self.x.square() * self.x).ct_eq(&$constant_b) + | self.is_identity() + } else { + // y^2 - x^3 - ax ?= b + (self.y.square() - (self.x.square() + $constant_a) * self.x).ct_eq(&$constant_b) + | self.is_identity() + } } fn coordinates(&self) -> CtOption> { @@ -840,11 +878,11 @@ macro_rules! new_curve_impl { } fn a() -> Self::Base { - $name::curve_constant_a() + $constant_a } fn b() -> Self::Base { - $name::curve_constant_b() + $constant_b } } @@ -892,52 +930,95 @@ macro_rules! new_curve_impl { type Output = $name; fn add(self, rhs: &'a $name) -> $name { - // Algorithm 1, https://eprint.iacr.org/2015/1060.pdf - let t0 = self.x * rhs.x; - let t1 = self.y * rhs.y; - let t2 = self.z * rhs.z; - let t3 = self.x + self.y; - let t4 = rhs.x + rhs.y; - let t3 = t3 * t4; - let t4 = t0 + t1; - let t3 = t3 - t4; - let t4 = self.x + self.z; - let t5 = rhs.x + rhs.z; - let t4 = t4 * t5; - let t5 = t0 + t2; - let t4 = t4 - t5; - let t5 = self.y + self.z; - let x3 = rhs.y + rhs.z; - let t5 = t5 * x3; - let x3 = t1 + t2; - let t5 = t5 - x3; - let z3 = $name::curve_constant_a() * t4; - let x3 = $name::mul_by_3b(&t2); - let z3 = x3 + z3; - let x3 = t1 - z3; - let z3 = t1 + z3; - let y3 = x3 * z3; - let t1 = t0 + t0; - let t1 = t1 + t0; - let t2 = $name::curve_constant_a() * t2; - let t4 = $name::mul_by_3b(&t4); - let t1 = t1 + t2; - let t2 = t0 - t2; - let t2 = $name::curve_constant_a() * t2; - let t4 = t4 + t2; - let t0 = t1 * t4; - let y3 = y3 + t0; - let t0 = t5 * t4; - let x3 = t3 * x3; - let x3 = x3 - t0; - let t0 = t3 * t1; - let z3 = t5 * z3; - let z3 = z3 + t0; - - $name { - x: x3, - y: y3, - z: z3, + if $constant_a == $base::ZERO { + // Algorithm 7, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t2 = self.z * rhs.z; + let t3 = self.x + self.y; + let t4 = rhs.x + rhs.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = self.y + self.z; + let x3 = rhs.y + rhs.z; + let t4 = t4 * x3; + let x3 = t1 + t2; + let t4 = t4 - x3; + let x3 = self.x + self.z; + let y3 = rhs.x + rhs.z; + let x3 = x3 * y3; + let y3 = t0 + t2; + let y3 = x3 - y3; + let x3 = t0 + t0; + let t0 = x3 + t0; + let t2 = $name::mul_by_3b(&t2); + let z3 = t1 + t2; + let t1 = t1 - t2; + let y3 = $name::mul_by_3b(&y3); + let x3 = t4 * y3; + let t2 = t3 * t1; + let x3 = t2 - x3; + let y3 = y3 * t0; + let t1 = t1 * z3; + let y3 = t1 + y3; + let t0 = t0 * t3; + let z3 = z3 * t4; + let z3 = z3 + t0; + + $name { + x: x3, + y: y3, + z: z3, + } + } else { + // Algorithm 1, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t2 = self.z * rhs.z; + let t3 = self.x + self.y; + let t4 = rhs.x + rhs.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = self.x + self.z; + let t5 = rhs.x + rhs.z; + let t4 = t4 * t5; + let t5 = t0 + t2; + let t4 = t4 - t5; + let t5 = self.y + self.z; + let x3 = rhs.y + rhs.z; + let t5 = t5 * x3; + let x3 = t1 + t2; + let t5 = t5 - x3; + let z3 = $constant_a * t4; + let x3 = $name::mul_by_3b(&t2); + let z3 = x3 + z3; + let x3 = t1 - z3; + let z3 = t1 + z3; + let y3 = x3 * z3; + let t1 = t0 + t0; + let t1 = t1 + t0; + let t2 = $constant_a * t2; + let t4 = $name::mul_by_3b(&t4); + let t1 = t1 + t2; + let t2 = t0 - t2; + let t2 = $constant_a * t2; + let t4 = t4 + t2; + let t0 = t1 * t4; + let y3 = y3 + t0; + let t0 = t5 * t4; + let x3 = t3 * x3; + let x3 = x3 - t0; + let t0 = t3 * t1; + let z3 = t5 * z3; + let z3 = z3 + t0; + + $name { + x: x3, + y: y3, + z: z3, + } } } } @@ -947,48 +1028,86 @@ macro_rules! new_curve_impl { // Mixed addition fn add(self, rhs: &'a $name_affine) -> $name { - // Algorithm 2, https://eprint.iacr.org/2015/1060.pdf - let t0 = self.x * rhs.x; - let t1 = self.y * rhs.y; - let t3 = rhs.x + rhs.y; - let t4 = self.x + self.y; - let t3 = t3 * t4; - let t4 = t0 + t1; - let t3 = t3 - t4; - let t4 = rhs.x * self.z; - let t4 = t4 + self.x; - let t5 = rhs.y * self.z; - let t5 = t5 + self.y; - let z3 = $name::curve_constant_a() * t4; - let x3 = $name::mul_by_3b(&self.z); - let z3 = x3 + z3; - let x3 = t1 - z3; - let z3 = t1 + z3; - let y3 = x3 * z3; - let t1 = t0 + t0; - let t1 = t1 + t0; - let t2 = $name::curve_constant_a() * self.z; - let t4 = $name::mul_by_3b(&t4); - let t1 = t1 + t2; - let t2 = t0 - t2; - let t2 = $name::curve_constant_a() * t2; - let t4 = t4 + t2; - let t0 = t1 * t4; - let y3 = y3 + t0; - let t0 = t5 * t4; - let x3 = t3 * x3; - let x3 = x3 - t0; - let t0 = t3 * t1; - let z3 = t5 * z3; - let z3 = z3 + t0; - - let tmp = $name{ - x: x3, - y: y3, - z: z3, - }; - - $name::conditional_select(&tmp, self, rhs.is_identity()) + if $constant_a == $base::ZERO { + // Algorithm 8, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t3 = rhs.x + rhs.y; + let t4 = self.x + self.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = rhs.y * self.z; + let t4 = t4 + self.y; + let y3 = rhs.x * self.z; + let y3 = y3 + self.x; + let x3 = t0 + t0; + let t0 = x3 + t0; + let t2 = $name::mul_by_3b(&self.z); + let z3 = t1 + t2; + let t1 = t1 - t2; + let y3 = $name::mul_by_3b(&y3); + let x3 = t4 * y3; + let t2 = t3 * t1; + let x3 = t2 - x3; + let y3 = y3 * t0; + let t1 = t1 * z3; + let y3 = t1 + y3; + let t0 = t0 * t3; + let z3 = z3 * t4; + let z3 = z3 + t0; + + let tmp = $name{ + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, self, rhs.is_identity()) + } else { + // Algorithm 2, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t3 = rhs.x + rhs.y; + let t4 = self.x + self.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = rhs.x * self.z; + let t4 = t4 + self.x; + let t5 = rhs.y * self.z; + let t5 = t5 + self.y; + let z3 = $constant_a * t4; + let x3 = $name::mul_by_3b(&self.z); + let z3 = x3 + z3; + let x3 = t1 - z3; + let z3 = t1 + z3; + let y3 = x3 * z3; + let t1 = t0 + t0; + let t1 = t1 + t0; + let t2 = $constant_a * self.z; + let t4 = $name::mul_by_3b(&t4); + let t1 = t1 + t2; + let t2 = t0 - t2; + let t2 = $constant_a * t2; + let t4 = t4 + t2; + let t0 = t1 * t4; + let y3 = y3 + t0; + let t0 = t5 * t4; + let x3 = t3 * x3; + let x3 = x3 - t0; + let t0 = t3 * t1; + let z3 = t5 * z3; + let z3 = z3 + t0; + + let tmp = $name{ + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, self, rhs.is_identity()) + } } } diff --git a/src/legendre.rs b/src/legendre.rs index 6f6fda17..7e4b9971 100644 --- a/src/legendre.rs +++ b/src/legendre.rs @@ -28,7 +28,7 @@ pub trait Legendre: Field { #[macro_export] macro_rules! prime_field_legendre { ($field:ident ) => { - impl crate::legendre::Legendre for $field { + impl $crate::legendre::Legendre for $field { type BasePrimeField = Self; #[inline] From 8e3a33af78c941bb87ab8a5e81dc4cb3d09c0d69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:43:26 +0200 Subject: [PATCH 5/7] fix: Improve serialization for prime fields (#85) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: Improve serialization for prime fields Summary: 256-bit field serialization is currently 4x u64, ie. the native format. This implements the standard of byte-serialization (corresponding to the PrimeField::{to,from}_repr), and an hex-encoded variant of that for (de)serializers that are human-readable (concretely, json). - Added a new macro `serialize_deserialize_32_byte_primefield!` for custom serialization and deserialization of 32-byte prime field in different struct (Fq, Fp, Fr) across the secp256r, bn256, and derive libraries. - Implemented the new macro for serialization and deserialization in various structs, replacing the previous `serde::{Deserialize, Serialize}` direct use. - Enhanced error checking in the custom serialization methods to ensure valid field elements. - Updated the test function in the tests/field.rs file to include JSON serialization and deserialization tests for object integrity checking. * fixup! fix: Improve serialization for prime fields --------- Co-authored-by: Carlos PĂ©rez <37264926+CPerezz@users.noreply.github.com> --- Cargo.toml | 4 +++- src/bn256/fq.rs | 7 +++---- src/bn256/fr.rs | 7 +++---- src/derive/field.rs | 35 +++++++++++++++++++++++++++++++++++ src/secp256k1/fp.rs | 7 +++---- src/secp256k1/fq.rs | 7 +++---- src/secp256r1/fp.rs | 7 +++---- src/secp256r1/fq.rs | 7 +++---- src/tests/field.rs | 7 +++++++ 9 files changed, 63 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f29c917e..a843a97c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ criterion = { version = "0.3", features = ["html_reports"] } rand_xorshift = "0.3" ark-std = { version = "0.3" } bincode = "1.3.3" +serde_json = "1.0.105" [dependencies] subtle = "2.4" @@ -30,6 +31,7 @@ num-traits = "0.2" paste = "1.0.11" serde = { version = "1.0", default-features = false, optional = true } serde_arrays = { version = "0.1.0", optional = true } +hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } blake2b_simd = "1" [features] @@ -37,7 +39,7 @@ default = ["reexport", "bits"] asm = [] bits = ["ff/bits"] bn256-table = [] -derive_serde = ["serde/derive", "serde_arrays"] +derive_serde = ["serde/derive", "serde_arrays", "hex"] prefetch = [] print-trace = ["ark-std/print-trace"] reexport = [] diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index 0024723a..fec8d863 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -16,9 +16,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `p = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47` @@ -28,9 +25,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 const MODULUS: Fq = Fq([ diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 8a57ff9f..7e3b5ae8 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -31,9 +31,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_r$ where /// /// `r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001` @@ -43,9 +40,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fr` values are always in // Montgomery form; i.e., Fr(a) = aR mod r, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fr(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fr); + /// Constant representing the modulus /// r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 const MODULUS: Fr = Fr([ diff --git a/src/derive/field.rs b/src/derive/field.rs index 945ee981..da962457 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -686,3 +686,38 @@ macro_rules! field_bits { } }; } + +/// A macro to help define serialization and deserialization for prime field implementations +/// that use 32-byte representations. This assumes the concerned type implements PrimeField +/// (for from_repr, to_repr). +#[macro_export] +macro_rules! serialize_deserialize_32_byte_primefield { + ($type:ty) => { + impl ::serde::Serialize for $type { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_repr(); + if serializer.is_human_readable() { + hex::serde::serialize(bytes, serializer) + } else { + bytes.serialize(serializer) + } + } + } + + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $type { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + <[u8; 32]>::deserialize(deserializer)? + }; + Option::from(Self::from_repr(bytes)).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + }; +} diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index f6a2a54b..c346dc6c 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_p$ where /// /// `p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f` @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fp` values are always in // Montgomery form; i.e., Fp(a) = aR mod p, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fp(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fp); + /// Constant representing the modulus /// p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f const MODULUS: Fp = Fp([ diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 304f5f10..189daaba 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `q = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141` @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 const MODULUS: Fq = Fq([ diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index 228e4a67..bf86e157 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_p$ where /// /// `p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fp` values are always in // Montgomery form; i.e., Fp(a) = aR mod p, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fp(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fp); + /// Constant representing the modulus /// p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff const MODULUS: Fp = Fp([ diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index 1b98761c..d1a7b809 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -5,9 +5,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `q = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551` @@ -17,9 +14,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551 const MODULUS: Fq = Fq([ diff --git a/src/tests/field.rs b/src/tests/field.rs index b04f801e..02f5509f 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -280,11 +280,18 @@ where let _message = format!("serialization with serde {type_name}"); let start = start_timer!(|| _message); for _ in 0..1000000 { + // byte serialization let a = F::random(&mut rng); let bytes = bincode::serialize(&a).unwrap(); let reader = std::io::Cursor::new(bytes); let b: F = bincode::deserialize_from(reader).unwrap(); assert_eq!(a, b); + + // json serialization + let json = serde_json::to_string(&a).unwrap(); + let reader = std::io::Cursor::new(json); + let b: F = serde_json::from_reader(reader).unwrap(); + assert_eq!(a, b); } end_timer!(start); } From 2f3e388eef9b788adf126bb4a8abb10877a0a04d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:02:14 +0200 Subject: [PATCH 6/7] refactor: (De)Serialization of points using `GroupEncoding` (#88) * refactor: implement (De)Serialization of points using the `GroupEncoding` trait - Updated curve point (de)serialization logic from the internal representation to the representation offered by the implementation of the `GroupEncoding` trait. * fix: add explicit json serde tests --- src/derive/curve.rs | 70 +++++++++++++++++++++++++++++++++++++++++++-- src/tests/curve.rs | 12 ++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/derive/curve.rs b/src/derive/curve.rs index 098d1a2f..660e85d8 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -288,8 +288,72 @@ macro_rules! new_curve_impl { } + /// A macro to help define point serialization using the [`group::GroupEncoding`] trait + /// This assumes both point types ($name, $nameaffine) implement [`group::GroupEncoding`]. + #[cfg(feature = "derive_serde")] + macro_rules! serialize_deserialize_to_from_bytes { + () => { + impl ::serde::Serialize for $name { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_bytes(); + if serializer.is_human_readable() { + ::hex::serde::serialize(&bytes.0, serializer) + } else { + ::serde_arrays::serialize(&bytes.0, serializer) + } + } + } + + paste::paste! { + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $name { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + ::serde_arrays::deserialize::<_, u8, [< $name _COMPRESSED_SIZE >]>(deserializer)? + }; + Option::from(Self::from_bytes(&[< $name Compressed >](bytes))).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + } + + impl ::serde::Serialize for $name_affine { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_bytes(); + if serializer.is_human_readable() { + ::hex::serde::serialize(&bytes.0, serializer) + } else { + ::serde_arrays::serialize(&bytes.0, serializer) + } + } + } + + paste::paste! { + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $name_affine { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + ::serde_arrays::deserialize::<_, u8, [< $name _COMPRESSED_SIZE >]>(deserializer)? + }; + Option::from(Self::from_bytes(&[< $name Compressed >](bytes))).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + } + }; + } + #[derive(Copy, Clone, Debug)] - #[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] $($privacy)* struct $name { pub x: $base, pub y: $base, @@ -297,13 +361,13 @@ macro_rules! new_curve_impl { } #[derive(Copy, Clone, PartialEq)] - #[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] $($privacy)* struct $name_affine { pub x: $base, pub y: $base, } - + #[cfg(feature = "derive_serde")] + serialize_deserialize_to_from_bytes!(); impl_compressed!(); impl_uncompressed!(); diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 54d23791..2f93bbb4 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -74,12 +74,24 @@ where assert_eq!(projective_point.to_affine(), affine_point_rec); assert_eq!(affine_point, affine_point_rec); } + { + let affine_json = serde_json::to_string(&affine_point).unwrap(); + let reader = std::io::Cursor::new(affine_json); + let affine_point_rec: G::AffineExt = serde_json::from_reader(reader).unwrap(); + assert_eq!(affine_point, affine_point_rec); + } { let projective_bytes = bincode::serialize(&projective_point).unwrap(); let reader = std::io::Cursor::new(projective_bytes); let projective_point_rec: G = bincode::deserialize_from(reader).unwrap(); assert_eq!(projective_point, projective_point_rec); } + { + let projective_json = serde_json::to_string(&projective_point).unwrap(); + let reader = std::io::Cursor::new(projective_json); + let projective_point_rec: G = serde_json::from_reader(reader).unwrap(); + assert_eq!(projective_point, projective_point_rec); + } } } From ee7cb86ce7d733586e7ac48e4dc25930d7851d85 Mon Sep 17 00:00:00 2001 From: einar-taiko <126954546+einar-taiko@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:09:47 +0800 Subject: [PATCH 7/7] Insert MSM and FFT code and their benchmarks. (#86) * Insert MSM and FFT code and their benchmarks. Resolves taikoxyz/zkevm-circuits#150. * feedback * Add instructions * feeback * Implement feedback: Actually supply the correct arguments to `best_multiexp`. Split into `singlecore` and `multicore` benchmarks so Criterion's result caching and comparison over multiple runs makes sense. Rewrite point and scalar generation. * Use slicing and parallelism to to decrease running time. Laptop measurements: k=22: 109 sec k=16: 1 sec * Refactor msm * Refactor fft * Update module comments * Fix formatting * Implement suggestion for fixing CI --- Cargo.toml | 13 +++- benches/fft.rs | 57 ++++++++++++++++++ benches/msm.rs | 116 +++++++++++++++++++++++++++++++++++ src/fft.rs | 134 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 + src/msm.rs | 153 +++++++++++++++++++++++++++++++++++++++++++++++ src/multicore.rs | 16 +++++ 7 files changed, 491 insertions(+), 1 deletion(-) create mode 100644 benches/fft.rs create mode 100644 benches/msm.rs create mode 100644 src/fft.rs create mode 100644 src/msm.rs create mode 100644 src/multicore.rs diff --git a/Cargo.toml b/Cargo.toml index a843a97c..43fa7d03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,9 +33,11 @@ serde = { version = "1.0", default-features = false, optional = true } serde_arrays = { version = "0.1.0", optional = true } hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } blake2b_simd = "1" +maybe-rayon = { version = "0.1.0", default-features = false } [features] -default = ["reexport", "bits"] +default = ["reexport", "bits", "multicore"] +multicore = ["maybe-rayon/threads"] asm = [] bits = ["ff/bits"] bn256-table = [] @@ -69,3 +71,12 @@ harness = false [[bench]] name = "hash_to_curve" harness = false + +[[bench]] +name = "fft" +harness = false + +[[bench]] +name = "msm" +harness = false +required-features = ["multicore"] diff --git a/benches/fft.rs b/benches/fft.rs new file mode 100644 index 00000000..a250308d --- /dev/null +++ b/benches/fft.rs @@ -0,0 +1,57 @@ +//! This benchmarks Fast-Fourier Transform (FFT). +//! Since it is over a finite field, it is actually the Number Theoretical +//! Transform (NNT). It uses the `Fr` scalar field from the BN256 curve. +//! +//! To run this benchmark: +//! +//! cargo bench -- fft +//! +//! Caveat: The multicore benchmark assumes: +//! 1. a multi-core system +//! 2. that the `multicore` feature is enabled. It is by default. + +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion}; +use group::ff::Field; +use halo2curves::bn256::Fr as Scalar; +use halo2curves::fft::best_fft; +use rand_core::OsRng; +use std::ops::Range; +use std::time::SystemTime; + +const RANGE: Range = 3..19; + +fn generate_data(k: u32) -> Vec { + let n = 1 << k; + let timer = SystemTime::now(); + println!("\n\nGenerating 2^{k} = {n} values..",); + let data: Vec = (0..n).map(|_| Scalar::random(OsRng)).collect(); + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} values took: {} sec.\n\n", + end.as_secs() + ); + data +} + +fn fft(c: &mut Criterion) { + let max_k = RANGE.max().unwrap_or(16); + let mut data = generate_data(max_k); + let omega = Scalar::random(OsRng); + let mut group = c.benchmark_group("fft"); + for k in RANGE { + group.bench_function(BenchmarkId::new("k", k), |b| { + let n = 1 << k; + assert!(n <= data.len()); + b.iter(|| { + best_fft(&mut data[..n], omega, k); + }); + }); + } + group.finish(); +} + +criterion_group!(benches, fft); +criterion_main!(benches); diff --git a/benches/msm.rs b/benches/msm.rs new file mode 100644 index 00000000..c78952b7 --- /dev/null +++ b/benches/msm.rs @@ -0,0 +1,116 @@ +//! This benchmarks Multi Scalar Multiplication (MSM). +//! It measures `G1` from the BN256 curve. +//! +//! To run this benchmark: +//! +//! cargo bench -- msm +//! +//! Caveat: The multicore benchmark assumes: +//! 1. a multi-core system +//! 2. that the `multicore` feature is enabled. It is by default. + +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion}; +use ff::Field; +use group::prime::PrimeCurveAffine; +use halo2curves::bn256::{Fr as Scalar, G1Affine as Point}; +use halo2curves::msm::{best_multiexp, multiexp_serial}; +use maybe_rayon::current_thread_index; +use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rand_core::SeedableRng; +use rand_xorshift::XorShiftRng; +use std::time::SystemTime; + +const SAMPLE_SIZE: usize = 10; +const SINGLECORE_RANGE: [u8; 6] = [3, 8, 10, 12, 14, 16]; +const MULTICORE_RANGE: [u8; 9] = [3, 8, 10, 12, 14, 16, 18, 20, 22]; +const SEED: [u8; 16] = [ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5, +]; + +fn generate_coefficients_and_curvepoints(k: u8) -> (Vec, Vec) { + let n: u64 = { + assert!(k < 64); + 1 << k + }; + + println!("\n\nGenerating 2^{k} = {n} coefficients and curve points..",); + let timer = SystemTime::now(); + let coeffs = (0..n) + .into_par_iter() + .map_init( + || { + let mut thread_seed = SEED; + let uniq = current_thread_index().unwrap().to_ne_bytes(); + assert!(std::mem::size_of::() == 8); + for i in 0..uniq.len() { + thread_seed[i] += uniq[i]; + thread_seed[i + 8] += uniq[i]; + } + XorShiftRng::from_seed(thread_seed) + }, + |rng, _| Scalar::random(rng), + ) + .collect(); + let bases = (0..n) + .into_par_iter() + .map_init( + || { + let mut thread_seed = SEED; + let uniq = current_thread_index().unwrap().to_ne_bytes(); + assert!(std::mem::size_of::() == 8); + for i in 0..uniq.len() { + thread_seed[i] += uniq[i]; + thread_seed[i + 8] += uniq[i]; + } + XorShiftRng::from_seed(thread_seed) + }, + |rng, _| Point::random(rng), + ) + .collect(); + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} coefficients and curve points took: {} sec.\n\n", + end.as_secs() + ); + + (coeffs, bases) +} + +fn msm(c: &mut Criterion) { + let mut group = c.benchmark_group("msm"); + let max_k = *SINGLECORE_RANGE + .iter() + .chain(MULTICORE_RANGE.iter()) + .max() + .unwrap_or(&16); + let (coeffs, bases) = generate_coefficients_and_curvepoints(max_k); + + for k in SINGLECORE_RANGE { + group + .bench_function(BenchmarkId::new("singlecore", k), |b| { + assert!(k < 64); + let n: usize = 1 << k; + let mut acc = Point::identity().into(); + b.iter(|| multiexp_serial(&coeffs[..n], &bases[..n], &mut acc)); + }) + .sample_size(10); + } + for k in MULTICORE_RANGE { + group + .bench_function(BenchmarkId::new("multicore", k), |b| { + assert!(k < 64); + let n: usize = 1 << k; + b.iter(|| { + best_multiexp(&coeffs[..n], &bases[..n]); + }) + }) + .sample_size(SAMPLE_SIZE); + } + group.finish(); +} + +criterion_group!(benches, msm); +criterion_main!(benches); diff --git a/src/fft.rs b/src/fft.rs new file mode 100644 index 00000000..6eb3487e --- /dev/null +++ b/src/fft.rs @@ -0,0 +1,134 @@ +use crate::multicore; +pub use crate::{CurveAffine, CurveExt}; +use ff::Field; +use group::{GroupOpsOwned, ScalarMulOwned}; + +/// This represents an element of a group with basic operations that can be +/// performed. This allows an FFT implementation (for example) to operate +/// generically over either a field or elliptic curve group. +pub trait FftGroup: + Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned +{ +} + +impl FftGroup for T +where + Scalar: Field, + T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned, +{ +} + +/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size +/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative +/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when +/// interpreted as the coefficients of a polynomial of degree $n - 1$, is +/// transformed into the evaluations of this polynomial at each of the $n$ +/// distinct powers of $\omega$. This transformation is invertible by providing +/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element +/// by $n$. +/// +/// This will use multithreading if beneficial. +pub fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + fn bitreverse(mut n: usize, l: usize) -> usize { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r + } + + let threads = multicore::current_num_threads(); + let log_threads = threads.ilog2(); + let n = a.len(); + assert_eq!(n, 1 << log_n); + + for k in 0..n { + let rk = bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk, k); + } + } + + // precompute twiddle factors + let twiddles: Vec<_> = (0..(n / 2)) + .scan(Scalar::ONE, |w, _| { + let tw = *w; + *w *= ω + Some(tw) + }) + .collect(); + + if log_n <= log_threads { + let mut chunk = 2_usize; + let mut twiddle_chunk = n / 2; + for _ in 0..log_n { + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + }); + chunk *= 2; + twiddle_chunk /= 2; + } + } else { + recursive_butterfly_arithmetic(a, n, 1, &twiddles) + } +} + +/// This perform recursive butterfly arithmetic +pub fn recursive_butterfly_arithmetic>( + a: &mut [G], + n: usize, + twiddle_chunk: usize, + twiddles: &[Scalar], +) { + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0] += &t; + a[1] -= &t; + } else { + let (left, right) = a.split_at_mut(n / 2); + multicore::join( + || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), + || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3fa8e98f..670a6448 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ mod arithmetic; +pub mod fft; pub mod hash_to_curve; +pub mod msm; +pub mod multicore; #[macro_use] pub mod legendre; pub mod serde; diff --git a/src/msm.rs b/src/msm.rs new file mode 100644 index 00000000..de30be55 --- /dev/null +++ b/src/msm.rs @@ -0,0 +1,153 @@ +use ff::PrimeField; +use group::Group; +use pasta_curves::arithmetic::CurveAffine; + +use crate::multicore; + +pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + +/// Performs a small multi-exponentiation operation. +/// Uses the double-and-add algorithm with doublings shared across points. +pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + let mut acc = C::Curve::identity(); + + // for byte idx + for byte_idx in (0..32).rev() { + // for bit idx + for bit_idx in (0..8).rev() { + acc = acc.double(); + // for each coeff + for coeff_idx in 0..coeffs.len() { + let byte = coeffs[coeff_idx].as_ref()[byte_idx]; + if ((byte >> bit_idx) & 1) != 0 { + acc += bases[coeff_idx]; + } + } + } + } + + acc +} + +/// Performs a multi-exponentiation operation. +/// +/// This function will panic if coeffs and bases have a different length. +/// +/// This will use multithreading if beneficial. +pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); + + let num_threads = multicore::current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![C::Curve::identity(); num_chunks]; + multicore::scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + multiexp_serial(coeffs, bases, acc); + }); + } + }); + results.iter().fold(C::Curve::identity(), |a, b| a + b) + } else { + let mut acc = C::Curve::identity(); + multiexp_serial(coeffs, bases, &mut acc); + acc + } +} diff --git a/src/multicore.rs b/src/multicore.rs new file mode 100644 index 00000000..d8323553 --- /dev/null +++ b/src/multicore.rs @@ -0,0 +1,16 @@ +pub use maybe_rayon::{ + iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, + join, scope, Scope, +}; + +#[cfg(feature = "multicore")] +pub use maybe_rayon::{ + current_num_threads, + iter::{IndexedParallelIterator, IntoParallelRefIterator}, + slice::ParallelSliceMut, +}; + +#[cfg(not(feature = "multicore"))] +pub fn current_num_threads() -> usize { + 1 +}