Skip to content

Commit

Permalink
Implement CIOS for ARM F::mul (#134)
Browse files Browse the repository at this point in the history
* impl CIOS

* more details

* add Fast CIOS for bn256

* rolled Fast CIOS

* clean comment

* geq for last line in bigint_geq

* update comment to include WORD_SIZE

* mod in montomgery

* cargo fmt

* cargo clippy

---------

Co-authored-by: sragss <sragsdale@a16z.com>
  • Loading branch information
sragss and sragss authored Feb 9, 2024
1 parent 3c43d3c commit 9fff22c
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 103 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ serde_arrays = { version = "0.1.0", optional = true }
hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] }
blake2b_simd = "1"
rayon = "1.8"
unroll = "0.1.5"

[features]
default = ["bits"]
Expand Down
24 changes: 24 additions & 0 deletions src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ pub(crate) const fn macx(a: u64, b: u64, c: u64) -> (u64, u64) {
(res as u64, (res >> 64) as u64)
}

/// Returns a >= b
#[inline(always)]
pub(crate) const fn bigint_geq(a: &[u64; 4], b: &[u64; 4]) -> bool {
if a[3] > b[3] {
return true;
} else if a[3] < b[3] {
return false;
}
if a[2] > b[2] {
return true;
} else if a[2] < b[2] {
return false;
}
if a[1] > b[1] {
return true;
} else if a[1] < b[1] {
return false;
}
if a[0] >= b[0] {
return true;
}
false
}

/// Compute a * b, returning the result.
#[inline(always)]
pub(crate) fn mul_512(a: [u64; 4], b: [u64; 4]) -> [u64; 8] {
Expand Down
2 changes: 1 addition & 1 deletion src/bn256/fq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::bn256::assembly::field_arithmetic_asm;
#[cfg(not(feature = "asm"))]
use crate::{arithmetic::macx, field_arithmetic, field_specific};

use crate::arithmetic::{adc, mac, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/bn256/fr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use table::FR_TABLE;
#[cfg(not(feature = "bn256-table"))]
use crate::impl_from_u64;

use crate::arithmetic::{adc, mac, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
233 changes: 136 additions & 97 deletions src/derive/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,73 +63,88 @@ macro_rules! field_common {
$crate::ff_ext::jacobi::jacobi::<5>(&self.0, &$modulus.0)
}

#[cfg(feature = "asm")]
const fn montgomery_form(val: [u64; 4], r: $field) -> $field {
// Converts a 4 64-bit limb value into its congruent field representation.
// If `val` represents a 256 bit value then `r` should be R^2,
// if `val` represents the 256 MSB of a 512 bit value, then `r` should be R^3.

let (r0, carry) = mac(0, val[0], r.0[0], 0);
let (r1, carry) = mac(0, val[0], r.0[1], carry);
let (r2, carry) = mac(0, val[0], r.0[2], carry);
let (r3, r4) = mac(0, val[0], r.0[3], carry);

let (r1, carry) = mac(r1, val[1], r.0[0], 0);
let (r2, carry) = mac(r2, val[1], r.0[1], carry);
let (r3, carry) = mac(r3, val[1], r.0[2], carry);
let (r4, r5) = mac(r4, val[1], r.0[3], carry);

let (r2, carry) = mac(r2, val[2], r.0[0], 0);
let (r3, carry) = mac(r3, val[2], r.0[1], carry);
let (r4, carry) = mac(r4, val[2], r.0[2], carry);
let (r5, r6) = mac(r5, val[2], r.0[3], carry);

let (r3, carry) = mac(r3, val[3], r.0[0], 0);
let (r4, carry) = mac(r4, val[3], r.0[1], carry);
let (r5, carry) = mac(r5, val[3], r.0[2], carry);
let (r6, r7) = mac(r6, val[3], r.0[3], carry);

// Montgomery reduction
let k = r0.wrapping_mul($inv);
let (_, carry) = mac(r0, k, $modulus.0[0], 0);
let (r1, carry) = mac(r1, k, $modulus.0[1], carry);
let (r2, carry) = mac(r2, k, $modulus.0[2], carry);
let (r3, carry) = mac(r3, k, $modulus.0[3], carry);
let (r4, carry2) = adc(r4, 0, carry);

let k = r1.wrapping_mul($inv);
let (_, carry) = mac(r1, k, $modulus.0[0], 0);
let (r2, carry) = mac(r2, k, $modulus.0[1], carry);
let (r3, carry) = mac(r3, k, $modulus.0[2], carry);
let (r4, carry) = mac(r4, k, $modulus.0[3], carry);
let (r5, carry2) = adc(r5, carry2, carry);

let k = r2.wrapping_mul($inv);
let (_, carry) = mac(r2, k, $modulus.0[0], 0);
let (r3, carry) = mac(r3, k, $modulus.0[1], carry);
let (r4, carry) = mac(r4, k, $modulus.0[2], carry);
let (r5, carry) = mac(r5, k, $modulus.0[3], carry);
let (r6, carry2) = adc(r6, carry2, carry);

let k = r3.wrapping_mul($inv);
let (_, carry) = mac(r3, k, $modulus.0[0], 0);
let (r4, carry) = mac(r4, k, $modulus.0[1], carry);
let (r5, carry) = mac(r5, k, $modulus.0[2], carry);
let (r6, carry) = mac(r6, k, $modulus.0[3], carry);
let (r7, carry2) = adc(r7, carry2, carry);

// Result may be within MODULUS of the correct value
let (d0, borrow) = sbb(r4, $modulus.0[0], 0);
let (d1, borrow) = sbb(r5, $modulus.0[1], borrow);
let (d2, borrow) = sbb(r6, $modulus.0[2], borrow);
let (d3, borrow) = sbb(r7, $modulus.0[3], borrow);
let (_, borrow) = sbb(carry2, 0, borrow);
let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0);
let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry);
let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry);
let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry);
#[cfg(feature = "asm")]
{
let (r0, carry) = mac(0, val[0], r.0[0], 0);
let (r1, carry) = mac(0, val[0], r.0[1], carry);
let (r2, carry) = mac(0, val[0], r.0[2], carry);
let (r3, r4) = mac(0, val[0], r.0[3], carry);

let (r1, carry) = mac(r1, val[1], r.0[0], 0);
let (r2, carry) = mac(r2, val[1], r.0[1], carry);
let (r3, carry) = mac(r3, val[1], r.0[2], carry);
let (r4, r5) = mac(r4, val[1], r.0[3], carry);

let (r2, carry) = mac(r2, val[2], r.0[0], 0);
let (r3, carry) = mac(r3, val[2], r.0[1], carry);
let (r4, carry) = mac(r4, val[2], r.0[2], carry);
let (r5, r6) = mac(r5, val[2], r.0[3], carry);

let (r3, carry) = mac(r3, val[3], r.0[0], 0);
let (r4, carry) = mac(r4, val[3], r.0[1], carry);
let (r5, carry) = mac(r5, val[3], r.0[2], carry);
let (r6, r7) = mac(r6, val[3], r.0[3], carry);

// Montgomery reduction
let k = r0.wrapping_mul($inv);
let (_, carry) = mac(r0, k, $modulus.0[0], 0);
let (r1, carry) = mac(r1, k, $modulus.0[1], carry);
let (r2, carry) = mac(r2, k, $modulus.0[2], carry);
let (r3, carry) = mac(r3, k, $modulus.0[3], carry);
let (r4, carry2) = adc(r4, 0, carry);

let k = r1.wrapping_mul($inv);
let (_, carry) = mac(r1, k, $modulus.0[0], 0);
let (r2, carry) = mac(r2, k, $modulus.0[1], carry);
let (r3, carry) = mac(r3, k, $modulus.0[2], carry);
let (r4, carry) = mac(r4, k, $modulus.0[3], carry);
let (r5, carry2) = adc(r5, carry2, carry);

let k = r2.wrapping_mul($inv);
let (_, carry) = mac(r2, k, $modulus.0[0], 0);
let (r3, carry) = mac(r3, k, $modulus.0[1], carry);
let (r4, carry) = mac(r4, k, $modulus.0[2], carry);
let (r5, carry) = mac(r5, k, $modulus.0[3], carry);
let (r6, carry2) = adc(r6, carry2, carry);

let k = r3.wrapping_mul($inv);
let (_, carry) = mac(r3, k, $modulus.0[0], 0);
let (r4, carry) = mac(r4, k, $modulus.0[1], carry);
let (r5, carry) = mac(r5, k, $modulus.0[2], carry);
let (r6, carry) = mac(r6, k, $modulus.0[3], carry);
let (r7, carry2) = adc(r7, carry2, carry);

// Result may be within MODULUS of the correct value
let (d0, borrow) = sbb(r4, $modulus.0[0], 0);
let (d1, borrow) = sbb(r5, $modulus.0[1], borrow);
let (d2, borrow) = sbb(r6, $modulus.0[2], borrow);
let (d3, borrow) = sbb(r7, $modulus.0[3], borrow);
let (_, borrow) = sbb(carry2, 0, borrow);
let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0);
let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry);
let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry);
let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry);

$field([d0, d1, d2, d3])
}

$field([d0, d1, d2, d3])
#[cfg(not(feature = "asm"))]
{
let mut val = val;
if bigint_geq(&val, &$modulus.0) {
let mut borrow = 0;
(val[0], borrow) = sbb(val[0], $modulus.0[0], borrow);
(val[1], borrow) = sbb(val[1], $modulus.0[1], borrow);
(val[2], borrow) = sbb(val[2], $modulus.0[2], borrow);
(val[3], _) = sbb(val[3], $modulus.0[3], borrow);
}
$field::mul(&$field(val), &r)
}
}

fn from_u512(limbs: [u64; 8]) -> $field {
Expand All @@ -150,27 +165,13 @@ macro_rules! field_common {
let lower_256 = [limbs[0], limbs[1], limbs[2], limbs[3]];
let upper_256 = [limbs[4], limbs[5], limbs[6], limbs[7]];

#[cfg(feature = "asm")]
{
Self::montgomery_form(lower_256, $r2) + Self::montgomery_form(upper_256, $r3)
}
#[cfg(not(feature = "asm"))]
{
$field(lower_256) * $r2 + $field(upper_256) * $r3
}
Self::montgomery_form(lower_256, $r2) + Self::montgomery_form(upper_256, $r3)
}

/// Converts from an integer represented in little endian
/// into its (congruent) `$field` representation.
pub const fn from_raw(val: [u64; 4]) -> Self {
#[cfg(feature = "asm")]
{
Self::montgomery_form(val, $r2)
}
#[cfg(not(feature = "asm"))]
{
(&$field(val)).mul(&$r2)
}
Self::montgomery_form(val, $r2)
}

/// Attempts to convert a little-endian byte representation of
Expand Down Expand Up @@ -429,31 +430,69 @@ macro_rules! field_arithmetic {
}

/// Multiplies `rhs` by `self`, returning the result.
#[inline]
pub const fn mul(&self, rhs: &Self) -> $field {
// Schoolbook multiplication
#[inline(always)]
#[unroll::unroll_for_loops]
#[allow(unused_assignments)]
pub const fn mul(&self, rhs: &Self) -> Self {
// Fast Coarsely Integrated Operand Scanning (CIOS) as described
// in Algorithm 2 of EdMSM: https://eprint.iacr.org/2022/1400.pdf
//
// Cannot use the fast version (algorithm 2) if
// modulus_high_word >= (WORD_SIZE - 1) / 2 - 1 = (2^64 - 1)/2 - 1

if $modulus.0[3] < (u64::MAX / 2) {
const N: usize = 4;
let mut t: [u64; N] = [0u64; N];
let mut c_2: u64;
for i in 0..4 {
let mut c: u64 = 0u64;
for j in 0..4 {
(t[j], c) = mac(t[j], self.0[j], rhs.0[i], c);
}
c_2 = c;

let m = t[0].wrapping_mul(INV);
(_, c) = macx(t[0], m, $modulus.0[0]);

for j in 1..4 {
(t[j - 1], c) = mac(t[j], m, $modulus.0[j], c);
}
(t[N - 1], _) = adc(c_2, c, 0);
}

if bigint_geq(&t, &$modulus.0) {
let mut borrow = 0;
(t[0], borrow) = sbb(t[0], $modulus.0[0], borrow);
(t[1], borrow) = sbb(t[1], $modulus.0[1], borrow);
(t[2], borrow) = sbb(t[2], $modulus.0[2], borrow);
(t[3], borrow) = sbb(t[3], $modulus.0[3], borrow);
}
$field(t)
} else {
// Schoolbook multiplication

let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);

let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);

let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);

let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);

$field::montgomery_reduce(&[r0, r1, r2, r3, r4, r5, r6, r7])
$field::montgomery_reduce(&[r0, r1, r2, r3, r4, r5, r6, r7])
}
}

/// Subtracts `rhs` from `self`, returning the result.
Expand Down
2 changes: 1 addition & 1 deletion src/secp256k1/fp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/secp256k1/fq.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/secp256r1/fp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/secp256r1/fq.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use core::fmt;
Expand Down

0 comments on commit 9fff22c

Please sign in to comment.