From 252fba78b792930a17f350fdd1ba33e7a9713174 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 17 Oct 2024 06:58:04 +0000 Subject: [PATCH] Add: `u8` kernels & FMA --- README.md | 19 +-- c/lib.c | 17 +++ include/simsimd/binary.h | 29 ++-- include/simsimd/dot.h | 103 ++++++++++++++- include/simsimd/fma.h | 39 ++++++ include/simsimd/simsimd.h | 114 +++++++++++++++- include/simsimd/spatial.h | 269 ++++++++++++++++++++++++++++++++++++-- scripts/bench.cxx | 32 +++-- scripts/test.py | 38 ++++-- 9 files changed, 603 insertions(+), 57 deletions(-) create mode 100644 include/simsimd/fma.h diff --git a/README.md b/README.md index 9aeba874..8d1fd0c5 100644 --- a/README.md +++ b/README.md @@ -91,15 +91,16 @@ You can learn more about the technical implementation details in the following b ## Benchmarks For reference, we use 1536-dimensional vectors, like the embeddings produced by the OpenAI Ada API. -Comparing the serial code throughput produced by GCC 12 to hand-optimized kernels in SimSIMD, we see the following single-core improvements: +Comparing the serial code throughput produced by GCC 12 to hand-optimized kernels in SimSIMD, we see the following single-core improvements for the two most common vector-vector similarity metrics - the Cosine similarity and the Euclidean distance: -| Type | Apple M2 Pro | AMD Genoa | AWS Graviton 4 | -| :----- | ---------------------------------: | ---------------------------------: | ---------------------------------: | -| `f64` | 18.5 → 28.8 GB/s
+ 56 % | 21.9 → 41.4 GB/s
+ 89 % | 20.7 → 41.3 GB/s
+ 99 % | -| `f32` | 9.2 → 29.6 GB/s
+ 221 % | 10.9 → 95.8 GB/s
+ 779 % | 4.9 → 41.9 GB/s
+ 755 % | -| `f16` | 4.6 → 14.6 GB/s
+ 217 % | 3.1 → 108.4 GB/s
+ 3,397 % | 5.4 → 39.3 GB/s
+ 627 % | -| `bf16` | 4.6 → 26.3 GB/s
+ 472 % | 0.8 → 59.5 GB/s
+7,437 % | 2.5 → 29.9 GB/s
+ 1,096 % | -| `i8` | 25.8 → 47.1 GB/s
+ 83 % | 33.1 → 65.3 GB/s
+ 97 % | 35.2 → 43.5 GB/s
+ 24 % | +| Type | Apple M2 Pro | Intel Sapphire Rapids | AWS Graviton 4 | +| :----- | ----------------------------: | -------------------------------: | ------------------------------: | +| `f64` | 18.5 → 28.8 GB/s
+ 56 % | 21.9 → 41.4 GB/s
+ 89 % | 20.7 → 41.3 GB/s
+ 99 % | +| `f32` | 9.2 → 29.6 GB/s
+ 221 % | 10.9 → 95.8 GB/s
+ 779 % | 4.9 → 41.9 GB/s
+ 755 % | +| `f16` | 4.6 → 14.6 GB/s
+ 217 % | 3.1 → 108.4 GB/s
+ 3,397 % | 5.4 → 39.3 GB/s
+ 627 % | +| `bf16` | 4.6 → 26.3 GB/s
+ 472 % | 0.8 → 59.5 GB/s
+7,437 % | 2.5 → 29.9 GB/s
+ 1,096 % | +| `i8` | 25.8 → 47.1 GB/s
+ 83 % | 33.1 → 65.3 GB/s
+ 97 % | 35.2 → 43.5 GB/s
+ 24 % | +| `u8` | | 32.5 → 66.5 GB/s
+ 105 % | | Similar speedups are often observed even when compared to BLAS and LAPACK libraries underlying most numerical computing libraries, including NumPy and SciPy in Python. Broader benchmarking results: @@ -112,7 +113,7 @@ Broader benchmarking results: The package is intended to replace the usage of `numpy.inner`, `numpy.dot`, and `scipy.spatial.distance`. Aside from drastic performance improvements, SimSIMD significantly improves accuracy in mixed precision setups. -NumPy and SciPy, processing `i8` or `f16` vectors, will use the same types for accumulators, while SimSIMD can combine `i8` enumeration, `i16` multiplication, and `i32` accumulation to avoid overflows entirely. +NumPy and SciPy, processing `i8`, `u8` or `f16` vectors, will use the same types for accumulators, while SimSIMD can combine `i8` enumeration, `i16` multiplication, and `i32` accumulation to avoid overflows entirely. The same applies to processing `f16` and `bf16` values with `f32` precision. ### Installation diff --git a/c/lib.c b/c/lib.c index 5ecb1970..81f3e8a2 100644 --- a/c/lib.c +++ b/c/lib.c @@ -108,6 +108,8 @@ extern "C" { } // Dot products +SIMSIMD_DECLARATION_DENSE(dot, i8, i8) +SIMSIMD_DECLARATION_DENSE(dot, u8, u8) SIMSIMD_DECLARATION_DENSE(dot, f16, f16) SIMSIMD_DECLARATION_DENSE(dot, bf16, bf16) SIMSIMD_DECLARATION_DENSE(dot, f32, f32) @@ -123,16 +125,19 @@ SIMSIMD_DECLARATION_DENSE(vdot, f64c, f64) // Spatial distances SIMSIMD_DECLARATION_DENSE(cos, i8, i8) +SIMSIMD_DECLARATION_DENSE(cos, u8, u8) SIMSIMD_DECLARATION_DENSE(cos, f16, f16) SIMSIMD_DECLARATION_DENSE(cos, bf16, bf16) SIMSIMD_DECLARATION_DENSE(cos, f32, f32) SIMSIMD_DECLARATION_DENSE(cos, f64, f64) SIMSIMD_DECLARATION_DENSE(l2sq, i8, i8) +SIMSIMD_DECLARATION_DENSE(l2sq, u8, u8) SIMSIMD_DECLARATION_DENSE(l2sq, f16, f16) SIMSIMD_DECLARATION_DENSE(l2sq, bf16, bf16) SIMSIMD_DECLARATION_DENSE(l2sq, f32, f32) SIMSIMD_DECLARATION_DENSE(l2sq, f64, f64) SIMSIMD_DECLARATION_DENSE(l2, i8, i8) +SIMSIMD_DECLARATION_DENSE(l2, u8, u8) SIMSIMD_DECLARATION_DENSE(l2, f16, f16) SIMSIMD_DECLARATION_DENSE(l2, bf16, bf16) SIMSIMD_DECLARATION_DENSE(l2, f32, f32) @@ -199,10 +204,13 @@ SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) { void* dummy = 0; // Dense: + simsimd_dot_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); + simsimd_dot_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); simsimd_dot_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); simsimd_dot_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_dot_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); simsimd_dot_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_dot_f16c((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); simsimd_dot_bf16c((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_dot_f32c((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); @@ -211,23 +219,32 @@ SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void) { simsimd_vdot_bf16c((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_vdot_f32c((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); simsimd_vdot_f64c((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_cos_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); + simsimd_cos_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); simsimd_cos_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); simsimd_cos_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_cos_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); simsimd_cos_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_l2sq_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); + simsimd_l2sq_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); simsimd_l2sq_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); simsimd_l2sq_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_l2sq_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); simsimd_l2sq_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_l2_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); + simsimd_l2_i8((simsimd_i8_t*)dummy, (simsimd_i8_t*)dummy, 0, dummy_results); + simsimd_l2_u8((simsimd_u8_t*)dummy, (simsimd_u8_t*)dummy, 0, dummy_results); simsimd_l2_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); simsimd_l2_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_l2_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); simsimd_l2_f64((simsimd_f64_t*)dummy, (simsimd_f64_t*)dummy, 0, dummy_results); + simsimd_hamming_b8((simsimd_b8_t*)dummy, (simsimd_b8_t*)dummy, 0, dummy_results); simsimd_jaccard_b8((simsimd_b8_t*)dummy, (simsimd_b8_t*)dummy, 0, dummy_results); + simsimd_kl_f16((simsimd_f16_t*)dummy, (simsimd_f16_t*)dummy, 0, dummy_results); simsimd_kl_bf16((simsimd_bf16_t*)dummy, (simsimd_bf16_t*)dummy, 0, dummy_results); simsimd_kl_f32((simsimd_f32_t*)dummy, (simsimd_f32_t*)dummy, 0, dummy_results); diff --git a/include/simsimd/binary.h b/include/simsimd/binary.h index 2f6a8059..7165e568 100644 --- a/include/simsimd/binary.h +++ b/include/simsimd/binary.h @@ -319,7 +319,20 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c simsimd_distance_t* result) { simsimd_size_t intersection = 0, union_ = 0; - // It's harder to squeeze out performance from tiny representations, so we unroll the loops for binary metrics. + //? On such vectors we can clearly see that the CPU struggles to perform this many parallel + //? population counts, because the throughput of Jaccard and Hamming in this case starts to differ. + //? One optimization, aside from Harley-Seal transforms can be using "shuffles" for nibble-popcount + //? lookups, to utilize other ports on the CPU. + //? https://github.com/ashvardanian/SimSIMD/pull/138 + // + // - `_mm512_popcnt_epi64` maps to `VPOPCNTQ (ZMM, K, ZMM)`: + // - On Ice Lake: 3 cycles latency, ports: 1*p5 + // - On Genoa: 2 cycles latency, ports: 1*FP01 + // - `_mm512_shuffle_epi8` maps to `VPSHUFB (ZMM, ZMM, ZMM)`: + // - On Ice Lake: 1 cycles latency, ports: 1*p5 + // - On Genoa: 2 cycles latency, ports: 1*FP12 + // + // It's harder to squeeze out performance from tiny representations, so we unroll the loops for binary metrics. if (n_words <= 64) { // Up to 512 bits. __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words); __m512i a_vec = _mm512_maskz_loadu_epi8(mask, a); @@ -341,20 +354,6 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c intersection = _mm512_reduce_add_epi64(_mm512_add_epi64(and2_count_vec, and1_count_vec)); union_ = _mm512_reduce_add_epi64(_mm512_add_epi64(or2_count_vec, or1_count_vec)); } else if (n_words <= 196) { // Up to 1568 bits. - // TODO: On such vectors we can clearly see that the CPU struggles to perform this many parallel - // population counts, because the throughput of Jaccard and Hamming in this case starts to differ. - // One optimization, aside from Harley-Seal transforms can be using "shuffles" for nibble-popcount - // lookups, to utilize other ports on the CPU. - // https://github.com/ashvardanian/SimSIMD/pull/138 - // - // On Ice Lake: - // - `VPOPCNTQ (ZMM, K, ZMM)` can only execute on port 5, which is a bottleneck. - // - `VPSHUFB (ZMM, ZMM, ZMM)` can only run on the same port 5 as well! - // On Zen4: - // - `VPOPCNTQ (ZMM, K, ZMM)` can run on ports: 0, 1. - // - `VPSHUFB (ZMM, ZMM, ZMM)` can run on ports: 1, 2. - // https://uops.info/table.html?search=VPOPCNTQ%20(ZMM%2C%20K%2C%20ZMM)&cb_lat=on&cb_tp=on&cb_uops=on&cb_ports=on&cb_SKX=on&cb_ICL=on&cb_TGL=on&cb_measurements=on&cb_doc=on&cb_avx512=on - // https://uops.info/table.html?search=VPSHUFB%20(ZMM%2C%20ZMM%2C%20ZMM)&cb_lat=on&cb_tp=on&cb_uops=on&cb_ports=on&cb_SKX=on&cb_ICL=on&cb_TGL=on&cb_measurements=on&cb_doc=on&cb_avx512=on __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 128); __m512i a1_vec = _mm512_loadu_epi8(a); __m512i b1_vec = _mm512_loadu_epi8(b); diff --git a/include/simsimd/dot.h b/include/simsimd/dot.h index 0b790e98..112a659f 100644 --- a/include/simsimd/dot.h +++ b/include/simsimd/dot.h @@ -13,6 +13,7 @@ * - 32-bit IEEE floating point numbers * - 16-bit IEEE floating point numbers * - 16-bit brain floating point numbers + * - 8-bit unsigned integers * - 8-bit signed integers * * For hardware architectures: @@ -54,6 +55,7 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16c_serial(simsimd_bf16_t const* a, simsimd_bf SIMSIMD_PUBLIC void simsimd_vdot_bf16c_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_u8_serial(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); /* Double-precision serial backends for all numeric types. * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. @@ -83,6 +85,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* results); @@ -120,6 +123,7 @@ SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f1 SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); /* SIMD-powered backends for various generations of AVX512 CPUs. * Skylake is handy, as it supports masked loads and other operations, avoiding the need for the tail loop. @@ -136,6 +140,7 @@ SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32 SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); @@ -212,6 +217,7 @@ SIMSIMD_MAKE_COMPLEX_DOT(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_dot SIMSIMD_MAKE_COMPLEX_VDOT(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_vdot_bf16c_serial SIMSIMD_MAKE_DOT(serial, i8, i64, SIMSIMD_DEREFERENCE) // simsimd_dot_i8_serial +SIMSIMD_MAKE_DOT(serial, u8, i64, SIMSIMD_DEREFERENCE) // simsimd_dot_u8_serial SIMSIMD_MAKE_DOT(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f32_accurate SIMSIMD_MAKE_COMPLEX_DOT(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_dot_f32c_accurate @@ -366,6 +372,27 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const* a, simsimd_i8_t cons *result = ab; } +SIMSIMD_PUBLIC void simsimd_dot_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + uint32x4_t ab_vec = vdupq_n_u32(0); + simsimd_size_t i = 0; + for (; i + 16 <= n; i += 16) { + uint8x16_t a_vec = vld1q_u8(a + i); + uint8x16_t b_vec = vld1q_u8(b + i); + ab_vec = vdotq_u32(ab_vec, a_vec, b_vec); + } + + // Take care of the tail: + uint32_t ab = vaddvq_u32(ab_vec); + for (; i < n; ++i) { + uint32_t ai = a[i], bi = b[i]; + ab += ai * bi; + } + + *result = ab; +} + #pragma clang attribute pop #pragma GCC pop_options #endif @@ -1115,12 +1142,47 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t c __m256i a_i8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); __m256i b_i8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); - // Unpack int8 to int16 + // Upcast `int8` to `int16` __m256i a_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_i8_vec, 0)); __m256i a_i16_high_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_i8_vec, 1)); __m256i b_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_i8_vec, 0)); __m256i b_i16_high_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_i8_vec, 1)); + // Multiply and accumulate at int16 level, accumulate at `int32` level + ab_i32_low_vec = _mm256_add_epi32(ab_i32_low_vec, _mm256_madd_epi16(a_i16_low_vec, b_i16_low_vec)); + ab_i32_high_vec = _mm256_add_epi32(ab_i32_high_vec, _mm256_madd_epi16(a_i16_high_vec, b_i16_high_vec)); + } + + // Horizontal sum across the 256-bit register + int ab = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); + + // Take care of the tail: + for (; i < n; ++i) + ab += (int)(a[i]) * b[i]; + *result = ab; +} + +SIMSIMD_PUBLIC void simsimd_dot_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256i ab_i32_low_vec = _mm256_setzero_si256(); + __m256i ab_i32_high_vec = _mm256_setzero_si256(); + __m256i const zeros_vec = _mm256_setzero_si256(); + + // AVX2 has no instructions for unsigned 8-bit integer dot-products, + // but it has a weird instruction for mixed signed-unsigned 8-bit dot-product. + simsimd_size_t i = 0; + for (; i + 32 <= n; i += 32) { + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + + // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking + // instructions instead of extracts, as they are much faster and more efficient. + __m256i a_i16_low_vec = _mm256_unpacklo_epi8(a_u8_vec, zeros_vec); + __m256i a_i16_high_vec = _mm256_unpackhi_epi8(a_u8_vec, zeros_vec); + __m256i b_i16_low_vec = _mm256_unpacklo_epi8(b_u8_vec, zeros_vec); + __m256i b_i16_high_vec = _mm256_unpackhi_epi8(b_u8_vec, zeros_vec); + // Multiply and accumulate at int16 level, accumulate at int32 level ab_i32_low_vec = _mm256_add_epi32(ab_i32_low_vec, _mm256_madd_epi16(a_i16_low_vec, b_i16_low_vec)); ab_i32_high_vec = _mm256_add_epi32(ab_i32_high_vec, _mm256_madd_epi16(a_i16_high_vec, b_i16_high_vec)); @@ -1699,6 +1761,45 @@ SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const *result = _mm512_reduce_add_epi32(ab_i32_vec); } +SIMSIMD_PUBLIC void simsimd_dot_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512i ab_i32_low_vec = _mm512_setzero_si512(); + __m512i ab_i32_high_vec = _mm512_setzero_si512(); + __m512i const zeros_vec = _mm512_setzero_si512(); + __m512i a_i16_low_vec, a_i16_high_vec, b_i16_low_vec, b_i16_high_vec; + __m512i a_u8_vec, b_u8_vec; + +simsimd_dot_u8_ice_cycle: + if (n < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); + a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); + n = 0; + } else { + a_u8_vec = _mm512_loadu_si512(a); + b_u8_vec = _mm512_loadu_si512(b); + a += 64, b += 64, n -= 64; + } + + // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking + // instructions instead of extracts, as they are much faster and more efficient. + a_i16_low_vec = _mm512_unpacklo_epi8(a_u8_vec, zeros_vec); + a_i16_high_vec = _mm512_unpackhi_epi8(a_u8_vec, zeros_vec); + b_i16_low_vec = _mm512_unpacklo_epi8(b_u8_vec, zeros_vec); + b_i16_high_vec = _mm512_unpackhi_epi8(b_u8_vec, zeros_vec); + // Unfortunately we can't use the `_mm512_dpbusd_epi32` intrinsics here either, + // as it's asymmetric with respect to the sign of the input arguments: + // Signed(ZeroExtend16(a.byte[4*j]) * SignExtend16(b.byte[4*j])) + // So we have to use the `_mm512_dpwssd_epi32` intrinsics instead, upcasting + // to 16-bit beforehand. + ab_i32_low_vec = _mm512_dpwssd_epi32(ab_i32_low_vec, a_i16_low_vec, b_i16_low_vec); + ab_i32_high_vec = _mm512_dpwssd_epi32(ab_i32_high_vec, a_i16_high_vec, b_i16_high_vec); + if (n) + goto simsimd_dot_u8_ice_cycle; + + *result = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_ICE diff --git a/include/simsimd/fma.h b/include/simsimd/fma.h new file mode 100644 index 00000000..858d0496 --- /dev/null +++ b/include/simsimd/fma.h @@ -0,0 +1,39 @@ +/** + * @file fma.h + * @brief SIMD-accelerated mixed-precision Fused-Multiply-Add operations. + * @author Ash Vardanian + * @date October 16, 2024 + * + * Contains following element-wise operations: + * - Weighted Sum: Oq[i] = Alpha * X[i] + Beta * Z[i] + * - FMA or Fused-Multiply-Add: O[i] = Alpha * X[i] * Y[i] + Beta * Z[i] + * + * For datatypes: + * - 64-bit IEEE floating point numbers + * - 32-bit IEEE floating point numbers + * - 16-bit IEEE floating point numbers + * - 16-bit brain floating point numbers + * - 8-bit unsigned integers + * - 8-bit signed integers + * + * For hardware architectures: + * - Arm: NEON, SVE + * - x86: Haswell, Ice Lake, Skylake, Genoa, Sapphire + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + */ +#ifndef SIMSIMD_FMA_H +#define SIMSIMD_FMA_H + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/include/simsimd/simsimd.h b/include/simsimd/simsimd.h index e4013078..c260dc3c 100644 --- a/include/simsimd/simsimd.h +++ b/include/simsimd/simsimd.h @@ -546,10 +546,10 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_implementation( // case simsimd_datatype_unknown_k: break; // These data-types are not supported yet + case simsimd_datatype_i4x2_k: break; case simsimd_datatype_i16_k: break; case simsimd_datatype_i32_k: break; case simsimd_datatype_i64_k: break; - case simsimd_datatype_u8_k: break; case simsimd_datatype_u64_k: break; // Double-precision floating-point vectors @@ -819,7 +819,7 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_implementation( // break; } - // Single-byte integer vectors + // Single-byte signed integer vectors case simsimd_datatype_i8_k: { #if SIMSIMD_TARGET_NEON_I8 if (viable & simsimd_cap_neon_i8_k) @@ -863,6 +863,50 @@ SIMSIMD_INTERNAL void _simsimd_find_metric_punned_implementation( // break; } + // Single-byte unsigned integer vectors + case simsimd_datatype_u8_k: { +#if SIMSIMD_TARGET_NEON_I8 + if (viable & simsimd_cap_neon_i8_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_neon, *c = simsimd_cap_neon_i8_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_neon, *c = simsimd_cap_neon_i8_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_ICE + if (viable & simsimd_cap_ice_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_ice, *c = simsimd_cap_ice_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_u8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2_k: *m = (m_t)&simsimd_l2_u8_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + } // Binary vectors case simsimd_datatype_b8_k: { @@ -1270,6 +1314,8 @@ SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t con */ SIMSIMD_DYNAMIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_cos_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* d); SIMSIMD_DYNAMIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_DYNAMIC void simsimd_cos_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, @@ -1280,6 +1326,8 @@ SIMSIMD_DYNAMIC void simsimd_cos_f64(simsimd_f64_t const* a, simsimd_f64_t const simsimd_distance_t* d); SIMSIMD_DYNAMIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_l2sq_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* d); SIMSIMD_DYNAMIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_DYNAMIC void simsimd_l2sq_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, @@ -1408,6 +1456,30 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( // * @note The dot product is zero if and only if the two vectors are orthogonal. * @note Defined only for floating-point and integer data types. */ +SIMSIMD_PUBLIC void simsimd_dot_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON_F16 + simsimd_dot_i8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_dot_i8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_i8_haswell(a, b, n, d); +#else + simsimd_dot_i8_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_dot_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON_F16 + simsimd_dot_u8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_dot_u8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_u8_haswell(a, b, n, d); +#else + simsimd_dot_u8_serial(a, b, n, d); +#endif +} SIMSIMD_PUBLIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d) { #if SIMSIMD_TARGET_SVE_F16 @@ -1422,7 +1494,6 @@ SIMSIMD_PUBLIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* simsimd_dot_f16_serial(a, b, n, d); #endif } - SIMSIMD_PUBLIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* d) { #if SIMSIMD_TARGET_GENOA @@ -1435,7 +1506,6 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t con simsimd_dot_bf16_serial(a, b, n, d); #endif } - SIMSIMD_PUBLIC void simsimd_dot_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* d) { #if SIMSIMD_TARGET_SVE @@ -1582,6 +1652,18 @@ SIMSIMD_PUBLIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_cos_i8_serial(a, b, n, d); #endif } +SIMSIMD_PUBLIC void simsimd_cos_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_cos_u8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_cos_u8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_cos_u8_haswell(a, b, n, d); +#else + simsimd_cos_u8_serial(a, b, n, d); +#endif +} SIMSIMD_PUBLIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d) { #if SIMSIMD_TARGET_SVE_F16 @@ -1648,6 +1730,18 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b simsimd_l2sq_i8_serial(a, b, n, d); #endif } +SIMSIMD_PUBLIC void simsimd_l2sq_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_l2sq_u8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_l2sq_u8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_l2sq_u8_haswell(a, b, n, d); +#else + simsimd_l2sq_u8_serial(a, b, n, d); +#endif +} SIMSIMD_PUBLIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d) { #if SIMSIMD_TARGET_SVE_F16 @@ -1714,6 +1808,18 @@ SIMSIMD_PUBLIC void simsimd_l2_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_l2_i8_serial(a, b, n, d); #endif } +SIMSIMD_PUBLIC void simsimd_l2_u8(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_l2_u8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_l2_u8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_l2_u8_haswell(a, b, n, d); +#else + simsimd_l2_u8_serial(a, b, n, d); +#endif +} SIMSIMD_PUBLIC void simsimd_l2_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d) { #if SIMSIMD_TARGET_SVE_F16 diff --git a/include/simsimd/spatial.h b/include/simsimd/spatial.h index 69f93595..e669de4b 100644 --- a/include/simsimd/spatial.h +++ b/include/simsimd/spatial.h @@ -13,6 +13,7 @@ * - 32-bit IEEE floating point numbers * - 16-bit IEEE floating point numbers * - 16-bit brain floating point numbers + * - 8-bit unsigned integral numbers * - 8-bit signed integral numbers * - 4-bit signed integral numbers * @@ -55,6 +56,9 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_serial(simsimd_bf16_t const* a, simsimd_bf1 SIMSIMD_PUBLIC void simsimd_l2_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2_u8_serial(simsimd_u8_t const* a, simsimd_u8_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_u8_serial(simsimd_u8_t const* a, simsimd_u8_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_u8_serial(simsimd_u8_t const* a, simsimd_u8_t const*, simsimd_size_t n, simsimd_distance_t* d); /* Double-precision serial backends for all numeric types. * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. @@ -67,10 +71,6 @@ SIMSIMD_PUBLIC void simsimd_l2sq_f16_accurate(simsimd_f16_t const* a, simsimd_f1 SIMSIMD_PUBLIC void simsimd_cos_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const*, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); -SIMSIMD_PUBLIC void simsimd_cos_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); -SIMSIMD_PUBLIC void simsimd_l2_i8_accurate(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); -SIMSIMD_PUBLIC void simsimd_l2sq_i8_accurate(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); -SIMSIMD_PUBLIC void simsimd_cos_i8_accurate(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); /* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all @@ -91,6 +91,9 @@ SIMSIMD_PUBLIC void simsimd_cos_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_ SIMSIMD_PUBLIC void simsimd_l2_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); /* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes. * Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs. @@ -117,6 +120,9 @@ SIMSIMD_PUBLIC void simsimd_cos_f64_sve(simsimd_f64_t const* a, simsimd_f64_t co SIMSIMD_PUBLIC void simsimd_l2_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); @@ -148,6 +154,9 @@ SIMSIMD_PUBLIC void simsimd_cos_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t SIMSIMD_PUBLIC void simsimd_l2_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_l2sq_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* d); SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* d); @@ -224,6 +233,10 @@ SIMSIMD_MAKE_COS(serial, i8, i32, SIMSIMD_DEREFERENCE) // simsimd_cos_i8_serial SIMSIMD_MAKE_L2SQ(serial, i8, i32, SIMSIMD_DEREFERENCE) // simsimd_l2sq_i8_serial SIMSIMD_MAKE_L2(serial, i8, i32, SIMSIMD_DEREFERENCE) // simsimd_l2_i8_serial +SIMSIMD_MAKE_COS(serial, u8, i32, SIMSIMD_DEREFERENCE) // simsimd_cos_u8_serial +SIMSIMD_MAKE_L2SQ(serial, u8, i32, SIMSIMD_DEREFERENCE) // simsimd_l2sq_u8_serial +SIMSIMD_MAKE_L2(serial, u8, i32, SIMSIMD_DEREFERENCE) // simsimd_l2_u8_serial + SIMSIMD_MAKE_COS(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_cos_f32_accurate SIMSIMD_MAKE_L2SQ(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_l2sq_f32_accurate SIMSIMD_MAKE_L2(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_l2_f32_accurate @@ -236,10 +249,6 @@ SIMSIMD_MAKE_COS(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_cos_bf16_ SIMSIMD_MAKE_L2SQ(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_l2sq_bf16_accurate SIMSIMD_MAKE_L2(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_l2_bf16_accurate -SIMSIMD_MAKE_COS(accurate, i8, i32, SIMSIMD_DEREFERENCE) // simsimd_cos_i8_accurate -SIMSIMD_MAKE_L2SQ(accurate, i8, i32, SIMSIMD_DEREFERENCE) // simsimd_l2sq_i8_accurate -SIMSIMD_MAKE_L2(accurate, i8, i32, SIMSIMD_DEREFERENCE) // simsimd_l2_i8_accurate - #if SIMSIMD_TARGET_ARM #if SIMSIMD_TARGET_NEON #pragma GCC push_options @@ -716,6 +725,56 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const* a, simsimd_i8_t cons *result = _simsimd_cos_normalize_f32_neon(ab, a2, b2); } +SIMSIMD_PUBLIC void simsimd_l2_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_l2sq_u8_neon(a, b, n, result); + *result = _simsimd_sqrt_f32_neon(*result); +} +SIMSIMD_PUBLIC void simsimd_l2sq_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + uint32x4_t d2_vec = vdupq_n_u32(0); + simsimd_size_t i = 0; + for (; i + 16 <= n; i += 16) { + uint8x16_t a_vec = vld1q_u8(a + i); + uint8x16_t b_vec = vld1q_u8(b + i); + uint8x16_t d_vec = vabdq_u8(a_vec, b_vec); + d2_vec = vdotq_u32(d2_vec, d_vec, d_vec); + } + uint32_t d2 = vaddvq_u32(d2_vec); + for (; i < n; ++i) { + int32_t n = (int32_t)a[i] - b[i]; + d2 += (uint32_t)(n * n); + } + *result = d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_u8_neon(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + simsimd_size_t i = 0; + uint32x4_t ab_vec = vdupq_n_u32(0); + uint32x4_t a2_vec = vdupq_n_u32(0); + uint32x4_t b2_vec = vdupq_n_u32(0); + for (; i + 16 <= n; i += 16) { + uint8x16_t a_vec = vld1q_u8(a + i); + uint8x16_t b_vec = vld1q_u8(b + i); + ab_vec = vdotq_u32(ab_vec, a_vec, b_vec); + a2_vec = vdotq_u32(a2_vec, a_vec, a_vec); + b2_vec = vdotq_u32(b2_vec, b_vec, b_vec); + } + uint32_t ab = vaddvq_u32(ab_vec); + uint32_t a2 = vaddvq_u32(a2_vec); + uint32_t b2 = vaddvq_u32(b2_vec); + + // Take care of the tail: + for (; i < n; ++i) { + uint32_t ai = a[i], bi = b[i]; + ab += ai * bi, a2 += ai * ai, b2 += bi * bi; + } + + *result = _simsimd_cos_normalize_f32_neon(ab, a2, b2); +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_NEON_I8 @@ -1214,6 +1273,106 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t c __m256i b_i16_low_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_i8_vec, 0)); __m256i b_i16_high_vec = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_i8_vec, 1)); + // Multiply and accumulate as `int16`, accumulate products as `int32`: + ab_i32_low_vec = _mm256_add_epi32(ab_i32_low_vec, _mm256_madd_epi16(a_i16_low_vec, b_i16_low_vec)); + ab_i32_high_vec = _mm256_add_epi32(ab_i32_high_vec, _mm256_madd_epi16(a_i16_high_vec, b_i16_high_vec)); + a2_i32_low_vec = _mm256_add_epi32(a2_i32_low_vec, _mm256_madd_epi16(a_i16_low_vec, a_i16_low_vec)); + a2_i32_high_vec = _mm256_add_epi32(a2_i32_high_vec, _mm256_madd_epi16(a_i16_high_vec, a_i16_high_vec)); + b2_i32_low_vec = _mm256_add_epi32(b2_i32_low_vec, _mm256_madd_epi16(b_i16_low_vec, b_i16_low_vec)); + b2_i32_high_vec = _mm256_add_epi32(b2_i32_high_vec, _mm256_madd_epi16(b_i16_high_vec, b_i16_high_vec)); + } + + // Further reduce to a single sum for each vector + int ab = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); + int a2 = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(a2_i32_low_vec, a2_i32_high_vec)); + int b2 = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(b2_i32_low_vec, b2_i32_high_vec)); + + // Take care of the tail: + for (; i < n; ++i) { + int ai = a[i], bi = b[i]; + ab += ai * bi, a2 += ai * ai, b2 += bi * bi; + } + + *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); +} + +SIMSIMD_PUBLIC void simsimd_l2_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_l2sq_u8_haswell(a, b, n, result); + *result = _simsimd_sqrt_f32_haswell(*result); +} +SIMSIMD_PUBLIC void simsimd_l2sq_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256i d2_i32_low_vec = _mm256_setzero_si256(); + __m256i d2_i32_high_vec = _mm256_setzero_si256(); + __m256i const zeros_vec = _mm256_setzero_si256(); + + simsimd_size_t i = 0; + for (; i + 32 <= n; i += 32) { + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + + // Substracting unsigned vectors in AVX2 is done by saturating subtraction: + __m256i d_u8_vec = _mm256_or_si256(_mm256_subs_epu8(a_u8_vec, b_u8_vec), _mm256_subs_epu8(b_u8_vec, a_u8_vec)); + + // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking + // instructions instead of extracts, as they are much faster and more efficient. + __m256i d_i16_low_vec = _mm256_unpacklo_epi8(d_u8_vec, zeros_vec); + __m256i d_i16_high_vec = _mm256_unpackhi_epi8(d_u8_vec, zeros_vec); + + // Multiply and accumulate at `int16` level, accumulate at `int32` level: + d2_i32_low_vec = _mm256_add_epi32(d2_i32_low_vec, _mm256_madd_epi16(d_i16_low_vec, d_i16_low_vec)); + d2_i32_high_vec = _mm256_add_epi32(d2_i32_high_vec, _mm256_madd_epi16(d_i16_high_vec, d_i16_high_vec)); + } + + // Accumulate the 32-bit integers from `d2_i32_high_vec` and `d2_i32_low_vec` + int d2 = _simsimd_reduce_i32x8_haswell(_mm256_add_epi32(d2_i32_low_vec, d2_i32_high_vec)); + + // Take care of the tail: + for (; i < n; ++i) { + int n = (int)(a[i]) - b[i]; + d2 += n * n; + } + + *result = (simsimd_f64_t)d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_u8_haswell(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256i ab_i32_low_vec = _mm256_setzero_si256(); + __m256i ab_i32_high_vec = _mm256_setzero_si256(); + __m256i a2_i32_low_vec = _mm256_setzero_si256(); + __m256i a2_i32_high_vec = _mm256_setzero_si256(); + __m256i b2_i32_low_vec = _mm256_setzero_si256(); + __m256i b2_i32_high_vec = _mm256_setzero_si256(); + __m256i const zeros_vec = _mm256_setzero_si256(); + + // AVX2 has no instructions for 8-bit signed integer dot-products, + // but it has a weird instruction for mixed signed-unsigned 8-bit dot-product. + // So we need to normalize the first vector to its absolute value, + // and shift the product sign into the second vector. + // + // __m256i a_i8_abs_vec = _mm256_abs_epi8(a_i8_vec); + // __m256i b_i8_flipped_vec = _mm256_sign_epi8(b_i8_vec, a_i8_vec); + // __m256i ab_i16_vec = _mm256_maddubs_epi16(a_i8_abs_vec, b_i8_flipped_vec); + // + // The problem with this approach, however, is the `-128` value in the second vector. + // Flipping it's sign will do nothing, and the result will be incorrect. + // This can easily lead to noticeable numerical errors in the final result. + simsimd_size_t i = 0; + for (; i + 32 <= n; i += 32) { + __m256i a_u8_vec = _mm256_lddqu_si256((__m256i const*)(a + i)); + __m256i b_u8_vec = _mm256_lddqu_si256((__m256i const*)(b + i)); + + // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking + // instructions instead of extracts, as they are much faster and more efficient. + __m256i a_i16_low_vec = _mm256_unpacklo_epi8(a_u8_vec, zeros_vec); + __m256i a_i16_high_vec = _mm256_unpackhi_epi8(a_u8_vec, zeros_vec); + __m256i b_i16_low_vec = _mm256_unpacklo_epi8(b_u8_vec, zeros_vec); + __m256i b_i16_high_vec = _mm256_unpackhi_epi8(b_u8_vec, zeros_vec); + // Multiply and accumulate as `int16`, accumulate products as `int32` ab_i32_low_vec = _mm256_add_epi32(ab_i32_low_vec, _mm256_madd_epi16(a_i16_low_vec, b_i16_low_vec)); ab_i32_high_vec = _mm256_add_epi32(ab_i32_high_vec, _mm256_madd_epi16(a_i16_high_vec, b_i16_high_vec)); @@ -1675,8 +1834,8 @@ SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const* a, simsimd_i8_t cons b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); n = 0; } else { - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256(a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256(b)); + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)a)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)b)); a += 32, b += 32, n -= 32; } d_i16s_vec = _mm512_sub_epi16(a_i16_vec, b_i16_vec); @@ -1701,8 +1860,8 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const b_i16_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); n = 0; } else { - a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256(a)); - b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256(b)); + a_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)a)); + b_i16_vec = _mm512_cvtepi8_epi16(_mm256_lddqu_si256((__m256i const*)b)); a += 32, b += 32, n -= 32; } @@ -1756,6 +1915,92 @@ SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const int b2 = _mm512_reduce_add_epi32(b2_i32_vec); *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); } +SIMSIMD_PUBLIC void simsimd_l2_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_l2sq_u8_ice(a, b, n, result); + *result = _simsimd_sqrt_f32_haswell(*result); +} +SIMSIMD_PUBLIC void simsimd_l2sq_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512i d2_i32_low_vec = _mm512_setzero_si512(); + __m512i d2_i32_high_vec = _mm512_setzero_si512(); + __m512i const zeros_vec = _mm512_setzero_si512(); + __m512i d_i16_low_vec, d_i16_high_vec; + __m512i a_u8_vec, b_u8_vec, d_u8_vec; + +simsimd_l2sq_u8_ice_cycle: + if (n < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); + a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); + n = 0; + } else { + a_u8_vec = _mm512_loadu_si512(a); + b_u8_vec = _mm512_loadu_si512(b); + a += 64, b += 64, n -= 64; + } + + // Substracting unsigned vectors in AVX-512 is done by saturating subtraction: + d_u8_vec = _mm512_or_si512(_mm512_subs_epu8(a_u8_vec, b_u8_vec), _mm512_subs_epu8(b_u8_vec, a_u8_vec)); + d_i16_low_vec = _mm512_unpacklo_epi8(d_u8_vec, zeros_vec); + d_i16_high_vec = _mm512_unpackhi_epi8(d_u8_vec, zeros_vec); + + // Multiply and accumulate at `int16` level, accumulate at `int32` level: + d2_i32_low_vec = _mm512_dpwssd_epi32(d2_i32_low_vec, d_i16_low_vec, d_i16_low_vec); + d2_i32_high_vec = _mm512_dpwssd_epi32(d2_i32_high_vec, d_i16_high_vec, d_i16_high_vec); + if (n) + goto simsimd_l2sq_u8_ice_cycle; + + *result = _mm512_reduce_add_epi32(_mm512_add_epi32(d2_i32_low_vec, d2_i32_high_vec)); +} + +SIMSIMD_PUBLIC void simsimd_cos_u8_ice(simsimd_u8_t const* a, simsimd_u8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m512i ab_i32_low_vec = _mm512_setzero_si512(); + __m512i ab_i32_high_vec = _mm512_setzero_si512(); + __m512i a2_i32_low_vec = _mm512_setzero_si512(); + __m512i a2_i32_high_vec = _mm512_setzero_si512(); + __m512i b2_i32_low_vec = _mm512_setzero_si512(); + __m512i b2_i32_high_vec = _mm512_setzero_si512(); + __m512i const zeros_vec = _mm512_setzero_si512(); + __m512i a_i16_low_vec, a_i16_high_vec, b_i16_low_vec, b_i16_high_vec; + __m512i a_u8_vec, b_u8_vec; + +simsimd_cos_u8_ice_cycle: + if (n < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); + a_u8_vec = _mm512_maskz_loadu_epi8(mask, a); + b_u8_vec = _mm512_maskz_loadu_epi8(mask, b); + n = 0; + } else { + a_u8_vec = _mm512_loadu_si512(a); + b_u8_vec = _mm512_loadu_si512(b); + a += 64, b += 64, n -= 64; + } + + // Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking + // instructions instead of extracts, as they are much faster and more efficient. + a_i16_low_vec = _mm512_unpacklo_epi8(a_u8_vec, zeros_vec); + a_i16_high_vec = _mm512_unpackhi_epi8(a_u8_vec, zeros_vec); + b_i16_low_vec = _mm512_unpacklo_epi8(b_u8_vec, zeros_vec); + b_i16_high_vec = _mm512_unpackhi_epi8(b_u8_vec, zeros_vec); + + // Multiply and accumulate as `int16`, accumulate products as `int32`: + ab_i32_low_vec = _mm512_dpwssds_epi32(ab_i32_low_vec, a_i16_low_vec, b_i16_low_vec); + ab_i32_high_vec = _mm512_dpwssds_epi32(ab_i32_high_vec, a_i16_high_vec, b_i16_high_vec); + a2_i32_low_vec = _mm512_dpwssds_epi32(a2_i32_low_vec, a_i16_low_vec, a_i16_low_vec); + a2_i32_high_vec = _mm512_dpwssds_epi32(a2_i32_high_vec, a_i16_high_vec, a_i16_high_vec); + b2_i32_low_vec = _mm512_dpwssds_epi32(b2_i32_low_vec, b_i16_low_vec, b_i16_low_vec); + b2_i32_high_vec = _mm512_dpwssds_epi32(b2_i32_high_vec, b_i16_high_vec, b_i16_high_vec); + if (n) + goto simsimd_cos_u8_ice_cycle; + + int ab = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_i32_low_vec, ab_i32_high_vec)); + int a2 = _mm512_reduce_add_epi32(_mm512_add_epi32(a2_i32_low_vec, a2_i32_high_vec)); + int b2 = _mm512_reduce_add_epi32(_mm512_add_epi32(b2_i32_low_vec, b2_i32_high_vec)); + *result = _simsimd_cos_normalize_f32_haswell(ab, a2, b2); +} SIMSIMD_PUBLIC void simsimd_l2_i4x2_ice(simsimd_i4x2_t const* a, simsimd_i4x2_t const* b, simsimd_size_t n_words, simsimd_distance_t* result) { diff --git a/scripts/bench.cxx b/scripts/bench.cxx index 93f7761e..b438e669 100644 --- a/scripts/bench.cxx +++ b/scripts/bench.cxx @@ -656,9 +656,13 @@ int main(int argc, char** argv) { dense_("cos_f64_neon", simsimd_cos_f64_neon, simsimd_cos_f64_serial); dense_("l2sq_f64_neon", simsimd_l2sq_f64_neon, simsimd_l2sq_f64_serial); - dense_("cos_i8_neon", simsimd_cos_i8_neon, simsimd_cos_i8_accurate); + dense_("cos_i8_neon", simsimd_cos_i8_neon, simsimd_cos_i8_serial); + dense_("l2sq_i8_neon", simsimd_l2sq_i8_neon, simsimd_l2sq_i8_serial); dense_("dot_i8_neon", simsimd_dot_i8_neon, simsimd_dot_i8_serial); - dense_("l2sq_i8_neon", simsimd_l2sq_i8_neon, simsimd_l2sq_i8_accurate); + + dense_("cos_u8_neon", simsimd_cos_u8_neon, simsimd_cos_u8_serial); + dense_("l2sq_u8_neon", simsimd_l2sq_u8_neon, simsimd_l2sq_u8_serial); + dense_("dot_u8_neon", simsimd_dot_u8_neon, simsimd_dot_u8_serial); dense_("hamming_b8_neon", simsimd_hamming_b8_neon, simsimd_hamming_b8_serial); dense_("jaccard_b8_neon", simsimd_jaccard_b8_neon, simsimd_jaccard_b8_serial); @@ -724,9 +728,13 @@ int main(int argc, char** argv) { dense_("cos_bf16_haswell", simsimd_cos_bf16_haswell, simsimd_cos_bf16_accurate); dense_("l2sq_bf16_haswell", simsimd_l2sq_bf16_haswell, simsimd_l2sq_bf16_accurate); - dense_("cos_i8_haswell", simsimd_cos_i8_haswell, simsimd_cos_i8_accurate); + dense_("cos_i8_haswell", simsimd_cos_i8_haswell, simsimd_cos_i8_serial); + dense_("l2sq_i8_haswell", simsimd_l2sq_i8_haswell, simsimd_l2sq_i8_serial); dense_("dot_i8_haswell", simsimd_dot_i8_haswell, simsimd_dot_i8_serial); - dense_("l2sq_i8_haswell", simsimd_l2sq_i8_haswell, simsimd_l2sq_i8_accurate); + + dense_("cos_u8_haswell", simsimd_cos_u8_haswell, simsimd_cos_u8_serial); + dense_("l2sq_u8_haswell", simsimd_l2sq_u8_haswell, simsimd_l2sq_u8_serial); + dense_("dot_u8_haswell", simsimd_dot_u8_haswell, simsimd_dot_u8_serial); dense_("hamming_b8_haswell", simsimd_hamming_b8_haswell, simsimd_hamming_b8_serial); dense_("jaccard_b8_haswell", simsimd_jaccard_b8_haswell, simsimd_jaccard_b8_serial); @@ -767,9 +775,13 @@ int main(int argc, char** argv) { #endif #if SIMSIMD_TARGET_ICE - dense_("cos_i8_ice", simsimd_cos_i8_ice, simsimd_cos_i8_accurate); + dense_("cos_i8_ice", simsimd_cos_i8_ice, simsimd_cos_i8_serial); + dense_("l2sq_i8_ice", simsimd_l2sq_i8_ice, simsimd_l2sq_i8_serial); dense_("dot_i8_ice", simsimd_dot_i8_ice, simsimd_dot_i8_serial); - dense_("l2sq_i8_ice", simsimd_l2sq_i8_ice, simsimd_l2sq_i8_accurate); + + dense_("cos_u8_ice", simsimd_cos_u8_ice, simsimd_cos_u8_serial); + dense_("l2sq_u8_ice", simsimd_l2sq_u8_ice, simsimd_l2sq_u8_serial); + dense_("dot_u8_ice", simsimd_dot_u8_ice, simsimd_dot_u8_serial); dense_("dot_f64_skylake", simsimd_dot_f64_skylake, simsimd_dot_f64_serial); dense_("cos_f64_skylake", simsimd_cos_f64_skylake, simsimd_cos_f64_serial); @@ -831,9 +843,13 @@ int main(int argc, char** argv) { dense_("cos_f64_serial", simsimd_cos_f64_serial, simsimd_cos_f64_serial); dense_("l2sq_f64_serial", simsimd_l2sq_f64_serial, simsimd_l2sq_f64_serial); - dense_("cos_i8_serial", simsimd_cos_i8_serial, simsimd_cos_i8_accurate); + dense_("cos_i8_serial", simsimd_cos_i8_serial, simsimd_cos_i8_serial); + dense_("l2sq_i8_serial", simsimd_l2sq_i8_serial, simsimd_l2sq_i8_serial); dense_("dot_i8_serial", simsimd_dot_i8_serial, simsimd_dot_i8_serial); - dense_("l2sq_i8_serial", simsimd_l2sq_i8_serial, simsimd_l2sq_i8_accurate); + + dense_("cos_u8_serial", simsimd_cos_u8_serial, simsimd_cos_u8_serial); + dense_("l2sq_u8_serial", simsimd_l2sq_u8_serial, simsimd_l2sq_u8_serial); + dense_("dot_u8_serial", simsimd_dot_u8_serial, simsimd_dot_u8_serial); dense_("dot_f64c_serial", simsimd_dot_f64c_serial, simsimd_dot_f64c_serial); dense_("dot_f32c_serial", simsimd_dot_f32c_serial, simsimd_dot_f32c_accurate); diff --git a/scripts/test.py b/scripts/test.py index 4076374e..2cff738c 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -51,7 +51,6 @@ import simsimd as simd - # NumPy is available on most platforms and is required for most tests. # When using PyPy on some platforms NumPy has internal issues, that will # raise a weird error, not an `ImportError`. That's why we intentionally @@ -356,10 +355,22 @@ def collect_warnings(message: str, stats: dict): # We will run all the tests many times using different instruction sets under the hood. available_capabilities: Dict[str, str] = simd.get_capabilities() possible_x86_capabilities: List[str] = ["haswell", "ice", "skylake", "sapphire", "turin", "genoa", "sierra"] -possible_arm_capabilities: List[str] = ["neon", "neon_f16", "neon_bf16", "neon_i8", "sve", "sve_f16", "sve_bf16", "sve_i8"] +possible_arm_capabilities: List[str] = [ + "neon", + "neon_f16", + "neon_bf16", + "neon_i8", + "sve", + "sve_f16", + "sve_bf16", + "sve_i8", +] possible_x86_capabilities: List[str] = [c for c in possible_x86_capabilities if available_capabilities[c]] possible_arm_capabilities: List[str] = [c for c in possible_arm_capabilities if available_capabilities[c]] -possible_capabilities: List[str] = possible_x86_capabilities if platform.machine() == "x86_64" else possible_arm_capabilities +possible_capabilities: List[str] = ( + possible_x86_capabilities if platform.machine() == "x86_64" else possible_arm_capabilities +) + def keep_one_capability(cap: str): assert cap in possible_capabilities @@ -446,6 +457,10 @@ def test_pointers_availability(): assert simd.pointer_to_cosine("i8") != 0 assert simd.pointer_to_inner("i8") != 0 + assert simd.pointer_to_sqeuclidean("u8") != 0 + assert simd.pointer_to_cosine("u8") != 0 + assert simd.pointer_to_inner("u8") != 0 + def test_capabilities_list(): """Tests the visibility of hardware capabilities.""" @@ -698,16 +713,21 @@ def test_curved_bf16(ndim, metric, capability, stats_fixture): @pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") @pytest.mark.repeat(50) @pytest.mark.parametrize("ndim", [11, 97, 1536]) +@pytest.mark.parametrize("dtype", ["int8", "uint8"]) @pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"]) @pytest.mark.parametrize("capability", possible_capabilities) -def test_dense_i8(ndim, metric, capability, stats_fixture): +def test_dense_i8(ndim, dtype, metric, capability, stats_fixture): """Compares various SIMD kernels (like Dot-products, squared Euclidean, and Cosine distances) with their NumPy or baseline counterparts, testing accuracy for small integer types, that can't be directly processed with other tools without overflowing.""" np.random.seed() - a = np.random.randint(-128, 127, size=(ndim), dtype=np.int8) - b = np.random.randint(-128, 127, size=(ndim), dtype=np.int8) + if dtype == "int8": + a = np.random.randint(-128, 127, size=(ndim), dtype=np.int8) + b = np.random.randint(-128, 127, size=(ndim), dtype=np.int8) + else: + a = np.random.randint(0, 255, size=(ndim), dtype=np.uint8) + b = np.random.randint(0, 255, size=(ndim), dtype=np.uint8) keep_one_capability(capability) baseline_kernel, simd_kernel = name_to_kernels(metric) @@ -719,8 +739,10 @@ def test_dense_i8(ndim, metric, capability, stats_fixture): if metric == "inner": assert round(float(result)) == round(float(expected)), f"Expected {expected}, but got {result}" else: - np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL), f"Expected {expected}, but got {result}" - collect_errors(metric, ndim, "int8", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) + np.testing.assert_allclose( + result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL + ), f"Expected {expected}, but got {result}" + collect_errors(metric, ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) #! Fun fact: SciPy doesn't actually raise an `OverflowError` when overflow happens #! here, instead it raises `ValueError: math domain error` during the `sqrt` operation.