From 4cd928d2a27e9026a1546889186ecf7747364d80 Mon Sep 17 00:00:00 2001 From: Wei-Chen Wang Date: Tue, 31 Oct 2023 20:17:33 -0400 Subject: [PATCH] CPU Optimization (#74) --- kernels/avx/matmul_avx_int8_int4.cc | 234 +++++++++++++++--- kernels/neon/matmul_neon_int8_int4.cc | 191 ++++++++++++-- kernels/pthread_pool.cc | 129 ++++++++++ kernels/pthread_pool.h | 43 ++++ .../nn_modules/non_cuda/Int4llamaAttention.cc | 8 + .../non_cuda/Int4llamaDecoderLayer.cc | 19 +- .../non_cuda/Int4llamaForCausalLM.cc | 2 + llm/src/ops/linear.cc | 2 +- 8 files changed, 577 insertions(+), 51 deletions(-) create mode 100644 kernels/pthread_pool.cc create mode 100644 kernels/pthread_pool.h diff --git a/kernels/avx/matmul_avx_int8_int4.cc b/kernels/avx/matmul_avx_int8_int4.cc index e8e1b4a1..2eda1726 100644 --- a/kernels/avx/matmul_avx_int8_int4.cc +++ b/kernels/avx/matmul_avx_int8_int4.cc @@ -6,58 +6,221 @@ #include "../matmul.h" +#include "pthread_pool.h" + struct int4_thread_args { int start_j, end_j; const struct matmul_params *params; }; -static inline void merge_int4_int8_dot_product_unroll2block(float *s, float *s_a, uint8_t *w_ptr, __m256i *x_ptr, - __m256 &acc0) { - // load 0 - 127 bit and 128 - 255 - __m128i raw_w_0 = _mm_loadu_si128((const __m128i *)w_ptr); - __m128i raw_w_128 = _mm_loadu_si128((const __m128i *)(w_ptr + 16)); +#define FP16_TO_FP32(x) ((float) (x)) +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if __AVXVNNI__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} + +// static inline void merge_int4_int8_dot_product_unroll2block(float *s, float *s_a, uint8_t *w_ptr, __m256i *x_ptr, +// __m256 &acc0) { +// // load 0 - 127 bit and 128 - 255 +// __m128i raw_w_0 = _mm_loadu_si128((const __m128i *)w_ptr); +// __m128i raw_w_128 = _mm_loadu_si128((const __m128i *)(w_ptr + 16)); + +// __m256 v_s = _mm256_set1_ps(s[0] * s_a[0]); +// __m256 v_s2 = _mm256_set1_ps(s[1] * s_a[1]); + +// __m256i activation = x_ptr[0]; +// __m256i activation2 = x_ptr[1]; + +// // Expand bytes into uint16_t values +// __m256i w_8_16exp = _mm256_cvtepu8_epi16(raw_w_0); +// __m256i w2_8_16exp = _mm256_cvtepu8_epi16(raw_w_128); + +// // Unpack values into individual bytes +// __m256i raw_w = _mm256_loadu_si256((const __m256i *)w_ptr); +// const __m256i lowMask = _mm256_set1_epi8(0xF); +// __m256i w_0 = _mm256_and_si256(lowMask, raw_w); +// __m256i high = _mm256_andnot_si256(lowMask, raw_w); +// __m256i w_128 = _mm256_srli_epi16(high, 4); +// const __m256i zero_point = _mm256_set1_epi8(8); +// w_0 = _mm256_sub_epi8(w_0, zero_point); +// w_128 = _mm256_sub_epi8(w_128, zero_point); + +// // Get absolute values of x vectors +// const __m256i ax = _mm256_sign_epi8(w_0, w_0); +// const __m256i ax2 = _mm256_sign_epi8(w_128, w_128); +// // Sign the values of the y vectors +// const __m256i sy = _mm256_sign_epi8(activation, w_0); +// const __m256i sy2 = _mm256_sign_epi8(activation2, w_128); +// // Perform multiplication and create 16-bit values +// const __m256i dot = _mm256_maddubs_epi16(ax, sy); +// const __m256i dot2 = _mm256_maddubs_epi16(ax2, sy2); + +// const __m256i ones = _mm256_set1_epi16(1); +// const __m256i summed_pairs = _mm256_madd_epi16(ones, dot); +// const __m256i summed_pairs2 = _mm256_madd_epi16(ones, dot2); +// __m256 intermediate = _mm256_cvtepi32_ps(summed_pairs); +// __m256 intermediate2 = _mm256_cvtepi32_ps(summed_pairs2); + +// acc0 = _mm256_fmadd_ps(intermediate, v_s, acc0); +// acc0 = _mm256_fmadd_ps(intermediate2, v_s2, acc0); +// } +inline static void merge_int4_int8_dot_product_unroll2block(float *s, float *s_a, uint8_t *w_ptr, __m256i *x_ptr, + __m256 &acc0) { __m256 v_s = _mm256_set1_ps(s[0] * s_a[0]); __m256 v_s2 = _mm256_set1_ps(s[1] * s_a[1]); __m256i activation = x_ptr[0]; __m256i activation2 = x_ptr[1]; - // Expand bytes into uint16_t values - __m256i w_8_16exp = _mm256_cvtepu8_epi16(raw_w_0); - __m256i w2_8_16exp = _mm256_cvtepu8_epi16(raw_w_128); - + // __m256i w_0 = bytes_from_nibbles_32(w_ptr); // Unpack values into individual bytes __m256i raw_w = _mm256_loadu_si256((const __m256i *)w_ptr); const __m256i lowMask = _mm256_set1_epi8(0xF); __m256i w_0 = _mm256_and_si256(lowMask, raw_w); + __m256i high = _mm256_andnot_si256(lowMask, raw_w); __m256i w_128 = _mm256_srli_epi16(high, 4); + const __m256i zero_point = _mm256_set1_epi8(8); w_0 = _mm256_sub_epi8(w_0, zero_point); w_128 = _mm256_sub_epi8(w_128, zero_point); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(w_0, w_0); - const __m256i ax2 = _mm256_sign_epi8(w_128, w_128); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(activation, w_0); - const __m256i sy2 = _mm256_sign_epi8(activation2, w_128); - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - const __m256i dot2 = _mm256_maddubs_epi16(ax2, sy2); + const __m256 intermediate = mul_sum_i8_pairs_float(w_0, activation); + const __m256 intermediate2 = mul_sum_i8_pairs_float(w_128, activation2); + // // Get absolute values of x vectors + // const __m256i ax = _mm256_sign_epi8(w_0, w_0); + // const __m256i ax2 = _mm256_sign_epi8(w_128, w_128); + // // Sign the values of the y vectors + // const __m256i sy = _mm256_sign_epi8(activation, w_0); + // const __m256i sy2 = _mm256_sign_epi8(activation2, w_128); + // // Perform multiplication and create 16-bit values + // const __m256i dot = _mm256_maddubs_epi16(ax, sy); + // const __m256i dot2 = _mm256_maddubs_epi16(ax2, sy2); - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, dot); - const __m256i summed_pairs2 = _mm256_madd_epi16(ones, dot2); - __m256 intermediate = _mm256_cvtepi32_ps(summed_pairs); - __m256 intermediate2 = _mm256_cvtepi32_ps(summed_pairs2); + // const __m256i ones = _mm256_set1_epi16(1); + // const __m256i summed_pairs = _mm256_madd_epi16(ones, dot); + // const __m256i summed_pairs2 = _mm256_madd_epi16(ones, dot2); + // __m256 intermediate = _mm256_cvtepi32_ps(summed_pairs); + // __m256 intermediate2 = _mm256_cvtepi32_ps(summed_pairs2); acc0 = _mm256_fmadd_ps(intermediate, v_s, acc0); acc0 = _mm256_fmadd_ps(intermediate2, v_s2, acc0); } -static void *fast_int8_int4_zp_no_offset_over_column_unroll2block(void *args) { +inline static void *fast_int8_int4_zp_no_offset_over_column_unroll2block(void *args) { int i, j, k; struct int4_thread_args *mat_args = (struct int4_thread_args *)args; const struct matmul_params *params = mat_args->params; @@ -86,7 +249,8 @@ static void *fast_int8_int4_zp_no_offset_over_column_unroll2block(void *args) { C->data_ptr[i * C->column + j] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7] + params->bias.data_ptr[j]; else - C->data_ptr[i * C->column + j] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7]; + // C->data_ptr[i * C->column + j] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7]; + C->data_ptr[i * C->column + j] = hsum_float_8(acc0); } } return NULL; @@ -162,25 +326,33 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_ // const int num_thread = 4; const int num_thread = params->opt_params.num_thread; int i, j, k; - pthread_t thread_pool[num_thread]; + // pthread_t thread_pool[num_thread]; struct int4_thread_args threads_args[num_thread]; assert(params->block_size == 32); // support block size 32 for now assert(params->A.column % (params->block_size * 2) == 0); - assert((params->C.column % (num_thread * 2)) == 0); // support block size 32 for now + // assert((params->C.column % (num_thread * 2)) == 0); // support block size 32 for now // quantize A assert((params->A.column * params->A.row) % params->block_size == 0); quantize_fp_to_int8_block_size32(params->A.data_ptr, params->A.column * params->A.row, params->A.int8_data_ptr, params->A_scales); + static void *pool = pool_start(fast_int8_int4_zp_no_offset_over_column_unroll2block, num_thread); + // Thread creation for (j = 0; j < num_thread; j++) { threads_args[j].start_j = j * (params->C.column / num_thread); - threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + if (j == num_thread - 1) { + threads_args[j].end_j = params->C.column; + } else { + threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + } threads_args[j].params = params; - pthread_create(&thread_pool[j], NULL, fast_int8_int4_zp_no_offset_over_column_unroll2block, &threads_args[j]); + // pthread_create(&thread_pool[j], NULL, fast_int8_int4_zp_no_offset_over_column_unroll2block, &threads_args[j]); + pool_enqueue(pool, &threads_args[j], NULL); } - // Join threads - for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL); + // // Join threads + // for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL); + pool_wait(pool); }; } // namespace matmul diff --git a/kernels/neon/matmul_neon_int8_int4.cc b/kernels/neon/matmul_neon_int8_int4.cc index 2f0e8f3e..a1089599 100644 --- a/kernels/neon/matmul_neon_int8_int4.cc +++ b/kernels/neon/matmul_neon_int8_int4.cc @@ -9,6 +9,13 @@ #include "../matmul.h" #include "common.h" +#include "pthread_pool.h" + +struct a8w4_thread_args { + int start_j, end_j; + const struct matmul_params* params; +}; + // Most of this function is from llama.cpp void quantize_fp32_to_int8(float* A, int8_t* qA, float* sA, int size, int block_size) { assert(size % block_size == 0); @@ -159,11 +166,6 @@ void matmul_int8_int4_no_offset(struct matmul_params* params) { } } -struct a8w4_thread_args { - int start_j, end_j; - const struct matmul_params* params; -}; - static void* matmul_int8_int4_no_offset_over_column(void* args) { struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; const struct matmul_params* params = mat_args->params; @@ -230,7 +232,7 @@ static void* matmul_int8_int4_no_offset_over_column(void* args) { return NULL; } -static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { +inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; const struct matmul_params* params = mat_args->params; int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; @@ -279,6 +281,46 @@ static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + // // Sequential data + // uint8x8_t hi_mask = vdup_n_u8(0xF0); // 11110000 + // uint8x8_t lo_mask = vdup_n_u8(0x0F); // 00001111 + // // Split the data into two 64-bit halves + // uint8x8_t w0_lower_half = vget_low_u8(w0); + // uint8x8_t w0_upper_half = vget_high_u8(w0); + // uint8x8_t w0_low_hi_nibbles = vshr_n_u8(vand_u8(w0_lower_half, hi_mask), 4); + // uint8x8_t w0_low_lo_nibbles = vand_u8(w0_lower_half, lo_mask); + // int8x16_t w0_low = vreinterpretq_s8_u8(vcombine_u8(w0_low_lo_nibbles, w0_low_hi_nibbles)); + // uint8x8_t w0_high_hi_nibbles = vshr_n_u8(vand_u8(w0_upper_half, hi_mask), 4); + // uint8x8_t w0_high_lo_nibbles = vand_u8(w0_upper_half, lo_mask); + // int8x16_t w0_high = vreinterpretq_s8_u8(vcombine_u8(w0_high_lo_nibbles, w0_high_hi_nibbles)); + + // uint8x8_t w1_lower_half = vget_low_u8(w1); + // uint8x8_t w1_upper_half = vget_high_u8(w1); + // uint8x8_t w1_low_hi_nibbles = vshr_n_u8(vand_u8(w1_lower_half, hi_mask), 4); + // uint8x8_t w1_low_lo_nibbles = vand_u8(w1_lower_half, lo_mask); + // int8x16_t w1_low = vreinterpretq_s8_u8(vcombine_u8(w1_low_lo_nibbles, w1_low_hi_nibbles)); + // uint8x8_t w1_high_hi_nibbles = vshr_n_u8(vand_u8(w1_upper_half, hi_mask), 4); + // uint8x8_t w1_high_lo_nibbles = vand_u8(w1_upper_half, lo_mask); + // int8x16_t w1_high = vreinterpretq_s8_u8(vcombine_u8(w1_high_lo_nibbles, w1_high_hi_nibbles)); + + // uint8x8_t w2_lower_half = vget_low_u8(w2); + // uint8x8_t w2_upper_half = vget_high_u8(w2); + // uint8x8_t w2_low_hi_nibbles = vshr_n_u8(vand_u8(w2_lower_half, hi_mask), 4); + // uint8x8_t w2_low_lo_nibbles = vand_u8(w2_lower_half, lo_mask); + // int8x16_t w2_low = vreinterpretq_s8_u8(vcombine_u8(w2_low_lo_nibbles, w2_low_hi_nibbles)); + // uint8x8_t w2_high_hi_nibbles = vshr_n_u8(vand_u8(w2_upper_half, hi_mask), 4); + // uint8x8_t w2_high_lo_nibbles = vand_u8(w2_upper_half, lo_mask); + // int8x16_t w2_high = vreinterpretq_s8_u8(vcombine_u8(w2_high_lo_nibbles, w2_high_hi_nibbles)); + + // uint8x8_t w3_lower_half = vget_low_u8(w3); + // uint8x8_t w3_upper_half = vget_high_u8(w3); + // uint8x8_t w3_low_hi_nibbles = vshr_n_u8(vand_u8(w3_lower_half, hi_mask), 4); + // uint8x8_t w3_low_lo_nibbles = vand_u8(w3_lower_half, lo_mask); + // int8x16_t w3_low = vreinterpretq_s8_u8(vcombine_u8(w3_low_lo_nibbles, w3_low_hi_nibbles)); + // uint8x8_t w3_high_hi_nibbles = vshr_n_u8(vand_u8(w3_upper_half, hi_mask), 4); + // uint8x8_t w3_high_lo_nibbles = vand_u8(w3_upper_half, lo_mask); + // int8x16_t w3_high = vreinterpretq_s8_u8(vcombine_u8(w3_high_lo_nibbles, w3_high_hi_nibbles)); + // apply offset w0_low = vsubq_s8(w0_low, offsets); w0_high = vsubq_s8(w0_high, offsets); @@ -326,6 +368,111 @@ static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { return NULL; } +// inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { +// struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; +// const struct matmul_params* params = mat_args->params; +// int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; +// const int num_block = k / block_size; + +// for (int i = 0; i < m; i++) { +// for (int j = mat_args->start_j; j < mat_args->end_j; j++) { +// float32x4_t sumv0 = vdupq_n_f32(0.0f); +// float32x4_t sumv1 = vdupq_n_f32(0.0f); +// float32x4_t sumv2 = vdupq_n_f32(0.0f); +// float32x4_t sumv3 = vdupq_n_f32(0.0f); +// float32x4_t sumv4 = vdupq_n_f32(0.0f); +// float32x4_t sumv5 = vdupq_n_f32(0.0f); +// float32x4_t sumv6 = vdupq_n_f32(0.0f); +// float32x4_t sumv7 = vdupq_n_f32(0.0f); +// const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; +// const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; +// float* s_a = ¶ms->A_scales[i * k / 32]; +// float* s_w = ¶ms->scales[j * k / 32]; + +// const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); +// const int8x16_t offsets = vdupq_n_s8(8); +// for (int q = 0; q < num_block; q += 8) { +// int32x4_t int_sum0 = vdupq_n_s32(0); +// int32x4_t int_sum1 = vdupq_n_s32(0); +// int32x4_t int_sum2 = vdupq_n_s32(0); +// int32x4_t int_sum3 = vdupq_n_s32(0); +// int32x4_t int_sum4 = vdupq_n_s32(0); +// int32x4_t int_sum5 = vdupq_n_s32(0); +// int32x4_t int_sum6 = vdupq_n_s32(0); +// int32x4_t int_sum7 = vdupq_n_s32(0); +// float s_0 = *s_a++ * *s_w++; +// float s_1 = *s_a++ * *s_w++; +// float s_2 = *s_a++ * *s_w++; +// float s_3 = *s_a++ * *s_w++; + + +// const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight +// const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight +// const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight +// const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight +// w_start += 64; + +// // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit +// // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit +// // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... +// // low; (0, 0), (1, 0), (2, 0), (3, 0) ... +// // high: (16, 0), (17, 0), (18, 0), (19, 0) ... +// int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); +// int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); +// int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); +// int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); +// int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); +// int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); +// int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); +// int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + +// // apply offset +// w0_low = vsubq_s8(w0_low, offsets); +// w0_high = vsubq_s8(w0_high, offsets); +// w1_low = vsubq_s8(w1_low, offsets); +// w1_high = vsubq_s8(w1_high, offsets); +// w2_low = vsubq_s8(w2_low, offsets); +// w2_high = vsubq_s8(w2_high, offsets); +// w3_low = vsubq_s8(w3_low, offsets); +// w3_high = vsubq_s8(w3_high, offsets); + +// // load 64 8-bit activation +// const int8x16_t a0 = vld1q_s8(a_start); +// const int8x16_t a1 = vld1q_s8(a_start + 16); +// const int8x16_t a2 = vld1q_s8(a_start + 32); +// const int8x16_t a3 = vld1q_s8(a_start + 48); +// const int8x16_t a4 = vld1q_s8(a_start + 64); +// const int8x16_t a5 = vld1q_s8(a_start + 80); +// const int8x16_t a6 = vld1q_s8(a_start + 96); +// const int8x16_t a7 = vld1q_s8(a_start + 112); +// a_start += 128; + +// // dot product into int32x4_t +// int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); +// int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); +// int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); +// int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); +// int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); +// int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); +// int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); +// int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + +// sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); +// sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); +// sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); +// sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); +// } +// if (params->bias.data_ptr) +// params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + +// vaddvq_f32(sumv2) + vaddvq_f32(sumv3); +// else +// params->C.data_ptr[i * n + j] = +// vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); +// } +// } + +// return NULL; +// } static void* matmul_int8_int4_no_offset_over_column_packed(void* args) { struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; @@ -412,23 +559,37 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_ // const int num_thread = 8; const int num_thread = params->opt_params.num_thread; - pthread_t thread_pool[num_thread]; + // pthread_t thread_pool[num_thread]; struct a8w4_thread_args threads_args[num_thread]; assert(params->block_size == 32); // support block size 32 for now +#ifdef PACK_QK + // This may lead to performance degradation + static void *pool = pool_start(matmul_int8_int4_no_offset_over_column_packed, num_thread); +#else + static void *pool = pool_start(matmul_int8_int4_no_offset_over_column_unroll128, num_thread); +#endif + // Thread creation for (j = 0; j < num_thread; j++) { threads_args[j].start_j = j * (params->C.column / num_thread); - threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + // threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + if (j == num_thread - 1) { + threads_args[j].end_j = params->C.column; + } else { + threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + } threads_args[j].params = params; -#ifdef PACK_QK - // This may lead to performance degradation - pthread_create(&thread_pool[j], NULL, matmul_int8_int4_no_offset_over_column_packed, &threads_args[j]); -#else - pthread_create(&thread_pool[j], NULL, matmul_int8_int4_no_offset_over_column_unroll128, &threads_args[j]); -#endif +// #ifdef PACK_QK +// // This may lead to performance degradation +// pthread_create(&thread_pool[j], NULL, matmul_int8_int4_no_offset_over_column_packed, &threads_args[j]); +// #else +// pthread_create(&thread_pool[j], NULL, matmul_int8_int4_no_offset_over_column_unroll128, &threads_args[j]); +// #endif + pool_enqueue(pool, &threads_args[j], NULL); } // Join threads - for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL); + // for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL); + pool_wait(pool); }; } // namespace matmul diff --git a/kernels/pthread_pool.cc b/kernels/pthread_pool.cc new file mode 100644 index 00000000..6b4d88c8 --- /dev/null +++ b/kernels/pthread_pool.cc @@ -0,0 +1,129 @@ +#include "pthread_pool.h" +#include +#include +#include + +struct pool_queue { + void *arg; + char free; + struct pool_queue *next; +}; + +struct pool { + char cancelled; + void *(*fn)(void *); + unsigned int remaining; + unsigned int nthreads; + struct pool_queue *q; + struct pool_queue *end; + pthread_mutex_t q_mtx; + pthread_cond_t q_cnd; + pthread_t threads[1]; +}; + +static void * thread(void *arg); + +void * pool_start(void * (*thread_func)(void *), unsigned int threads) { + struct pool *p = (struct pool *) malloc(sizeof(struct pool) + (threads-1) * sizeof(pthread_t)); + int i; + + pthread_mutex_init(&p->q_mtx, NULL); + pthread_cond_init(&p->q_cnd, NULL); + p->nthreads = threads; + p->fn = thread_func; + p->cancelled = 0; + p->remaining = 0; + p->end = NULL; + p->q = NULL; + + for (i = 0; i < threads; i++) { + pthread_create(&p->threads[i], NULL, &thread, p); + } + + return p; +} + +void pool_enqueue(void *pool, void *arg, char free) { + struct pool *p = (struct pool *) pool; + struct pool_queue *q = (struct pool_queue *) malloc(sizeof(struct pool_queue)); + q->arg = arg; + q->next = NULL; + q->free = free; + + pthread_mutex_lock(&p->q_mtx); + if (p->end != NULL) p->end->next = q; + if (p->q == NULL) p->q = q; + p->end = q; + p->remaining++; + pthread_cond_signal(&p->q_cnd); + pthread_mutex_unlock(&p->q_mtx); +} + +void pool_wait(void *pool) { + struct pool *p = (struct pool *) pool; + + pthread_mutex_lock(&p->q_mtx); + while (!p->cancelled && p->remaining) { + pthread_cond_wait(&p->q_cnd, &p->q_mtx); + } + pthread_mutex_unlock(&p->q_mtx); +} + +void pool_end(void *pool) { + struct pool *p = (struct pool *) pool; + struct pool_queue *q; + int i; + + p->cancelled = 1; + + pthread_mutex_lock(&p->q_mtx); + pthread_cond_broadcast(&p->q_cnd); + pthread_mutex_unlock(&p->q_mtx); + + for (i = 0; i < p->nthreads; i++) { + pthread_join(p->threads[i], NULL); + } + + while (p->q != NULL) { + q = p->q; + p->q = q->next; + + if (q->free) free(q->arg); + free(q); + } + + free(p); +} + +static void * thread(void *arg) { + struct pool_queue *q; + struct pool *p = (struct pool *) arg; + + while (!p->cancelled) { + pthread_mutex_lock(&p->q_mtx); + while (!p->cancelled && p->q == NULL) { + pthread_cond_wait(&p->q_cnd, &p->q_mtx); + } + if (p->cancelled) { + pthread_mutex_unlock(&p->q_mtx); + return NULL; + } + q = p->q; + p->q = q->next; + p->end = (q == p->end ? NULL : p->end); + pthread_mutex_unlock(&p->q_mtx); + + p->fn(q->arg); + + if (q->free) free(q->arg); + free(q); + q = NULL; + + pthread_mutex_lock(&p->q_mtx); + p->remaining--; + pthread_cond_broadcast(&p->q_cnd); + pthread_mutex_unlock(&p->q_mtx); + } + + return NULL; +} diff --git a/kernels/pthread_pool.h b/kernels/pthread_pool.h new file mode 100644 index 00000000..009b6baa --- /dev/null +++ b/kernels/pthread_pool.h @@ -0,0 +1,43 @@ +/** \file + * This file provides prototypes for an implementation of a pthread pool. + */ + +#ifndef __PTHREAD_POOL_H__ +/** + * Create a new thread pool. + * + * New tasks should be enqueued with pool_enqueue. thread_func will be called + * once per queued task with its sole argument being the argument given to + * pool_enqueue. + * + * \param thread_func The function executed by each thread for each work item. + * \param threads The number of threads in the pool. + * \return A pointer to the thread pool. + */ +void * pool_start(void * (*thread_func)(void *), unsigned int threads); + +/** + * Enqueue a new task for the thread pool. + * + * \param pool A thread pool returned by start_pool. + * \param arg The argument to pass to the thread worker function. + * \param free If true, the argument will be freed after the task has completed. + */ +void pool_enqueue(void *pool, void *arg, char free); + +/** + * Wait for all queued tasks to be completed. + */ +void pool_wait(void *pool); + +/** + * Stop all threads in the pool. + * + * Note that this function will block until all threads have terminated. + * All queued items will also be freed, along with the pool itself. + * Remaining work item arguments will be freed depending on the free argument to + * pool_enqueue. + */ +void pool_end(void *pool); + +#endif diff --git a/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc b/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc index a2917901..5c534af4 100644 --- a/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc +++ b/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc @@ -288,7 +288,9 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p // Query states Matrix3D query_states_unshape(query_states_unshape_arr, b, sqlen, embed_dim); + PROFILE_START(profile_name + "::q_proj"); this->q_proj.forward(input.hidden_states, query_states_unshape); + PROFILE_END(profile_name + "::q_proj"); Matrix3D query_states(query_states_arr, this->num_heads, sqlen, this->head_dim); this->shape(query_states_unshape, query_states, sqlen); @@ -306,13 +308,17 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p // Key states Matrix3D key_states_unshape(key_states_unshape_arr, b, sqlen, embed_dim); + PROFILE_START(profile_name + "::k_proj"); this->k_proj.forward(input.hidden_states, key_states_unshape); + PROFILE_END(profile_name + "::k_proj"); Matrix3D key_states(key_states_arr, this->num_heads, sqlen, this->head_dim); this->shape(key_states_unshape, key_states, sqlen); // Value states Matrix3D value_states_unshape(value_states_unshape_arr, b, sqlen, embed_dim); + PROFILE_START(profile_name + "::v_proj"); this->v_proj.forward(input.hidden_states, value_states_unshape); + PROFILE_END(profile_name + "::v_proj"); Matrix3D value_states(value_states_arr, this->num_heads, sqlen, this->head_dim); this->shape(value_states_unshape, value_states, sqlen); @@ -381,7 +387,9 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p // Output projection Matrix3D attn_output_fp(attn_output_fp_arr, 1, sqlen, this->num_heads * this->head_dim); + PROFILE_START(profile_name + "::o_proj"); this->o_proj.forward(attn_output_transpose, attn_output_fp); + PROFILE_END(profile_name + "::o_proj"); // Output assignment output.attn_output = attn_output_fp; diff --git a/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc b/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc index 0330ba24..d2566c41 100644 --- a/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc +++ b/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc @@ -30,12 +30,17 @@ static void add(Matrix3D a, Matrix3D b, Matrix3D c) { PROFILE_END("Int4llamaDecoderLayer::add"); } -static void SiLuMul(Matrix3D a, Matrix3D b) { +inline static float Silu(float x) { + return x / (1.0f + expf(-x)); +} + +inline static void SiLuMul(Matrix3D a, Matrix3D b) { PROFILE_START("Int4llamaDecoderLayer::MulSiLu"); for (int i = 0; i < a.length(); i++) { - float v = a.m_data[i]; - float silu_v = v * (1.0 / (1.0 + exp(-1 * v))); - a.m_data[i] = silu_v * b.m_data[i]; + // float v = a.m_data[i]; + // float silu_v = v * (1.0 / (1.0 + exp(-1 * v))); + // a.m_data[i] = silu_v * b.m_data[i]; + a.m_data[i] = Silu(a.m_data[i]) * b.m_data[i]; } PROFILE_END("Int4llamaDecoderLayer::MulSiLu"); } @@ -81,18 +86,24 @@ struct Int4llamaDecoderLayer_output Int4llamaDecoderLayer::forward(std::string p // Gate proj: embedding -> hidden_dim Matrix3D gate_proj(gate_proj_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, this->hidden_dim); + PROFILE_START("Int4llamaDecoderLayer::gate_proj"); this->gate_proj.forward(post_attention_layernorm, gate_proj); + PROFILE_END("Int4llamaDecoderLayer::gate_proj"); // up proj: embedding -> hidden_dim Matrix3D up_proj(up_proj_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, this->hidden_dim); + PROFILE_START("Int4llamaDecoderLayer::up_proj"); this->up_proj.forward(post_attention_layernorm, up_proj); + PROFILE_END("Int4llamaDecoderLayer::up_proj"); // silu SiLuMul(gate_proj, up_proj); // down proj: hidden_dim -> embedding Matrix3D down_proj(down_proj_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, this->embed_dim); + PROFILE_START("Int4llamaDecoderLayer::down_proj"); this->down_proj.forward(gate_proj, down_proj); + PROFILE_END("Int4llamaDecoderLayer::down_proj"); // Residual add add(residual_add, down_proj, residual_add); diff --git a/llm/src/nn_modules/non_cuda/Int4llamaForCausalLM.cc b/llm/src/nn_modules/non_cuda/Int4llamaForCausalLM.cc index f4add01e..1b6d819e 100644 --- a/llm/src/nn_modules/non_cuda/Int4llamaForCausalLM.cc +++ b/llm/src/nn_modules/non_cuda/Int4llamaForCausalLM.cc @@ -32,7 +32,9 @@ struct Int4LlamaForCausalLM_output Int4LlamaForCausalLM::forward(std::string par // Get logits Matrix3D logits(logits_output, 1, sqlen, this->decoder.voc_size); + PROFILE_START("Int4LlamaForCausalLM::lm_head"); this->lm_head.forward(decoder_output.last_hidden_state, logits); + PROFILE_END("Int4LlamaForCausalLM::lm_head"); struct Int4LlamaForCausalLM_output LMoutput = {logits, decoder_output.past_keys, decoder_output.past_values}; PROFILE_END(profile_name); diff --git a/llm/src/ops/linear.cc b/llm/src/ops/linear.cc index 136d50e8..4ff1ea02 100644 --- a/llm/src/ops/linear.cc +++ b/llm/src/ops/linear.cc @@ -174,7 +174,7 @@ void Linear_FP_int4::forward(const Matrix3D &x, Matrix3D &output) assert(x.m_dim_z / 2 == weight.m_dim_z); assert(output.m_dim_z > num_thread); - assert(output.m_dim_z % (num_thread * 2) == 0); // unroll column by 2 + // assert(output.m_dim_z % (num_thread * 2) == 0); // unroll column by 2 struct matmul_params params; params.A.row = x.m_dim_y;