Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New u8 Kernels #213

Merged
merged 4 commits into from
Oct 18, 2024
Merged

New u8 Kernels #213

merged 4 commits into from
Oct 18, 2024

Conversation

ashvardanian
Copy link
Owner

@ashvardanian ashvardanian commented Oct 17, 2024

Old Kernels, New Types: u8

On Intel Sapphire Rapids, for l2sq the throughput grows from 21 GB/s to 66 GB/s in AVX2 solution for Haswell and 94 GB/s in AVX-512 for Ice Lake+ CPUs.

CPU Caches:
  L1 Data 48 KiB (x8)
  L1 Instruction 32 KiB (x8)
  L2 Unified 2048 KiB (x8)
  L3 Unified 61440 KiB (x1)
Load Average: 0.86, 0.63, 0.40
-----------------------------------------------------------------------------------------------------------
Benchmark                                                 Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------
cos_i8_haswell<1536d>/min_time:10.000/threads:1        94.6 ns         94.6 ns    147959034 abs_delta=24.2675n bytes=32.4634G/s pairs=10.5675M/s relative_error=24.1004n
l2sq_i8_haswell<1536d>/min_time:10.000/threads:1       75.7 ns         75.7 ns    184969645 abs_delta=0 bytes=40.5813G/s pairs=13.2101M/s relative_error=0
dot_i8_haswell<1536d>/min_time:10.000/threads:1        75.1 ns         75.1 ns    186534921 abs_delta=0 bytes=40.9319G/s pairs=13.3242M/s relative_error=0
cos_u8_haswell<1536d>/min_time:10.000/threads:1        76.7 ns         76.7 ns    182722149 abs_delta=50.0079n bytes=40.0566G/s pairs=13.0392M/s relative_error=198.972n
l2sq_u8_haswell<1536d>/min_time:10.000/threads:1       46.6 ns         46.6 ns    296839472 abs_delta=0 bytes=65.9322G/s pairs=21.4623M/s relative_error=0
dot_u8_haswell<1536d>/min_time:10.000/threads:1        40.3 ns         40.3 ns    347943232 abs_delta=0 bytes=76.1757G/s pairs=24.7968M/s relative_error=0
cos_i8_ice<1536d>/min_time:10.000/threads:1            62.5 ns         62.5 ns    226246267 abs_delta=24.2675n bytes=49.1862G/s pairs=16.0111M/s relative_error=24.1004n
l2sq_i8_ice<1536d>/min_time:10.000/threads:1           49.7 ns         49.7 ns    278397268 abs_delta=0 bytes=61.7525G/s pairs=20.1017M/s relative_error=0
dot_i8_ice<1536d>/min_time:10.000/threads:1            48.7 ns         48.7 ns    287594728 abs_delta=0 bytes=63.1145G/s pairs=20.5451M/s relative_error=0
cos_u8_ice<1536d>/min_time:10.000/threads:1            46.2 ns         46.2 ns    302861979 abs_delta=50.0079n bytes=66.4599G/s pairs=21.6341M/s relative_error=198.972n
l2sq_u8_ice<1536d>/min_time:10.000/threads:1           32.8 ns         32.8 ns    426392200 abs_delta=0 bytes=93.5376G/s pairs=30.4484M/s relative_error=0
dot_u8_ice<1536d>/min_time:10.000/threads:1            31.8 ns         31.8 ns    439705238 abs_delta=0 bytes=96.7051G/s pairs=31.4795M/s relative_error=0
cos_i8_serial<1536d>/min_time:10.000/threads:1         94.4 ns         94.4 ns    148349619 abs_delta=0 bytes=32.5526G/s pairs=10.5965M/s relative_error=0
l2sq_i8_serial<1536d>/min_time:10.000/threads:1         149 ns          149 ns     93679282 abs_delta=0 bytes=20.5551G/s pairs=6.69113M/s relative_error=0
dot_i8_serial<1536d>/min_time:10.000/threads:1          298 ns          298 ns     46918869 abs_delta=0 bytes=10.2953G/s pairs=3.35132M/s relative_error=0
cos_u8_serial<1536d>/min_time:10.000/threads:1         94.4 ns         94.4 ns    148342871 abs_delta=0 bytes=32.5511G/s pairs=10.5961M/s relative_error=0
l2sq_u8_serial<1536d>/min_time:10.000/threads:1         149 ns          149 ns     93674000 abs_delta=0 bytes=20.5559G/s pairs=6.69137M/s relative_error=0
dot_u8_serial<1536d>/min_time:10.000/threads:1          298 ns          298 ns     46917154 abs_delta=0 bytes=10.2953G/s pairs=3.35132M/s relative_error=0

On Apple M2 Pro:

CPU Caches:
  L1 Data 64 KiB
  L1 Instruction 128 KiB
  L2 Unified 4096 KiB (x12)
Load Average: 3.49, 4.60, 4.59
----------------------------------------------------------------------------------------------------------
Benchmark                                                Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------------------------
cos_i8_neon<1536d>/min_time:10.000/threads:1          66.3 ns         65.9 ns    212749629 abs_delta=23.4158n bytes=46.596G/s pairs=15.168M/s relative_error=23.2195n
l2sq_i8_neon<1536d>/min_time:10.000/threads:1         58.3 ns         57.6 ns    245204458 abs_delta=0 bytes=53.3318G/s pairs=17.3606M/s relative_error=0
dot_i8_neon<1536d>/min_time:10.000/threads:1          57.5 ns         57.5 ns    249456897 abs_delta=0 bytes=53.4667G/s pairs=17.4045M/s relative_error=0
cos_u8_neon<1536d>/min_time:10.000/threads:1          65.4 ns         65.3 ns    215399031 abs_delta=44.134n bytes=47.0148G/s pairs=15.3043M/s relative_error=175.85n
l2sq_u8_neon<1536d>/min_time:10.000/threads:1         57.2 ns         57.1 ns    245530038 abs_delta=0 bytes=53.7771G/s pairs=17.5056M/s relative_error=0
dot_u8_neon<1536d>/min_time:10.000/threads:1          56.5 ns         56.3 ns    249268353 abs_delta=0 bytes=54.527G/s pairs=17.7497M/s relative_error=0
cos_i8_serial<1536d>/min_time:10.000/threads:1         172 ns          172 ns     81341484 abs_delta=0 bytes=17.869G/s pairs=5.81672M/s relative_error=0
l2sq_i8_serial<1536d>/min_time:10.000/threads:1        160 ns          160 ns     87444582 abs_delta=0 bytes=19.2064G/s pairs=6.25208M/s relative_error=0
dot_i8_serial<1536d>/min_time:10.000/threads:1         484 ns          480 ns     29330428 abs_delta=0 bytes=6.39618G/s pairs=2.08209M/s relative_error=0
cos_u8_serial<1536d>/min_time:10.000/threads:1         174 ns          174 ns     80410276 abs_delta=0 bytes=17.7017G/s pairs=5.76227M/s relative_error=0
l2sq_u8_serial<1536d>/min_time:10.000/threads:1        163 ns          163 ns     85610679 abs_delta=0 bytes=18.8776G/s pairs=6.14506M/s relative_error=0
dot_u8_serial<1536d>/min_time:10.000/threads:1         483 ns          481 ns     29273460 abs_delta=0 bytes=6.39193G/s pairs=2.08071M/s relative_error=0

L2 vs L2sq

How fast can we compute the Euclidean distance in $\mathbb{R}^3$?

$$d(\mathbf{a}, \mathbf{b}) = \sqrt{(a_1 - b_1)^2 + (a_2 - b_2)^2 + (a_3 - b_3)^2}$$

Is it much slower than computing the squared Euclidean distance? The answer, easily, 30% slower.

$$d^2(\mathbf{a}, \mathbf{b}) = (a_1 - b_1)^2 + (a_2 - b_2)^2 + (a_3 - b_3)^2$$

The cost of square root computation can be prohibitively high on low-dimensional vectors, so it's recommended to use L2sq where exact distance isn't necessary. Below are the numbers for 3D vectors on Intel Sapphire Rapids. Even on such tiny vectors, for bf16, for example, the Genoa kernels are over 4x faster than serial code - 31 GB/s vs 7 GB/s.

----------------------------------------------------------------------------------------------------------
Benchmark                                                Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------------------------
l2sq_f16_haswell<3d>/min_time:10.000/threads:1        22.9 ns         22.9 ns    611238825 abs_delta=18.0749n bytes=5.58275G/s pairs=43.6152M/s relative_error=9.15186n
l2_f16_haswell<3d>/min_time:10.000/threads:1          29.0 ns         29.0 ns    481702518 abs_delta=32.4372n bytes=4.40798G/s pairs=34.4373M/s relative_error=24.4623n
l2sq_bf16_haswell<3d>/min_time:10.000/threads:1       23.4 ns         23.4 ns    598749180 abs_delta=1.63755n bytes=5.47449G/s pairs=42.7695M/s relative_error=1.12825n
l2_bf16_haswell<3d>/min_time:10.000/threads:1         29.6 ns         29.6 ns    472595574 abs_delta=25.1248n bytes=4.32103G/s pairs=33.7581M/s relative_error=18.0853n
l2sq_i8_haswell<3d>/min_time:10.000/threads:1         4.00 ns         4.00 ns   1000000000 abs_delta=0 bytes=32.0099G/s pairs=250.077M/s relative_error=0
l2_i8_haswell<3d>/min_time:10.000/threads:1           4.37 ns         4.37 ns   1000000000 abs_delta=3.39785u bytes=29.277G/s pairs=228.727M/s relative_error=20.9261n
l2sq_u8_haswell<3d>/min_time:10.000/threads:1         4.11 ns         4.11 ns   1000000000 abs_delta=0 bytes=31.1124G/s pairs=243.066M/s relative_error=0
l2_u8_haswell<3d>/min_time:10.000/threads:1           4.08 ns         4.08 ns   1000000000 abs_delta=3.39785u bytes=31.3922G/s pairs=245.252M/s relative_error=20.9261n
l2sq_bf16_genoa<3d>/min_time:10.000/threads:1         4.13 ns         4.13 ns   1000000000 abs_delta=7.77245m bytes=31.0277G/s pairs=242.404M/s relative_error=3.58281m
l2_bf16_genoa<3d>/min_time:10.000/threads:1           4.67 ns         4.67 ns   1000000000 abs_delta=2.54837m bytes=27.4161G/s pairs=214.188M/s relative_error=1.79376m
l2sq_f16_sapphire<3d>/min_time:10.000/threads:1       3.09 ns         3.09 ns   1000000000 abs_delta=777.823u bytes=41.4891G/s pairs=324.134M/s relative_error=369.752u
l2_f16_sapphire<3d>/min_time:10.000/threads:1         3.63 ns         3.63 ns   1000000000 abs_delta=256.898u bytes=35.3008G/s pairs=275.787M/s relative_error=184.879u
l2sq_i8_ice<3d>/min_time:10.000/threads:1             2.91 ns         2.91 ns   1000000000 abs_delta=0 bytes=43.9976G/s pairs=343.731M/s relative_error=0
l2_i8_ice<3d>/min_time:10.000/threads:1               3.73 ns         3.73 ns   1000000000 abs_delta=3.39785u bytes=34.301G/s pairs=267.977M/s relative_error=20.9261n
l2sq_u8_ice<3d>/min_time:10.000/threads:1             3.32 ns         3.32 ns   1000000000 abs_delta=0 bytes=38.5524G/s pairs=301.191M/s relative_error=0
l2_u8_ice<3d>/min_time:10.000/threads:1               4.09 ns         4.09 ns   1000000000 abs_delta=3.39785u bytes=31.334G/s pairs=244.797M/s relative_error=20.9261n
l2sq_f64_skylake<3d>/min_time:10.000/threads:1        2.38 ns         2.38 ns   1000000000 abs_delta=102.457a bytes=53.8795G/s pairs=420.934M/s relative_error=50.304a
l2_f64_skylake<3d>/min_time:10.000/threads:1          2.53 ns         2.53 ns   1000000000 abs_delta=26.4545a bytes=50.4946G/s pairs=394.489M/s relative_error=22.4433a
l2sq_f32_skylake<3d>/min_time:10.000/threads:1        2.76 ns         2.76 ns   1000000000 abs_delta=88.0723n bytes=46.4136G/s pairs=362.606M/s relative_error=42.7713n
l2_f32_skylake<3d>/min_time:10.000/threads:1          3.20 ns         3.20 ns   1000000000 abs_delta=29.4821n bytes=39.943G/s pairs=312.055M/s relative_error=21.3856n
l2sq_bf16_serial<3d>/min_time:10.000/threads:1        17.9 ns         17.9 ns    779709405 abs_delta=12.0012n bytes=7.15287G/s pairs=55.8818M/s relative_error=6.49416n
l2_bf16_serial<3d>/min_time:10.000/threads:1          17.9 ns         17.9 ns    777831068 abs_delta=4.15743n bytes=7.1374G/s pairs=55.7609M/s relative_error=3.24708n
l2sq_f16_serial<3d>/min_time:10.000/threads:1         10.1 ns         10.1 ns   1000000000 abs_delta=51.7317n bytes=12.6994G/s pairs=99.2144M/s relative_error=25.558n
l2_f16_serial<3d>/min_time:10.000/threads:1           3.66 ns         3.66 ns   1000000000 abs_delta=17.515n bytes=34.9496G/s pairs=273.043M/s relative_error=12.779n
l2sq_f32_serial<3d>/min_time:10.000/threads:1         2.58 ns         2.58 ns   1000000000 abs_delta=86.5599n bytes=49.5235G/s pairs=386.902M/s relative_error=41.2701n
l2_f32_serial<3d>/min_time:10.000/threads:1           3.10 ns         3.10 ns   1000000000 abs_delta=28.7089n bytes=41.2603G/s pairs=322.346M/s relative_error=20.635n
l2sq_f64_serial<3d>/min_time:10.000/threads:1         2.07 ns         2.07 ns   1000000000 abs_delta=0 bytes=61.8926G/s pairs=483.536M/s relative_error=0
l2_f64_serial<3d>/min_time:10.000/threads:1           2.70 ns         2.70 ns   1000000000 abs_delta=0 bytes=47.443G/s pairs=370.648M/s relative_error=0
l2sq_i8_serial<3d>/min_time:10.000/threads:1          3.70 ns         3.70 ns   1000000000 abs_delta=0 bytes=34.5992G/s pairs=270.306M/s relative_error=0
l2_i8_serial<3d>/min_time:10.000/threads:1            3.23 ns         3.23 ns   1000000000 abs_delta=0 bytes=39.6018G/s pairs=309.389M/s relative_error=0
l2sq_u8_serial<3d>/min_time:10.000/threads:1          2.75 ns         2.75 ns   1000000000 abs_delta=0 bytes=46.5618G/s pairs=363.764M/s relative_error=0
l2_u8_serial<3d>/min_time:10.000/threads:1            3.41 ns         3.41 ns   1000000000 abs_delta=0 bytes=37.5629G/s pairs=293.46M/s relative_error=0

@ashvardanian ashvardanian changed the title Fused-Multiply-Add & u8 Kernels New u8 Kernels Oct 18, 2024
@ashvardanian ashvardanian merged commit 0313406 into main Oct 18, 2024
35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant