Skip to content

Commit

Permalink
Merge pull request #211 from ashvardanian/main-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian authored Oct 16, 2024
2 parents abe25d5 + 5b8790a commit 617de0b
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions include/simsimd/spatial.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,22 +572,27 @@ SIMSIMD_PUBLIC void simsimd_l2_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const
}
SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n,
simsimd_distance_t* result) {
int32x4_t d2_vec = vdupq_n_s32(0);

// The naive approach is to upcast 8-bit signed integers into 16-bit signed integers
// for subtraction, then multiply within 16-bit integers and accumulate the results
// into 32-bit integers. This approach is slow on modern Arm CPUs. On Graviton 4,
// that approach results in 17 GB/s of throughput, compared to 39 GB/s for `i8`
// dot-products.
//
// Luckily we can use the `vabdq_s8` which technically returns `i8` values, but it's a
// matter of reinterpret-casting! That approach boosts us to 33 GB/s of throughput.
uint32x4_t d2_vec = vdupq_n_u32(0);
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
int8x8_t a_vec = vld1_s8(a + i);
int8x8_t b_vec = vld1_s8(b + i);
int16x8_t a_vec16 = vmovl_s8(a_vec);
int16x8_t b_vec16 = vmovl_s8(b_vec);
int16x8_t d_vec = vsubq_s16(a_vec16, b_vec16);
int32x4_t d_low = vmull_s16(vget_low_s16(d_vec), vget_low_s16(d_vec));
int32x4_t d_high = vmull_s16(vget_high_s16(d_vec), vget_high_s16(d_vec));
d2_vec = vaddq_s32(d2_vec, vaddq_s32(d_low, d_high));
for (; i + 16 <= n; i += 16) {
int8x16_t a_vec = vld1q_s8(a + i);
int8x16_t b_vec = vld1q_s8(b + i);
uint8x16_t d_vec = vreinterpretq_u8_s8(vabdq_s8(a_vec, b_vec));
d2_vec = vdotq_u32(d2_vec, d_vec, d_vec);
}
int32_t d2 = vaddvq_s32(d2_vec);
uint32_t d2 = vaddvq_u32(d2_vec);
for (; i < n; ++i) {
int32_t n = (int32_t)a[i] - b[i];
d2 += n * n;
d2 += (uint32_t)(n * n);
}
*result = d2;
}
Expand Down

0 comments on commit 617de0b

Please sign in to comment.