Skip to content

Commit

Permalink
further refactoring to internal arm types
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarinn committed Aug 21, 2024
1 parent b7d2f53 commit d524791
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 67 deletions.
17 changes: 17 additions & 0 deletions curve25519-dalek/src/backend/vector/neon/edwards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
use core::convert::From;
use core::ops::{Add, Neg, Sub};

use curve25519_dalek_derive::unsafe_target_feature;
use subtle::Choice;
use subtle::ConditionallySelectable;

Expand All @@ -50,12 +51,14 @@ use super::field::{FieldElement2625x4, Lanes, Shuffle};
#[derive(Copy, Clone, Debug)]
pub struct ExtendedPoint(pub(super) FieldElement2625x4);

#[unsafe_target_feature("neon")]
impl From<edwards::EdwardsPoint> for ExtendedPoint {
fn from(P: edwards::EdwardsPoint) -> ExtendedPoint {
ExtendedPoint(FieldElement2625x4::new(&P.X, &P.Y, &P.Z, &P.T))
}
}

#[unsafe_target_feature("neon")]
impl From<ExtendedPoint> for edwards::EdwardsPoint {
fn from(P: ExtendedPoint) -> edwards::EdwardsPoint {
let tmp = P.0.split();
Expand All @@ -68,6 +71,7 @@ impl From<ExtendedPoint> for edwards::EdwardsPoint {
}
}

#[unsafe_target_feature("neon")]
impl ConditionallySelectable for ExtendedPoint {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
ExtendedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
Expand All @@ -78,18 +82,21 @@ impl ConditionallySelectable for ExtendedPoint {
}
}

#[unsafe_target_feature("neon")]
impl Default for ExtendedPoint {
fn default() -> ExtendedPoint {
ExtendedPoint::identity()
}
}

#[unsafe_target_feature("neon")]
impl Identity for ExtendedPoint {
fn identity() -> ExtendedPoint {
constants::EXTENDEDPOINT_IDENTITY
}
}

#[unsafe_target_feature("neon")]
impl ExtendedPoint {
/// Compute the double of this point.
pub fn double(&self) -> ExtendedPoint {
Expand Down Expand Up @@ -175,6 +182,7 @@ impl ExtendedPoint {
#[derive(Copy, Clone, Debug)]
pub struct CachedPoint(pub(super) FieldElement2625x4);

#[unsafe_target_feature("neon")]
impl From<ExtendedPoint> for CachedPoint {
fn from(P: ExtendedPoint) -> CachedPoint {
let mut x = P.0;
Expand All @@ -193,18 +201,21 @@ impl From<ExtendedPoint> for CachedPoint {
}
}

#[unsafe_target_feature("neon")]
impl Default for CachedPoint {
fn default() -> CachedPoint {
CachedPoint::identity()
}
}

#[unsafe_target_feature("neon")]
impl Identity for CachedPoint {
fn identity() -> CachedPoint {
constants::CACHEDPOINT_IDENTITY
}
}

#[unsafe_target_feature("neon")]
impl ConditionallySelectable for CachedPoint {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
CachedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
Expand All @@ -215,6 +226,7 @@ impl ConditionallySelectable for CachedPoint {
}
}

#[unsafe_target_feature("neon")]
impl<'a> Neg for &'a CachedPoint {
type Output = CachedPoint;
/// Lazily negate the point.
Expand All @@ -229,6 +241,7 @@ impl<'a> Neg for &'a CachedPoint {
}
}

#[unsafe_target_feature("neon")]
impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint {
type Output = ExtendedPoint;

Expand Down Expand Up @@ -266,6 +279,7 @@ impl<'a, 'b> Add<&'b CachedPoint> for &'a ExtendedPoint {
}
}

#[unsafe_target_feature("neon")]
impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint {
type Output = ExtendedPoint;

Expand All @@ -279,6 +293,7 @@ impl<'a, 'b> Sub<&'b CachedPoint> for &'a ExtendedPoint {
}
}

#[unsafe_target_feature("neon")]
impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable<CachedPoint> {
fn from(point: &'a edwards::EdwardsPoint) -> Self {
let P = ExtendedPoint::from(*point);
Expand All @@ -290,6 +305,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for LookupTable<CachedPoint> {
}
}

#[unsafe_target_feature("neon")]
impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5<CachedPoint> {
fn from(point: &'a edwards::EdwardsPoint) -> Self {
let A = ExtendedPoint::from(*point);
Expand All @@ -303,6 +319,7 @@ impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable5<CachedPoint> {
}
}

#[unsafe_target_feature("neon")]
impl<'a> From<&'a edwards::EdwardsPoint> for NafLookupTable8<CachedPoint> {
fn from(point: &'a edwards::EdwardsPoint) -> Self {
let A = ExtendedPoint::from(*point);
Expand Down
106 changes: 49 additions & 57 deletions curve25519-dalek/src/backend/vector/neon/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

use core::ops::{Add, Mul, Neg};

use super::packed_simd::{i32x4, u32x2, u32x2x2, u32x4, u32x4x2, u64x2, u64x4};
use super::packed_simd::{i32x4, u32x2, u32x2x2, u32x4, u32x4x2, u64x2, u64x2x2};
use crate::backend::serial::u64::field::FieldElement51;
use crate::backend::vector::neon::constants::{
P_TIMES_16_HI, P_TIMES_16_LO, P_TIMES_2_HI, P_TIMES_2_LO,
Expand Down Expand Up @@ -135,7 +135,6 @@ fn unpack_pair(src: u32x4x2) -> (u32x2x2, u32x2x2) {
fn repack_pair(x: u32x4x2, y: u32x4x2) -> u32x4x2 {
unsafe {
use core::arch::aarch64::vcombine_u32;
use core::arch::aarch64::vget_low_u32;
use core::arch::aarch64::vtrn1_u32;

u32x4x2::new(
Expand Down Expand Up @@ -366,7 +365,6 @@ impl FieldElement2625x4 {
let rotated_carryout = |v: u32x4x2| -> u32x4x2 {
unsafe {
use core::arch::aarch64::vcombine_u32;
use core::arch::aarch64::vget_low_u32;
use core::arch::aarch64::vqshlq_u32;

let c: u32x4x2 = u32x4x2::new(
Expand All @@ -391,7 +389,6 @@ impl FieldElement2625x4 {
let combine = |v_lo: u32x4x2, v_hi: u32x4x2| -> u32x4x2 {
unsafe {
use core::arch::aarch64::vcombine_u32;
use core::arch::aarch64::vget_low_u32;
u32x4x2::new(
vcombine_u32(
vget_low_u32(v_lo.0.0),
Expand Down Expand Up @@ -432,7 +429,6 @@ impl FieldElement2625x4 {
#[rustfmt::skip] // Retain formatting of return tuple
let c9_19: u32x4x2 = unsafe {
use core::arch::aarch64::vcombine_u32;
use core::arch::aarch64::vget_low_u32;
use core::arch::aarch64::vmulq_n_u32;

let c9_19_spread: u32x4x2 = u32x4x2::new(
Expand All @@ -449,27 +445,22 @@ impl FieldElement2625x4 {
FieldElement2625x4(v)
}

// TODO: use arm types
#[inline]
#[rustfmt::skip] // Retain formatting of carry and repacking
fn reduce64(mut z: [(u64x2, u64x2); 10]) -> FieldElement2625x4 {
fn reduce64(mut z: [u64x2x2; 10]) -> FieldElement2625x4 {
#[allow(non_snake_case)]
let LOW_25_BITS: u64x2 = u64x2::splat((1 << 25) - 1);
let LOW_25_BITS: u64x2x2 = u64x2x2::splat((1 << 25) - 1);
#[allow(non_snake_case)]
let LOW_26_BITS: u64x2 = u64x2::splat((1 << 26) - 1);
let LOW_26_BITS: u64x2x2 = u64x2x2::splat((1 << 26) - 1);

let carry = |z: &mut [(u64x2, u64x2); 10], i: usize| {
let carry = |z: &mut [u64x2x2; 10], i: usize| {
debug_assert!(i < 9);
if i % 2 == 0 {
z[i + 1].0 = z[i + 1].0 + (z[i].0.shr::<26>());
z[i + 1].1 = z[i + 1].1 + (z[i].1.shr::<26>());
z[i].0 = z[i].0 & LOW_26_BITS;
z[i].1 = z[i].1 & LOW_26_BITS;
z[i + 1] = z[i + 1] + (z[i].shr::<26>());
z[i] = z[i] & LOW_26_BITS;
} else {
z[i + 1].0 = z[i + 1].0 + (z[i].0.shr::<25>());
z[i + 1].1 = z[i + 1].1 + (z[i].1.shr::<25>());
z[i].0 = z[i].0 & LOW_25_BITS;
z[i].1 = z[i].1 & LOW_25_BITS;
z[i + 1] = z[i + 1] + (z[i].shr::<25>());
z[i] = z[i] & LOW_25_BITS;
}
};

Expand All @@ -479,51 +470,53 @@ impl FieldElement2625x4 {
carry(&mut z, 3); carry(&mut z, 7);
carry(&mut z, 4); carry(&mut z, 8);

let c = (z[9].0.shr::<25>(), z[9].1.shr::<25>());
z[9] = (z[9].0 & LOW_25_BITS, z[9].1 & LOW_25_BITS);
let mut c0: (u64x2, u64x2) = (c.0 & LOW_26_BITS, c.1 & LOW_26_BITS);
let mut c1: (u64x2, u64x2) = (c.0.shr::<26>(), c.1.shr::<26>());
let c = z[9].shr::<25>();
z[9] = z[9] & LOW_25_BITS;
let mut c0: u64x2x2 = c & LOW_26_BITS;
let mut c1: u64x2x2 = c.shr::<26>();

unsafe {
use core::arch::aarch64::vmulq_n_u32;

c0 = (vmulq_n_u32(c0.0.into(), 19).into(),
vmulq_n_u32(c0.1.into(), 19).into());
c1 = (vmulq_n_u32(c1.0.into(), 19).into(),
vmulq_n_u32(c1.1.into(), 19).into());
use core::arch::aarch64::vreinterpretq_u32_u64;

c0 = u64x2x2::new(
vmulq_n_u32(vreinterpretq_u32_u64(c0.0.0), 19).into(),
vmulq_n_u32(vreinterpretq_u32_u64(c0.0.1), 19).into());
c1 = u64x2x2::new(
vmulq_n_u32(vreinterpretq_u32_u64(c1.0.0), 19).into(),
vmulq_n_u32(vreinterpretq_u32_u64(c1.0.1), 19).into());
}

z[0] = (z[0].0 + c0.0, z[0].1 + c0.1);
z[1] = (z[1].0 + c1.0, z[1].1 + c1.1);
z[0] = z[0] + c0;
z[1] = z[1] + c1;
carry(&mut z, 0);

FieldElement2625x4([
repack_pair(u32x4x2::new(z[0].0.into(), z[0].1.into()), u32x4x2::new(z[1].0.into(), z[1].1.into())),
repack_pair(u32x4x2::new(z[2].0.into(), z[2].1.into()), u32x4x2::new(z[3].0.into(), z[3].1.into())),
repack_pair(u32x4x2::new(z[4].0.into(), z[4].1.into()), u32x4x2::new(z[5].0.into(), z[5].1.into())),
repack_pair(u32x4x2::new(z[6].0.into(), z[6].1.into()), u32x4x2::new(z[7].0.into(), z[7].1.into())),
repack_pair(u32x4x2::new(z[8].0.into(), z[8].1.into()), u32x4x2::new(z[9].0.into(), z[9].1.into())),
repack_pair(u32x4x2::new(z[0].0.0.into(), z[0].0.1.into()), u32x4x2::new(z[1].0.0.into(), z[1].0.1.into())),
repack_pair(u32x4x2::new(z[2].0.0.into(), z[2].0.1.into()), u32x4x2::new(z[3].0.0.into(), z[3].0.1.into())),
repack_pair(u32x4x2::new(z[4].0.0.into(), z[4].0.1.into()), u32x4x2::new(z[5].0.0.into(), z[5].0.1.into())),
repack_pair(u32x4x2::new(z[6].0.0.into(), z[6].0.1.into()), u32x4x2::new(z[7].0.0.into(), z[7].0.1.into())),
repack_pair(u32x4x2::new(z[8].0.0.into(), z[8].0.1.into()), u32x4x2::new(z[9].0.0.into(), z[9].0.1.into())),
])
}

#[allow(non_snake_case)]
#[rustfmt::skip] // keep alignment of formulas
pub fn square_and_negate_D(&self) -> FieldElement2625x4 {
#[inline(always)]
fn m(x: u32x2x2, y: u32x2x2) -> u64x4 {
fn m(x: u32x2x2, y: u32x2x2) -> u64x2x2 {
use core::arch::aarch64::vmull_u32;
unsafe {
let z0: u64x2 = vmull_u32(x.0.0, y.0.0).into();
let z1: u64x2 = vmull_u32(x.0.1, y.0.1).into();
u64x4::new(z0, z1)
u64x2x2::new(z0, z1)
}
}

#[inline(always)]
fn m_lo(x: u32x2x2, y: u32x2x2) -> u32x2x2 {
use core::arch::aarch64::vmull_u32;
use core::arch::aarch64::vuzp1_u32;
use core::arch::aarch64::vget_low_u32;
unsafe {
let x: u32x4x2 = u32x4x2::new(
vmull_u32(x.0.0, y.0.0).into(),
Expand Down Expand Up @@ -571,17 +564,17 @@ impl FieldElement2625x4 {
let z9 = m(x0_2,x9) + m(x1_2,x8) + m(x2_2,x7) + m(x3_2,x6) + m(x4_2,x5);


let low__p37 = u64x4::splat(0x3ffffed << 37);
let even_p37 = u64x4::splat(0x3ffffff << 37);
let odd__p37 = u64x4::splat(0x1ffffff << 37);
let low__p37 = u64x2x2::splat(0x3ffffed << 37);
let even_p37 = u64x2x2::splat(0x3ffffff << 37);
let odd__p37 = u64x2x2::splat(0x1ffffff << 37);

let negate_D = |x_01: u64x4, p_01: u64x4| -> (u64x2, u64x2) {
let negate_D = |x_01: u64x2x2, p_01: u64x2x2| -> u64x2x2 {
unsafe {
use core::arch::aarch64::vcombine_u32;
use core::arch::aarch64::vreinterpretq_u32_u64;
use core::arch::aarch64::vsubq_u64;

(u64x2(x_01.0.0),
u64x2x2::new(u64x2(x_01.0.0),
vcombine_u32(
vget_low_u32(vreinterpretq_u32_u64(x_01.0.1)),
vget_high_u32(vreinterpretq_u32_u64(vsubq_u64(p_01.0.1, x_01.0.1)))).into())
Expand Down Expand Up @@ -652,16 +645,16 @@ impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 {
let (b8, b9) = unpack_pair(self.0[4]);

FieldElement2625x4::reduce64([
(vmull_u32(b0.0.0, consts.0.into()).into(), vmull_u32(b0.0.1, consts.1.into()).into()),
(vmull_u32(b1.0.0, consts.0.into()).into(), vmull_u32(b1.0.1, consts.1.into()).into()),
(vmull_u32(b2.0.0, consts.0.into()).into(), vmull_u32(b2.0.1, consts.1.into()).into()),
(vmull_u32(b3.0.0, consts.0.into()).into(), vmull_u32(b3.0.1, consts.1.into()).into()),
(vmull_u32(b4.0.0, consts.0.into()).into(), vmull_u32(b4.0.1, consts.1.into()).into()),
(vmull_u32(b5.0.0, consts.0.into()).into(), vmull_u32(b5.0.1, consts.1.into()).into()),
(vmull_u32(b6.0.0, consts.0.into()).into(), vmull_u32(b6.0.1, consts.1.into()).into()),
(vmull_u32(b7.0.0, consts.0.into()).into(), vmull_u32(b7.0.1, consts.1.into()).into()),
(vmull_u32(b8.0.0, consts.0.into()).into(), vmull_u32(b8.0.1, consts.1.into()).into()),
(vmull_u32(b9.0.0, consts.0.into()).into(), vmull_u32(b9.0.1, consts.1.into()).into())
u64x2x2::new(vmull_u32(b0.0.0, consts.0.into()).into(), vmull_u32(b0.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b1.0.0, consts.0.into()).into(), vmull_u32(b1.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b2.0.0, consts.0.into()).into(), vmull_u32(b2.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b3.0.0, consts.0.into()).into(), vmull_u32(b3.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b4.0.0, consts.0.into()).into(), vmull_u32(b4.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b5.0.0, consts.0.into()).into(), vmull_u32(b5.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b6.0.0, consts.0.into()).into(), vmull_u32(b6.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b7.0.0, consts.0.into()).into(), vmull_u32(b7.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b8.0.0, consts.0.into()).into(), vmull_u32(b8.0.1, consts.1.into()).into()),
u64x2x2::new(vmull_u32(b9.0.0, consts.0.into()).into(), vmull_u32(b9.0.1, consts.1.into()).into())
])
}
}
Expand All @@ -673,20 +666,19 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 {
#[rustfmt::skip] // Retain formatting of z_i computation
fn mul(self, rhs: &'b FieldElement2625x4) -> FieldElement2625x4 {
#[inline(always)]
fn m(x: u32x2x2, y: u32x2x2) -> u64x4 {
fn m(x: u32x2x2, y: u32x2x2) -> u64x2x2 {
use core::arch::aarch64::vmull_u32;
unsafe {
let z0: u64x2 = vmull_u32(x.0.0, y.0.0).into();
let z1: u64x2 = vmull_u32(x.0.1, y.0.1).into();
u64x4::new(z0, z1)
u64x2x2::new(z0, z1)
}
}

#[inline(always)]
fn m_lo(x: u32x2x2, y: u32x2x2) -> u32x2x2 {
use core::arch::aarch64::vmull_u32;
use core::arch::aarch64::vuzp1_u32;
use core::arch::aarch64::vget_low_u32;
unsafe {
let x: u32x4x2 = u32x4x2::new(
vmull_u32(x.0.0, y.0.0).into(),
Expand Down Expand Up @@ -741,8 +733,8 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 {
let z8 = m(x0,y8) + m(x1_2,y7) + m(x2,y6) + m(x3_2,y5) + m(x4,y4) + m(x5_2,y3) + m(x6,y2) + m(x7_2,y1) + m(x8,y0) + m(x9_2,y9_19);
let z9 = m(x0,y9) + m(x1,y8) + m(x2,y7) + m(x3,y6) + m(x4,y5) + m(x5,y4) + m(x6,y3) + m(x7,y2) + m(x8,y1) + m(x9,y0);

let f = |x: u64x4| -> (u64x2, u64x2) {
(
let f = |x: u64x2x2| -> u64x2x2 {
u64x2x2::new(
x.0.0.into(),
x.0.1.into()
)
Expand Down
2 changes: 1 addition & 1 deletion curve25519-dalek/src/backend/vector/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ pub(crate) mod constants;

pub(crate) use self::edwards::{CachedPoint, ExtendedPoint};

mod packed_simd;
pub mod packed_simd;
Loading

0 comments on commit d524791

Please sign in to comment.