diff --git a/include/simsimd/fma.h b/include/simsimd/fma.h index ec98171..4aac3af 100644 --- a/include/simsimd/fma.h +++ b/include/simsimd/fma.h @@ -109,23 +109,27 @@ SIMSIMD_MAKE_WSUM(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) / SIMSIMD_MAKE_WSUM(serial, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_serial SIMSIMD_MAKE_WSUM(serial, f16, f32, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_serial SIMSIMD_MAKE_WSUM(serial, bf16, f32, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_serial -SIMSIMD_MAKE_WSUM(serial, i8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_i8_serial -SIMSIMD_MAKE_WSUM(serial, u8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_u8_serial +SIMSIMD_MAKE_WSUM(serial, i8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_I8) // simsimd_wsum_i8_serial +SIMSIMD_MAKE_WSUM(serial, u8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_U8) // simsimd_wsum_u8_serial + +SIMSIMD_MAKE_WSUM(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_accurate +SIMSIMD_MAKE_WSUM(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_accurate +SIMSIMD_MAKE_WSUM(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_accurate +SIMSIMD_MAKE_WSUM(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_wsum_i8_accurate +SIMSIMD_MAKE_WSUM(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_wsum_u8_accurate SIMSIMD_MAKE_FMA(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f64_serial SIMSIMD_MAKE_FMA(serial, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f32_serial SIMSIMD_MAKE_FMA(serial, f16, f32, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_fma_f16_serial SIMSIMD_MAKE_FMA(serial, bf16, f32, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_fma_bf16_serial -SIMSIMD_MAKE_FMA(serial, i8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_i8_serial -SIMSIMD_MAKE_FMA(serial, u8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_u8_serial - -SIMSIMD_MAKE_WSUM(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_accurate -SIMSIMD_MAKE_WSUM(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_accurate -SIMSIMD_MAKE_WSUM(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_accurate +SIMSIMD_MAKE_FMA(serial, i8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_I8) // simsimd_fma_i8_serial +SIMSIMD_MAKE_FMA(serial, u8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_U8) // simsimd_fma_u8_serial SIMSIMD_MAKE_FMA(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f32_accurate SIMSIMD_MAKE_FMA(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_fma_f16_accurate SIMSIMD_MAKE_FMA(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_fma_bf16_accurate +SIMSIMD_MAKE_FMA(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_fma_i8_accurate +SIMSIMD_MAKE_FMA(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_fma_u8_accurate SIMSIMD_PUBLIC void simsimd_wsum_f64_haswell( // simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, // @@ -513,8 +517,8 @@ SIMSIMD_PUBLIC void simsimd_fma_f32_neon( // #if SIMSIMD_TARGET_NEON_F16 #pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) +#pragma GCC target("arch=armv8.2-a+simd+fp16") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) SIMSIMD_PUBLIC void simsimd_wsum_f16_neon( // simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, // @@ -578,13 +582,14 @@ SIMSIMD_PUBLIC void simsimd_wsum_u8_neon( // float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); - uint8x8_t sum_u8_vec = vmovn_u16(vcvtq_u16_f16(sum_vec)); + uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec)); vst1_u8(result + i, sum_u8_vec); } // The tail: - for (; i < n; ++i) - result[i] = (simsimd_u8_t)(alpha_f16 * a[i] + beta_f16 * b[i]); + for (; i < n; ++i) { + SIMSIMD_F32_TO_U8(alpha_f16 * a[i] + beta_f16 * b[i], result + i); + } } SIMSIMD_PUBLIC void simsimd_fma_u8_neon( // @@ -605,13 +610,68 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_neon( // float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); - uint8x8_t sum_u8_vec = vmovn_u16(vcvtq_u16_f16(sum_vec)); + uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec)); vst1_u8(result + i, sum_u8_vec); } // The tail: - for (; i < n; ++i) - result[i] = (simsimd_u8_t)(alpha_f16 * a[i] * b[i] + beta_f16 * c[i]); + for (; i < n; ++i) { + SIMSIMD_F32_TO_U8(alpha_f16 * a[i] * b[i] + beta_f16 * c[i], result + i); + } +} + +SIMSIMD_PUBLIC void simsimd_wsum_i8_neon( // + simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, // + simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t* result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + int8x8_t a_i8_vec = vld1_s8(a + i); + int8x8_t b_i8_vec = vld1_s8(b + i); + float16x8_t a_vec = vcvtq_f16_s16(vmovl_s8(a_i8_vec)); + float16x8_t b_vec = vcvtq_f16_s16(vmovl_s8(b_i8_vec)); + float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16); + float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16); + float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec); + int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec)); + vst1_s8(result + i, sum_i8_vec); + } + + // The tail: + for (; i < n; ++i) { + SIMSIMD_F32_TO_I8(alpha_f16 * a[i] + beta_f16 * b[i], result + i); + } +} + +SIMSIMD_PUBLIC void simsimd_fma_i8_neon( // + simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_i8_t const* c, // + simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t* result) { + float16_t alpha_f16 = (float16_t)alpha; + float16_t beta_f16 = (float16_t)beta; + + // The main loop: + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + int8x8_t a_i8_vec = vld1_s8(a + i); + int8x8_t b_i8_vec = vld1_s8(b + i); + int8x8_t c_i8_vec = vld1_s8(c + i); + float16x8_t a_vec = vcvtq_f16_s16(vmovl_s8(a_i8_vec)); + float16x8_t b_vec = vcvtq_f16_s16(vmovl_s8(b_i8_vec)); + float16x8_t c_vec = vcvtq_f16_s16(vmovl_s8(c_i8_vec)); + float16x8_t ab_vec = vmulq_f16(a_vec, b_vec); + float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16); + float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16); + int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec)); + vst1_s8(result + i, sum_i8_vec); + } + + // The tail: + for (; i < n; ++i) { + SIMSIMD_F32_TO_I8(alpha_f16 * a[i] * b[i] + beta_f16 * c[i], result + i); + } } #pragma clang attribute pop diff --git a/include/simsimd/types.h b/include/simsimd/types.h index 0447ccb..503f858 100644 --- a/include/simsimd/types.h +++ b/include/simsimd/types.h @@ -394,6 +394,19 @@ SIMSIMD_STATIC_ASSERT(sizeof(simsimd_bf16_t) == 2, simsimd_bf16_t_must_be_2_byte #endif #endif +#if !defined(SIMSIMD_F32_TO_I8) +#define SIMSIMD_F32_TO_I8(x, y) *(y) = (simsimd_i8_t)roundf(x) +#endif +#if !defined(SIMSIMD_F32_TO_U8) +#define SIMSIMD_F32_TO_U8(x, y) *(y) = (simsimd_u8_t)roundf(x) +#endif +#if !defined(SIMSIMD_F64_TO_I8) +#define SIMSIMD_F64_TO_I8(x, y) *(y) = (simsimd_i8_t)round(x) +#endif +#if !defined(SIMSIMD_F64_TO_U8) +#define SIMSIMD_F64_TO_U8(x, y) *(y) = (simsimd_u8_t)round(x) +#endif + /** @brief Convenience type for half-precision floating-point type conversions. */ typedef union { unsigned i; diff --git a/scripts/bench.cxx b/scripts/bench.cxx index 0d94ee2..363adf8 100644 --- a/scripts/bench.cxx +++ b/scripts/bench.cxx @@ -522,14 +522,14 @@ void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metr constexpr bool takes_three_vectors_k = std::tuple_size::arg_tuple>::value == 6; auto call_baseline = [&](vector_t const& a, vector_t const& b, vector_t const& c, vector_t& d) { if constexpr (takes_three_vectors_k) { - baseline(a.data(), b.data(), a.dimensions(), alpha, beta, d.data()); + baseline(a.data(), c.data(), a.dimensions(), alpha, beta, d.data()); } else { baseline(a.data(), b.data(), c.data(), a.dimensions(), alpha, beta, d.data()); } }; auto call_contender = [&](vector_t const& a, vector_t const& b, vector_t const& c, vector_t& d) { if constexpr (takes_three_vectors_k) { - kernel(a.data(), b.data(), a.dimensions(), alpha, beta, d.data()); + kernel(a.data(), c.data(), a.dimensions(), alpha, beta, d.data()); } else { kernel(a.data(), b.data(), c.data(), a.dimensions(), alpha, beta, d.data()); } @@ -545,8 +545,8 @@ void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metr auto& quad = quads[i]; quad.a = quad.b = quad.c = quad.d = vector_t(dimensions); quad.a.randomize(static_cast(i)); - quad.b.randomize(static_cast(i) + 54321u); - quad.c.randomize(static_cast(i) + 6789u); + quad.b.set(2); // Having a small constant here will help avoid overflows + quad.c.randomize(static_cast(i) + 54321u); } // Initialize the output buffers for distance calculations. @@ -563,10 +563,13 @@ void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metr l2_metric(baseline_d.data(), contender_d.data(), dimensions, &l2_metric_from_baseline[i]); l2_metric(baseline_d.data(), zeros.data(), dimensions, &l2_baseline_result_norm[i]); l2_metric(contender_d.data(), zeros.data(), dimensions, &l2_contender_result_norm[i]); - mean_delta += l2_metric_from_baseline[i]; - mean_relative_error += std::abs(l2_baseline_result_norm[i] - l2_contender_result_norm[i]) / - std::max(l2_baseline_result_norm[i], l2_contender_result_norm[i]); + + mean_delta += std::abs(l2_metric_from_baseline[i]); + mean_relative_error += + std::abs(l2_metric_from_baseline[i]) / (std::max)(l2_baseline_result_norm[i], l2_contender_result_norm[i]); } + mean_delta /= quads_count; + mean_relative_error /= quads_count; // The actual benchmarking loop. std::size_t iterations = 0; @@ -751,18 +754,6 @@ int main(int argc, char** argv) { #endif #if SIMSIMD_TARGET_NEON - dense_("dot_f16_neon", simsimd_dot_f16_neon, simsimd_dot_f16_accurate); - dense_("cos_f16_neon", simsimd_cos_f16_neon, simsimd_cos_f16_accurate); - dense_("l2sq_f16_neon", simsimd_l2sq_f16_neon, simsimd_l2sq_f16_accurate); - dense_("l2_f16_neon", simsimd_l2_f16_neon, simsimd_l2_f16_accurate); - dense_("kl_f16_neon", simsimd_kl_f16_neon, simsimd_kl_f16_accurate); - dense_("js_f16_neon", simsimd_js_f16_neon, simsimd_js_f16_accurate); - - dense_("dot_bf16_neon", simsimd_dot_bf16_neon, simsimd_dot_bf16_accurate); - dense_("cos_bf16_neon", simsimd_cos_bf16_neon, simsimd_cos_bf16_accurate); - dense_("l2sq_bf16_neon", simsimd_l2sq_bf16_neon, simsimd_l2sq_bf16_accurate); - dense_("l2_bf16_neon", simsimd_l2_bf16_neon, simsimd_l2_bf16_accurate); - dense_("dot_f32_neon", simsimd_dot_f32_neon, simsimd_dot_f32_accurate); dense_("cos_f32_neon", simsimd_cos_f32_neon, simsimd_cos_f32_accurate); dense_("l2sq_f32_neon", simsimd_l2sq_f32_neon, simsimd_l2sq_f32_accurate); @@ -821,8 +812,10 @@ int main(int argc, char** argv) { fma_("wsum_f16_neon", simsimd_wsum_f16_neon, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate); // FMA kernels for `u8` on NEON use `f16` arithmetic - fma_("fma_u8_neon", simsimd_fma_u8_neon, simsimd_fma_u8_serial, simsimd_l2_u8_serial); - fma_("wsum_u8_neon", simsimd_wsum_u8_neon, simsimd_wsum_u8_serial, simsimd_l2_u8_serial); + fma_("fma_u8_neon", simsimd_fma_u8_neon, simsimd_fma_u8_accurate, simsimd_l2_u8_serial); + fma_("wsum_u8_neon", simsimd_wsum_u8_neon, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial); + fma_("fma_i8_neon", simsimd_fma_i8_neon, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); + fma_("wsum_i8_neon", simsimd_wsum_i8_neon, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); #endif #if SIMSIMD_TARGET_NEON_BF16 @@ -1054,8 +1047,10 @@ int main(int argc, char** argv) { fma_("fma_f16_serial", simsimd_fma_f16_serial, simsimd_fma_f16_accurate, simsimd_l2_f16_accurate); fma_("wsum_f16_serial", simsimd_wsum_f16_serial, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate); - fma_("fma_u8_serial", simsimd_fma_u8_serial, simsimd_fma_u8_serial, simsimd_l2_u8_serial); - fma_("wsum_u8_serial", simsimd_wsum_u8_serial, simsimd_wsum_u8_serial, simsimd_l2_u8_serial); + fma_("fma_u8_serial", simsimd_fma_u8_serial, simsimd_fma_u8_accurate, simsimd_l2_u8_serial); + fma_("wsum_u8_serial", simsimd_wsum_u8_serial, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial); + fma_("fma_i8_serial", simsimd_fma_i8_serial, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); + fma_("wsum_i8_serial", simsimd_wsum_i8_serial, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); bm::RunSpecifiedBenchmarks(); bm::Shutdown();