Skip to content

Commit

Permalink
Fix: Float rounding in FMA
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 19, 2024
1 parent d666f55 commit ea24713
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 39 deletions.
92 changes: 76 additions & 16 deletions include/simsimd/fma.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,27 @@ SIMSIMD_MAKE_WSUM(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) /
SIMSIMD_MAKE_WSUM(serial, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_serial
SIMSIMD_MAKE_WSUM(serial, f16, f32, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_serial
SIMSIMD_MAKE_WSUM(serial, bf16, f32, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_serial
SIMSIMD_MAKE_WSUM(serial, i8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_i8_serial
SIMSIMD_MAKE_WSUM(serial, u8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_u8_serial
SIMSIMD_MAKE_WSUM(serial, i8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_I8) // simsimd_wsum_i8_serial
SIMSIMD_MAKE_WSUM(serial, u8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_U8) // simsimd_wsum_u8_serial

SIMSIMD_MAKE_WSUM(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_accurate
SIMSIMD_MAKE_WSUM(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_accurate
SIMSIMD_MAKE_WSUM(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_accurate
SIMSIMD_MAKE_WSUM(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_wsum_i8_accurate
SIMSIMD_MAKE_WSUM(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_wsum_u8_accurate

SIMSIMD_MAKE_FMA(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f64_serial
SIMSIMD_MAKE_FMA(serial, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f32_serial
SIMSIMD_MAKE_FMA(serial, f16, f32, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_fma_f16_serial
SIMSIMD_MAKE_FMA(serial, bf16, f32, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_fma_bf16_serial
SIMSIMD_MAKE_FMA(serial, i8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_i8_serial
SIMSIMD_MAKE_FMA(serial, u8, i64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_u8_serial

SIMSIMD_MAKE_WSUM(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_wsum_f32_accurate
SIMSIMD_MAKE_WSUM(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_wsum_f16_accurate
SIMSIMD_MAKE_WSUM(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_wsum_bf16_accurate
SIMSIMD_MAKE_FMA(serial, i8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_I8) // simsimd_fma_i8_serial
SIMSIMD_MAKE_FMA(serial, u8, f32, SIMSIMD_DEREFERENCE, SIMSIMD_F32_TO_U8) // simsimd_fma_u8_serial

SIMSIMD_MAKE_FMA(accurate, f32, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT) // simsimd_fma_f32_accurate
SIMSIMD_MAKE_FMA(accurate, f16, f64, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16) // simsimd_fma_f16_accurate
SIMSIMD_MAKE_FMA(accurate, bf16, f64, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16) // simsimd_fma_bf16_accurate
SIMSIMD_MAKE_FMA(accurate, i8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_I8) // simsimd_fma_i8_accurate
SIMSIMD_MAKE_FMA(accurate, u8, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F64_TO_U8) // simsimd_fma_u8_accurate

SIMSIMD_PUBLIC void simsimd_wsum_f64_haswell( //
simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, //
Expand Down Expand Up @@ -513,8 +517,8 @@ SIMSIMD_PUBLIC void simsimd_fma_f32_neon( //

#if SIMSIMD_TARGET_NEON_F16
#pragma GCC push_options
#pragma GCC target("arch=armv8.2-a+simd")
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function)
#pragma GCC target("arch=armv8.2-a+simd+fp16")
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)

SIMSIMD_PUBLIC void simsimd_wsum_f16_neon( //
simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, //
Expand Down Expand Up @@ -578,13 +582,14 @@ SIMSIMD_PUBLIC void simsimd_wsum_u8_neon( //
float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16);
float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16);
float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec);
uint8x8_t sum_u8_vec = vmovn_u16(vcvtq_u16_f16(sum_vec));
uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec));
vst1_u8(result + i, sum_u8_vec);
}

// The tail:
for (; i < n; ++i)
result[i] = (simsimd_u8_t)(alpha_f16 * a[i] + beta_f16 * b[i]);
for (; i < n; ++i) {
SIMSIMD_F32_TO_U8(alpha_f16 * a[i] + beta_f16 * b[i], result + i);
}
}

SIMSIMD_PUBLIC void simsimd_fma_u8_neon( //
Expand All @@ -605,13 +610,68 @@ SIMSIMD_PUBLIC void simsimd_fma_u8_neon( //
float16x8_t ab_vec = vmulq_f16(a_vec, b_vec);
float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16);
float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16);
uint8x8_t sum_u8_vec = vmovn_u16(vcvtq_u16_f16(sum_vec));
uint8x8_t sum_u8_vec = vmovn_u16(vcvtaq_u16_f16(sum_vec));
vst1_u8(result + i, sum_u8_vec);
}

// The tail:
for (; i < n; ++i)
result[i] = (simsimd_u8_t)(alpha_f16 * a[i] * b[i] + beta_f16 * c[i]);
for (; i < n; ++i) {
SIMSIMD_F32_TO_U8(alpha_f16 * a[i] * b[i] + beta_f16 * c[i], result + i);
}
}

SIMSIMD_PUBLIC void simsimd_wsum_i8_neon( //
simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, //
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t* result) {
float16_t alpha_f16 = (float16_t)alpha;
float16_t beta_f16 = (float16_t)beta;

// The main loop:
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
int8x8_t a_i8_vec = vld1_s8(a + i);
int8x8_t b_i8_vec = vld1_s8(b + i);
float16x8_t a_vec = vcvtq_f16_s16(vmovl_s8(a_i8_vec));
float16x8_t b_vec = vcvtq_f16_s16(vmovl_s8(b_i8_vec));
float16x8_t a_scaled_vec = vmulq_n_f16(a_vec, alpha_f16);
float16x8_t b_scaled_vec = vmulq_n_f16(b_vec, beta_f16);
float16x8_t sum_vec = vaddq_f16(a_scaled_vec, b_scaled_vec);
int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec));
vst1_s8(result + i, sum_i8_vec);
}

// The tail:
for (; i < n; ++i) {
SIMSIMD_F32_TO_I8(alpha_f16 * a[i] + beta_f16 * b[i], result + i);
}
}

SIMSIMD_PUBLIC void simsimd_fma_i8_neon( //
simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_i8_t const* c, //
simsimd_size_t n, simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t* result) {
float16_t alpha_f16 = (float16_t)alpha;
float16_t beta_f16 = (float16_t)beta;

// The main loop:
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
int8x8_t a_i8_vec = vld1_s8(a + i);
int8x8_t b_i8_vec = vld1_s8(b + i);
int8x8_t c_i8_vec = vld1_s8(c + i);
float16x8_t a_vec = vcvtq_f16_s16(vmovl_s8(a_i8_vec));
float16x8_t b_vec = vcvtq_f16_s16(vmovl_s8(b_i8_vec));
float16x8_t c_vec = vcvtq_f16_s16(vmovl_s8(c_i8_vec));
float16x8_t ab_vec = vmulq_f16(a_vec, b_vec);
float16x8_t ab_scaled_vec = vmulq_n_f16(ab_vec, alpha_f16);
float16x8_t sum_vec = vfmaq_n_f16(ab_scaled_vec, c_vec, beta_f16);
int8x8_t sum_i8_vec = vmovn_s16(vcvtaq_s16_f16(sum_vec));
vst1_s8(result + i, sum_i8_vec);
}

// The tail:
for (; i < n; ++i) {
SIMSIMD_F32_TO_I8(alpha_f16 * a[i] * b[i] + beta_f16 * c[i], result + i);
}
}

#pragma clang attribute pop
Expand Down
13 changes: 13 additions & 0 deletions include/simsimd/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,19 @@ SIMSIMD_STATIC_ASSERT(sizeof(simsimd_bf16_t) == 2, simsimd_bf16_t_must_be_2_byte
#endif
#endif

#if !defined(SIMSIMD_F32_TO_I8)
#define SIMSIMD_F32_TO_I8(x, y) *(y) = (simsimd_i8_t)roundf(x)
#endif
#if !defined(SIMSIMD_F32_TO_U8)
#define SIMSIMD_F32_TO_U8(x, y) *(y) = (simsimd_u8_t)roundf(x)
#endif
#if !defined(SIMSIMD_F64_TO_I8)
#define SIMSIMD_F64_TO_I8(x, y) *(y) = (simsimd_i8_t)round(x)
#endif
#if !defined(SIMSIMD_F64_TO_U8)
#define SIMSIMD_F64_TO_U8(x, y) *(y) = (simsimd_u8_t)round(x)
#endif

/** @brief Convenience type for half-precision floating-point type conversions. */
typedef union {
unsigned i;
Expand Down
41 changes: 18 additions & 23 deletions scripts/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,14 @@ void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metr
constexpr bool takes_three_vectors_k = std::tuple_size<typename function_traits<kernel_at>::arg_tuple>::value == 6;
auto call_baseline = [&](vector_t const& a, vector_t const& b, vector_t const& c, vector_t& d) {
if constexpr (takes_three_vectors_k) {
baseline(a.data(), b.data(), a.dimensions(), alpha, beta, d.data());
baseline(a.data(), c.data(), a.dimensions(), alpha, beta, d.data());
} else {
baseline(a.data(), b.data(), c.data(), a.dimensions(), alpha, beta, d.data());
}
};
auto call_contender = [&](vector_t const& a, vector_t const& b, vector_t const& c, vector_t& d) {
if constexpr (takes_three_vectors_k) {
kernel(a.data(), b.data(), a.dimensions(), alpha, beta, d.data());
kernel(a.data(), c.data(), a.dimensions(), alpha, beta, d.data());
} else {
kernel(a.data(), b.data(), c.data(), a.dimensions(), alpha, beta, d.data());
}
Expand All @@ -545,8 +545,8 @@ void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metr
auto& quad = quads[i];
quad.a = quad.b = quad.c = quad.d = vector_t(dimensions);
quad.a.randomize(static_cast<std::uint32_t>(i));
quad.b.randomize(static_cast<std::uint32_t>(i) + 54321u);
quad.c.randomize(static_cast<std::uint32_t>(i) + 6789u);
quad.b.set(2); // Having a small constant here will help avoid overflows
quad.c.randomize(static_cast<std::uint32_t>(i) + 54321u);
}

// Initialize the output buffers for distance calculations.
Expand All @@ -563,10 +563,13 @@ void measure_fma(bm::State& state, kernel_at kernel, kernel_at baseline, l2_metr
l2_metric(baseline_d.data(), contender_d.data(), dimensions, &l2_metric_from_baseline[i]);
l2_metric(baseline_d.data(), zeros.data(), dimensions, &l2_baseline_result_norm[i]);
l2_metric(contender_d.data(), zeros.data(), dimensions, &l2_contender_result_norm[i]);
mean_delta += l2_metric_from_baseline[i];
mean_relative_error += std::abs(l2_baseline_result_norm[i] - l2_contender_result_norm[i]) /
std::max(l2_baseline_result_norm[i], l2_contender_result_norm[i]);

mean_delta += std::abs(l2_metric_from_baseline[i]);
mean_relative_error +=
std::abs(l2_metric_from_baseline[i]) / (std::max)(l2_baseline_result_norm[i], l2_contender_result_norm[i]);
}
mean_delta /= quads_count;
mean_relative_error /= quads_count;

// The actual benchmarking loop.
std::size_t iterations = 0;
Expand Down Expand Up @@ -751,18 +754,6 @@ int main(int argc, char** argv) {
#endif

#if SIMSIMD_TARGET_NEON
dense_<f16_k>("dot_f16_neon", simsimd_dot_f16_neon, simsimd_dot_f16_accurate);
dense_<f16_k>("cos_f16_neon", simsimd_cos_f16_neon, simsimd_cos_f16_accurate);
dense_<f16_k>("l2sq_f16_neon", simsimd_l2sq_f16_neon, simsimd_l2sq_f16_accurate);
dense_<f16_k>("l2_f16_neon", simsimd_l2_f16_neon, simsimd_l2_f16_accurate);
dense_<f16_k>("kl_f16_neon", simsimd_kl_f16_neon, simsimd_kl_f16_accurate);
dense_<f16_k>("js_f16_neon", simsimd_js_f16_neon, simsimd_js_f16_accurate);

dense_<bf16_k>("dot_bf16_neon", simsimd_dot_bf16_neon, simsimd_dot_bf16_accurate);
dense_<bf16_k>("cos_bf16_neon", simsimd_cos_bf16_neon, simsimd_cos_bf16_accurate);
dense_<bf16_k>("l2sq_bf16_neon", simsimd_l2sq_bf16_neon, simsimd_l2sq_bf16_accurate);
dense_<bf16_k>("l2_bf16_neon", simsimd_l2_bf16_neon, simsimd_l2_bf16_accurate);

dense_<f32_k>("dot_f32_neon", simsimd_dot_f32_neon, simsimd_dot_f32_accurate);
dense_<f32_k>("cos_f32_neon", simsimd_cos_f32_neon, simsimd_cos_f32_accurate);
dense_<f32_k>("l2sq_f32_neon", simsimd_l2sq_f32_neon, simsimd_l2sq_f32_accurate);
Expand Down Expand Up @@ -821,8 +812,10 @@ int main(int argc, char** argv) {
fma_<f16_k>("wsum_f16_neon", simsimd_wsum_f16_neon, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate);

// FMA kernels for `u8` on NEON use `f16` arithmetic
fma_<u8_k>("fma_u8_neon", simsimd_fma_u8_neon, simsimd_fma_u8_serial, simsimd_l2_u8_serial);
fma_<u8_k>("wsum_u8_neon", simsimd_wsum_u8_neon, simsimd_wsum_u8_serial, simsimd_l2_u8_serial);
fma_<u8_k>("fma_u8_neon", simsimd_fma_u8_neon, simsimd_fma_u8_accurate, simsimd_l2_u8_serial);
fma_<u8_k>("wsum_u8_neon", simsimd_wsum_u8_neon, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial);
fma_<i8_k>("fma_i8_neon", simsimd_fma_i8_neon, simsimd_fma_i8_accurate, simsimd_l2_i8_serial);
fma_<i8_k>("wsum_i8_neon", simsimd_wsum_i8_neon, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial);
#endif

#if SIMSIMD_TARGET_NEON_BF16
Expand Down Expand Up @@ -1054,8 +1047,10 @@ int main(int argc, char** argv) {

fma_<f16_k>("fma_f16_serial", simsimd_fma_f16_serial, simsimd_fma_f16_accurate, simsimd_l2_f16_accurate);
fma_<f16_k>("wsum_f16_serial", simsimd_wsum_f16_serial, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate);
fma_<u8_k>("fma_u8_serial", simsimd_fma_u8_serial, simsimd_fma_u8_serial, simsimd_l2_u8_serial);
fma_<u8_k>("wsum_u8_serial", simsimd_wsum_u8_serial, simsimd_wsum_u8_serial, simsimd_l2_u8_serial);
fma_<u8_k>("fma_u8_serial", simsimd_fma_u8_serial, simsimd_fma_u8_accurate, simsimd_l2_u8_serial);
fma_<u8_k>("wsum_u8_serial", simsimd_wsum_u8_serial, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial);
fma_<i8_k>("fma_i8_serial", simsimd_fma_i8_serial, simsimd_fma_i8_accurate, simsimd_l2_i8_serial);
fma_<i8_k>("wsum_i8_serial", simsimd_wsum_i8_serial, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial);

bm::RunSpecifiedBenchmarks();
bm::Shutdown();
Expand Down

0 comments on commit ea24713

Please sign in to comment.