diff --git a/.gitignore b/.gitignore index 8bc1e14a..da1afdc4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.a *.pyc *.cuu +*.ccc .DS_Store .build/ .cache/ @@ -17,6 +18,7 @@ assets/ *.zip *.txt !requirements.txt +*.pt *.json test_* !test_*.cc diff --git a/kernels/avx/matmul_avx_int8_int4.cc b/kernels/avx/matmul_avx_int8_int4.cc index b25dd7d9..e8e1b4a1 100644 --- a/kernels/avx/matmul_avx_int8_int4.cc +++ b/kernels/avx/matmul_avx_int8_int4.cc @@ -159,7 +159,8 @@ static void quantize_fp_to_int8_block_size32(float *x, int size, int8_t *qx, flo namespace matmul { void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params) { - const int num_thread = 4; // params->opt_params.num_thread; + // const int num_thread = 4; + const int num_thread = params->opt_params.num_thread; int i, j, k; pthread_t thread_pool[num_thread]; struct int4_thread_args threads_args[num_thread]; diff --git a/kernels/neon/matmul_neon_int4.cc b/kernels/neon/matmul_neon_int4.cc index c54e375e..d43453e3 100644 --- a/kernels/neon/matmul_neon_int4.cc +++ b/kernels/neon/matmul_neon_int4.cc @@ -399,7 +399,8 @@ static void *fast_zp_no_offset_over_column_func_v3(void *args) { namespace matmul { void MatmulOperator::mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params) { - const int num_thread = 32; + // const int num_thread = 32; + const int num_thread = params->opt_params.num_thread; int i, j, k; pthread_t thread_pool[num_thread]; struct int4_thread_args threads_args[num_thread]; diff --git a/kernels/neon/matmul_neon_int4_offset.cc b/kernels/neon/matmul_neon_int4_offset.cc index f6099abd..5cac2a74 100644 --- a/kernels/neon/matmul_neon_int4_offset.cc +++ b/kernels/neon/matmul_neon_int4_offset.cc @@ -280,7 +280,8 @@ static void *fast_over_column_func_v1(void *args) { namespace matmul { void MatmulOperator::mat_mul_accelerator_int4_fast(const struct matmul_params *params) { - const int num_thread = 16; + // const int num_thread = 16; + const int num_thread = params->opt_params.num_thread; int i, j, k; pthread_t thread_pool[num_thread]; struct int4_thread_args threads_args[num_thread]; diff --git a/kernels/neon/matmul_neon_int8_int4.cc b/kernels/neon/matmul_neon_int8_int4.cc index 94dc6e20..2f0e8f3e 100644 --- a/kernels/neon/matmul_neon_int8_int4.cc +++ b/kernels/neon/matmul_neon_int8_int4.cc @@ -410,7 +410,8 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_ // ref imp. // matmul_int8_int4_no_offset(params); - const int num_thread = 8; + // const int num_thread = 8; + const int num_thread = params->opt_params.num_thread; pthread_t thread_pool[num_thread]; struct a8w4_thread_args threads_args[num_thread]; assert(params->block_size == 32); // support block size 32 for now diff --git a/llm/application/chat.cc b/llm/application/chat.cc index 6232a2d5..cbb51393 100644 --- a/llm/application/chat.cc +++ b/llm/application/chat.cc @@ -1,11 +1,13 @@ #include #include +#include #include "Generate.h" std::map model_config = { {"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B}, - {"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B}}; + {"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B}, + {"CodeLLaMA_7B", CodeLLaMA_7B}, {"CodeLLaMA_13B", CodeLLaMA_13B}}; std::map model_path = {{"OPT_125m", "models/OPT_125m"}, {"OPT_1.3B", "models/OPT_1.3B"}, @@ -14,7 +16,9 @@ std::map model_path = {{"OPT_125m", "models/OPT_125m"} {"LLaMA2_7B_chat", "models/LLaMA_7B_2_chat"}, {"LLaMA2_13B_chat", "models/LLaMA_13B_2_chat"}, {"7b", "models/LLaMA_7B_2_chat"}, - {"13b", "models/LLaMA_13B_2_chat"}}; + {"13b", "models/LLaMA_13B_2_chat"}, + {"CodeLLaMA_7B", "models/CodeLLaMA_7B_Instruct"}, + {"CodeLLaMA_13B", "models/CodeLLaMA_13B_Instruct"}}; std::map data_format_list = { {"FP32", FP32}, {"INT8", QINT8}, {"INT4", INT4}, {"int4", INT4}, {"fp32", FP32}, @@ -22,21 +26,52 @@ std::map data_format_list = { bool isLLaMA(std::string s) { std::string LLaMA_prefix = "LLaMA"; + std::string CodeLLaMA_prefix = "CodeLLaMA"; - if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s == "7b" || s == "13b") + if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s.substr(0, CodeLLaMA_prefix.size()) == CodeLLaMA_prefix || s == "7b" || s == "13b") return true; else return false; } +bool isCodeLLaMA(std::string s) { + std::string CodeLLaMA_prefix = "CodeLLaMA"; + + if (s.substr(0, CodeLLaMA_prefix.size()) == CodeLLaMA_prefix) + return true; + else + return false; +} + +bool convertToBool(const char* str) { + if (strcmp(str, "true") == 0 || strcmp(str, "1") == 0) { + return true; + } + else if (strcmp(str, "false") == 0 || strcmp(str, "0") == 0) { + return false; + } + else { + std::cerr << "Error: Invalid boolean value: " << str << std::endl; + exit(EXIT_FAILURE); + } +} + int main(int argc, char* argv[]) { std::string target_model = "LLaMA2_7B_chat"; std::string target_data_format = "INT4"; + bool instruct = true; Profiler::getInstance().for_demo = true; std::cout << "TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine" << std::endl; - if (argc == 3) { + if (argc >= 3 && argc <= 5) { + if (argc >= 4) { + NUM_THREAD = atoi(argv[3]); + } + if (argc == 5) { + instruct = convertToBool(argv[4]); + } + auto target_str = argv[1]; target_model = argv[1]; if (model_config.count(target_model) == 0) { @@ -108,10 +143,14 @@ int main(int argc, char* argv[]) { #endif struct opt_params generation_config; - generation_config.n_predict = 512; generation_config.n_vocab = 32000; - generation_config.temp = 0.1f; + generation_config.n_predict = 512; generation_config.repeat_penalty = 1.25f; + generation_config.temp = 0.1f; + if(isCodeLLaMA(target_model)) { + generation_config.temp = 0.2f; + generation_config.top_p = 0.95f; + } if (format_id == FP32) { Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id)); @@ -122,7 +161,21 @@ int main(int argc, char* argv[]) { std::cout << "USER: "; std::string input; std::getline(std::cin, input); - input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; + if (instruct) { + std::cout << "ASSISTANT: " << std::endl; + if (isCodeLLaMA(target_model)) { + input = "[INST] " + input + " [/INST]\n"; + } + } + else { + if (isCodeLLaMA(target_model)) { + std::cout << input; + } + } + + if (!isCodeLLaMA(target_model)) { + input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; + } LLaMAGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama_vocab.bin", true, false); } @@ -136,7 +189,21 @@ int main(int argc, char* argv[]) { std::cout << "USER: "; std::string input; std::getline(std::cin, input); - input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; + if (instruct) { + std::cout << "ASSISTANT: " << std::endl; + if (isCodeLLaMA(target_model)) { + input = "[INST] " + input + " [/INST]"; + } + } + else { + if (isCodeLLaMA(target_model)) { + std::cout << input; + } + } + + if (!isCodeLLaMA(target_model)) { + input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n"; + } LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, false); } diff --git a/llm/application/voicechat.cc b/llm/application/voicechat.cc index 6d81b1d6..010c6190 100644 --- a/llm/application/voicechat.cc +++ b/llm/application/voicechat.cc @@ -123,10 +123,11 @@ int main(int argc, char* argv[]) { std::string output; std::string model_input; - std::system("./application/sts_utils/listen"); + int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); std::getline(in, input); - std::system("rm tmpfile"); + result = std::system("rm tmpfile"); + (void)result; std::cout << input << std::endl; if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") @@ -145,10 +146,11 @@ int main(int argc, char* argv[]) { std::string output; std::string model_input; - std::system("./application/sts_utils/listen"); + int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); std::getline(in, input); - std::system("rm tmpfile"); + result = std::system("rm tmpfile"); + (void)result; std::cout << input << std::endl; if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.") @@ -191,10 +193,11 @@ int main(int argc, char* argv[]) { std::string output; std::string model_input; - std::system("./application/sts_utils/listen"); + int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); std::getline(in, input); - std::system("rm tmpfile"); + result = std::system("rm tmpfile"); + (void)result; std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); std::cout << input << std::endl; @@ -215,10 +218,11 @@ int main(int argc, char* argv[]) { std::string output; std::string model_input; - std::system("./application/sts_utils/listen"); + int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); std::getline(in, input); - std::system("rm tmpfile"); + result = std::system("rm tmpfile"); + (void)result; std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); std::cout << input << std::endl; @@ -238,10 +242,11 @@ int main(int argc, char* argv[]) { std::string output; std::string model_input; - std::system("./application/sts_utils/listen"); + int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); std::getline(in, input); - std::system("rm tmpfile"); + result = std::system("rm tmpfile"); + (void)result; std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); std::cout << input << std::endl; diff --git a/llm/include/Generate.h b/llm/include/Generate.h index 1f6fd04b..b772714d 100644 --- a/llm/include/Generate.h +++ b/llm/include/Generate.h @@ -5,6 +5,9 @@ Adapted from llama.cpp: */ +#ifndef GENERATE_H +#define GENERATE_H + #include #include #include @@ -24,7 +27,8 @@ Adapted from llama.cpp: #include "operators.h" #include "utils.h" -inline std::mt19937 OPT_rng; +// inline std::mt19937 OPT_rng; // inline variables are only available with ‘-std=c++17’ or ‘-std=gnu++17’ +static std::mt19937 OPT_rng; typedef struct OPT_token_data { int id; // token id @@ -39,14 +43,14 @@ typedef struct OPT_token_data_array { } OPT_token_data_array; struct opt_params { - int32_t seed = -1; // RNG seed - int32_t n_threads = 1; // TODO: fix this - int32_t n_predict = 128; // new tokens to predict - int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_vocab = 50272; // vocabulary size + int32_t seed = -1; // RNG seed + int32_t n_threads = 1; // TODO: fix this + int32_t n_predict = 128; // new tokens to predict + int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_vocab = 50272; // vocabulary size // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens @@ -97,3 +101,5 @@ std::vector OPTGenerate(void* model, int model_type, std::vector input enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4 }; std::string LLaMAGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config, std::string voc_path, bool interactive, bool voicechat); + +#endif // GENERATE_H diff --git a/llm/include/model.h b/llm/include/model.h index f838e871..507f2304 100644 --- a/llm/include/model.h +++ b/llm/include/model.h @@ -11,11 +11,11 @@ struct model_config { int hidden_dim; int vocsize; int padding_idx; - int qk; // group size + float rms_norm_eps; // RMSNorm epsilon (only for LLaMA models) - model_config() : model_config(1, 12, 12, 512, 768, 3072, 50272, 1) {} + model_config() : model_config(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6) {} model_config(int batch, int num_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize, - int padding_idx) + int padding_idx, float rms_norm_eps) : batch(batch), num_heads(num_heads), num_layers(num_layers), @@ -23,17 +23,20 @@ struct model_config { embed_dim(embed_dim), hidden_dim(hidden_dim), vocsize(vocsize), - padding_idx(padding_idx) {} + padding_idx(padding_idx), + rms_norm_eps(rms_norm_eps) {} }; -enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B }; +enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B }; enum { FP32, QINT8, INT4 }; -const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1); -const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1); -const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1); -const struct model_config llama_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1); -const struct model_config llama_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1); +const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0); +const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1, 0); +const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1, 0); +const struct model_config llama_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6); +const struct model_config llama_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6); +const struct model_config codellama_7B(1, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5); +const struct model_config codellama_13B(1, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5); static struct model_config get_opt_model_config(int choise) { struct model_config ret; switch (choise) { @@ -46,12 +49,18 @@ static struct model_config get_opt_model_config(int choise) { case OPT_6_7B: ret = opt_6_7B; break; - case LLaMA_7B:; + case LLaMA_7B: ret = llama_7B; break; - case LLaMA_13B:; + case LLaMA_13B: ret = llama_13B; break; + case CodeLLaMA_7B: + ret = codellama_7B; + break; + case CodeLLaMA_13B: + ret = codellama_13B; + break; default: throw("Unsupported model choice."); break; diff --git a/llm/include/nn_modules/Fp32llamaDecoder.h b/llm/include/nn_modules/Fp32llamaDecoder.h index eaa80176..8502d0ed 100644 --- a/llm/include/nn_modules/Fp32llamaDecoder.h +++ b/llm/include/nn_modules/Fp32llamaDecoder.h @@ -31,6 +31,7 @@ class Fp32llamaDecoder { struct Fp32llamaDecoder_output forward(const struct Fp32llamaDecoder_input& input); Embedding embed_tokens; LlamaRMSNorm norm; + float rms_norm_eps; int voc_size, embed_dim, padding_idx, hidden_dim, num_heads; std::vector layers; std::string profile_name = "Fp32llamaDecoder"; diff --git a/llm/include/nn_modules/Fp32llamaDecoderLayer.h b/llm/include/nn_modules/Fp32llamaDecoderLayer.h index dc85db67..075162e7 100644 --- a/llm/include/nn_modules/Fp32llamaDecoderLayer.h +++ b/llm/include/nn_modules/Fp32llamaDecoderLayer.h @@ -42,6 +42,7 @@ class Fp32llamaDecoderLayer { struct Fp32llamaDecoderLayer_output forward(const struct Fp32llamaDecoderLayer_input &input); int embed_dim, num_attention_heads, hidden_dim, layer_idx; + float rms_norm_eps; LlamaRMSNorm input_layernorm, post_attention_layernorm; Linear_FP gate_proj, down_proj, up_proj; Fp32llamaAttention attn; diff --git a/llm/include/nn_modules/Int4llamaDecoder.h b/llm/include/nn_modules/Int4llamaDecoder.h index 630901f7..7b5c57f2 100644 --- a/llm/include/nn_modules/Int4llamaDecoder.h +++ b/llm/include/nn_modules/Int4llamaDecoder.h @@ -44,6 +44,7 @@ class Int4llamaDecoder { Matrix3D prepare_decoder_attention_mask(int length, int past_length); struct Int4llamaDecoder_output forward(std::string param_path, const struct Int4llamaDecoder_input& input); int voc_size, embed_dim, padding_idx, hidden_dim, num_heads; + float rms_norm_eps; std::vector layers; std::string profile_name = "Int4llamaDecoder"; #ifdef QM_CUDA diff --git a/llm/include/nn_modules/Int4llamaDecoderLayer.h b/llm/include/nn_modules/Int4llamaDecoderLayer.h index 94895c93..e90fbabe 100644 --- a/llm/include/nn_modules/Int4llamaDecoderLayer.h +++ b/llm/include/nn_modules/Int4llamaDecoderLayer.h @@ -67,6 +67,7 @@ class Int4llamaDecoderLayer { std::string profile_name = "Int4llamaDecoderLayer"; int embed_dim, num_attention_heads, hidden_dim, layer_idx; + float rms_norm_eps; Int4llamaAttention attn; #ifdef QM_CUDA void free_cuda_memory(); diff --git a/llm/include/operators.h b/llm/include/operators.h index fba5c8db..dcdb8133 100644 --- a/llm/include/operators.h +++ b/llm/include/operators.h @@ -6,7 +6,8 @@ #include "matmul.h" #define BLK_SIZE 16 -#define NUM_THREAD 4 +// #define NUM_THREAD 8 +static int NUM_THREAD = 8; // include all ops #include "ops/BMM_F32T.h" diff --git a/llm/include/ops/LlamaRMSNorm.h b/llm/include/ops/LlamaRMSNorm.h index 649f458c..6b7a44b1 100644 --- a/llm/include/ops/LlamaRMSNorm.h +++ b/llm/include/ops/LlamaRMSNorm.h @@ -1,12 +1,12 @@ #include "common.h" +#include "utils.h" class LlamaRMSNorm { public: LlamaRMSNorm(Matrix3D _weight) : weight(_weight){}; LlamaRMSNorm(){}; - void forward(const Matrix3D &x, Matrix3D &output); + void forward(const Matrix3D &x, Matrix3D &output, float eps); Matrix3D weight; - float eps = 1e-6; private: std::string profile_name = "LlamaRMSNorm"; diff --git a/llm/include/ops/cuda/LlamaRMSNorm.cuh b/llm/include/ops/cuda/LlamaRMSNorm.cuh index 106a49b7..5fec6b20 100644 --- a/llm/include/ops/cuda/LlamaRMSNorm.cuh +++ b/llm/include/ops/cuda/LlamaRMSNorm.cuh @@ -4,9 +4,8 @@ class LlamaRMSNorm_cuda { public: LlamaRMSNorm_cuda(Matrix3D _weight) : weight(_weight){}; LlamaRMSNorm_cuda(){}; - void forward(const Matrix3D &x, Matrix3D &output); + void forward(const Matrix3D &x, Matrix3D &output, float eps); Matrix3D weight; - float eps = 1e-6; // half half_eps = 6.10352e-05; private: diff --git a/llm/src/OPTGenerate.cc b/llm/src/OPTGenerate.cc index ca038a6b..a698df8c 100644 --- a/llm/src/OPTGenerate.cc +++ b/llm/src/OPTGenerate.cc @@ -8,7 +8,8 @@ // Function to speak in the background void speakInBackground(const std::string& text) { std::string command = "./application/sts_utils/speak \"" + text + "\""; - std::system(command.c_str()); + int result = std::system(command.c_str()); + (void)result; } // OPTGenerate function diff --git a/llm/src/nn_modules/Fp32llamaDecoder.cc b/llm/src/nn_modules/Fp32llamaDecoder.cc index 0fd7dfb3..18d0b1c0 100644 --- a/llm/src/nn_modules/Fp32llamaDecoder.cc +++ b/llm/src/nn_modules/Fp32llamaDecoder.cc @@ -35,6 +35,7 @@ Fp32llamaDecoder::Fp32llamaDecoder(std::string param_path, const struct model_co this->hidden_dim = config.hidden_dim; this->num_heads = config.num_heads; this->padding_idx = config.padding_idx; + this->rms_norm_eps = config.rms_norm_eps; int max_sqlen = config.max_sqlen; @@ -104,7 +105,7 @@ struct Fp32llamaDecoder_output Fp32llamaDecoder::forward(const struct Fp32llamaD // Layernorm Matrix3D last_hidden_states(last_hidden_states_buf, 1, sqlen, this->embed_dim); - this->norm.forward(hidden_states, last_hidden_states); + this->norm.forward(hidden_states, last_hidden_states, rms_norm_eps); struct Fp32llamaDecoder_output output = {last_hidden_states, past_keys, past_values}; PROFILE_END(profile_name); diff --git a/llm/src/nn_modules/Fp32llamaDecoderLayer.cc b/llm/src/nn_modules/Fp32llamaDecoderLayer.cc index 05918079..9188555a 100644 --- a/llm/src/nn_modules/Fp32llamaDecoderLayer.cc +++ b/llm/src/nn_modules/Fp32llamaDecoderLayer.cc @@ -37,7 +37,7 @@ struct Fp32llamaDecoderLayer_output Fp32llamaDecoderLayer::forward(const struct // Layernorm Matrix3D hidden_states(hidden_states_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, input.hidden_states.m_dim_z); - this->input_layernorm.forward(input.hidden_states, hidden_states); + this->input_layernorm.forward(input.hidden_states, hidden_states, rms_norm_eps); // Attention struct Fp32llamaAttention_input attn_param(hidden_states, input.attention_mask, input.past_key, input.past_value, @@ -52,7 +52,7 @@ struct Fp32llamaDecoderLayer_output Fp32llamaDecoderLayer::forward(const struct // Layernorm Matrix3D post_attention_layernorm(final_layer_norm_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, input.hidden_states.m_dim_z); - this->post_attention_layernorm.forward(residual_add, post_attention_layernorm); + this->post_attention_layernorm.forward(residual_add, post_attention_layernorm, rms_norm_eps); // Gate proj: embedding -> hidden_dim Matrix3D gate_proj(gate_proj_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, @@ -102,6 +102,8 @@ Fp32llamaDecoderLayer::Fp32llamaDecoderLayer(std::string param_path, const struc post_attention_layernorm_weight.load((param_path + "/post_attention_layernorm/weight.bin").c_str()); this->post_attention_layernorm = LlamaRMSNorm(post_attention_layernorm_weight); + this->rms_norm_eps = config.rms_norm_eps; + this->embed_dim = config.embed_dim; this->num_attention_heads = config.num_heads; this->hidden_dim = config.hidden_dim; diff --git a/llm/src/nn_modules/cuda/Int4llamaDecoder.cu b/llm/src/nn_modules/cuda/Int4llamaDecoder.cu index 7c6a0d30..a4769c72 100644 --- a/llm/src/nn_modules/cuda/Int4llamaDecoder.cu +++ b/llm/src/nn_modules/cuda/Int4llamaDecoder.cu @@ -30,6 +30,7 @@ Int4llamaDecoder::Int4llamaDecoder(std::string param_path, const struct model_co this->hidden_dim = config.hidden_dim; this->num_heads = config.num_heads; this->padding_idx = config.padding_idx; + this->rms_norm_eps = config.rms_norm_eps; // Embedding Matrix3D embweight(new float[voc_size * embed_dim], 1, voc_size, embed_dim); @@ -102,7 +103,7 @@ struct Int4llamaDecoder_output Int4llamaDecoder::forward(std::string param_path, } Matrix3D last_hidden_states(last_hidden_states_buf, 1, sqlen, this->embed_dim); - this->norm.forward(hidden_states, last_hidden_states); + this->norm.forward(hidden_states, last_hidden_states, rms_norm_eps); struct Int4llamaDecoder_output output = {last_hidden_states, past_keys, past_values}; PROFILE_END(profile_name); diff --git a/llm/src/nn_modules/cuda/Int4llamaDecoderLayer.cu b/llm/src/nn_modules/cuda/Int4llamaDecoderLayer.cu index 5af4c391..02c14c52 100644 --- a/llm/src/nn_modules/cuda/Int4llamaDecoderLayer.cu +++ b/llm/src/nn_modules/cuda/Int4llamaDecoderLayer.cu @@ -50,6 +50,8 @@ Int4llamaDecoderLayer::Int4llamaDecoderLayer(std::string param_path, const struc post_attention_layernorm_weight.load((param_path + "/post_attention_layernorm/weight.bin").c_str()); this->post_attention_layernorm = LlamaRMSNorm_cuda(post_attention_layernorm_weight); + this->rms_norm_eps = config.rms_norm_eps; + this->embed_dim = config.embed_dim; this->num_attention_heads = config.num_heads; this->hidden_dim = config.hidden_dim; @@ -73,7 +75,7 @@ struct Int4llamaDecoderLayer_output Int4llamaDecoderLayer::forward(std::string p Matrix3D hidden_states(hidden_states_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, input.hidden_states.m_dim_z); - this->input_layernorm.forward(input.hidden_states, hidden_states); + this->input_layernorm.forward(input.hidden_states, hidden_states, rms_norm_eps); struct Int4llamaAttention_input attn_param(hidden_states, input.attention_mask, input.past_key, input.past_value, input.has_past_key_value, this->layer_idx); @@ -87,7 +89,7 @@ struct Int4llamaDecoderLayer_output Int4llamaDecoderLayer::forward(std::string p Matrix3D post_attention_layernorm(final_layer_norm_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, input.hidden_states.m_dim_z); - this->post_attention_layernorm.forward(residual_add, post_attention_layernorm); + this->post_attention_layernorm.forward(residual_add, post_attention_layernorm, rms_norm_eps); Matrix3D gate_proj(gate_proj_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, this->hidden_dim); diff --git a/llm/src/nn_modules/cuda/LLaMAGenerate.cu b/llm/src/nn_modules/cuda/LLaMAGenerate.cu index 93aa21ba..c76b1ed5 100644 --- a/llm/src/nn_modules/cuda/LLaMAGenerate.cu +++ b/llm/src/nn_modules/cuda/LLaMAGenerate.cu @@ -27,23 +27,27 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ break; } } - if (interactive) std::cout << "ASSISTANT: " << std::endl; + // if (interactive) std::cout << "ASSISTANT: " << std::endl; - bool has_past_kv = false; bool previous_two_hash = false; + int break_cnt = 2; + bool new_prompt = true; + static bool has_past_kv = false; #ifdef QM_CUDA - std::vector> past_keys, past_values; + static std::vector> past_keys, past_values; #else - std::vector> past_keys, past_values; + static std::vector> past_keys, past_values; #endif - std::vector> past_keys_fp32, past_values_fp32; + static std::vector> past_keys_fp32, past_values_fp32; int n_remain = generation_config.n_predict; - int break_cnt = 2; std::string output; while (n_remain != 0 && break_cnt) { std::vector logits(generation_config.n_vocab); int sqlen = 1; + if (new_prompt) { + sqlen = input_ids.size(); + } if (model_type == LLaMA_INT4) { Int4LlamaForCausalLM *model = static_cast(model_ptr); struct Int4LlamaForCausalLM_output model_output; @@ -52,13 +56,12 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat, past_keys, past_values}; } else { - sqlen = input_ids.size(); Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat}; } - if (has_past_kv) STATS_START("Inference latency"); + if (!new_prompt) STATS_START("Inference latency"); model_output = model->forward(param_path, model_input); - if (has_past_kv) STATS_END("Inference latency"); + if (!new_prompt) STATS_END("Inference latency"); past_keys = model_output.past_keys; past_values = model_output.past_values; // memcpy model_ouput.logits[-1] to logits @@ -72,13 +75,12 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat, past_keys_fp32, past_values_fp32}; } else { - sqlen = input_ids.size(); Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat}; } - if (has_past_kv) STATS_START("Inference latency"); + if (!new_prompt) STATS_START("Inference latency"); model_output = model->forward(model_input); - if (has_past_kv) STATS_END("Inference latency"); + if (!new_prompt) STATS_END("Inference latency"); past_keys_fp32 = model_output.past_keys; past_values_fp32 = model_output.past_values; // memcpy model_ouput.logits[-1] to logits @@ -86,6 +88,7 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ generation_config.n_vocab * sizeof(float)); } has_past_kv = true; + new_prompt = false; // Generate const int n_ctx = generation_config.n_ctx; diff --git a/llm/src/nn_modules/non_cuda/Int4llamaDecoder.cc b/llm/src/nn_modules/non_cuda/Int4llamaDecoder.cc index fda0011d..8f481a47 100644 --- a/llm/src/nn_modules/non_cuda/Int4llamaDecoder.cc +++ b/llm/src/nn_modules/non_cuda/Int4llamaDecoder.cc @@ -35,6 +35,7 @@ Int4llamaDecoder::Int4llamaDecoder(std::string param_path, const struct model_co this->hidden_dim = config.hidden_dim; this->num_heads = config.num_heads; this->padding_idx = config.padding_idx; + this->rms_norm_eps = config.rms_norm_eps; int max_sqlen = config.max_sqlen; @@ -101,7 +102,7 @@ struct Int4llamaDecoder_output Int4llamaDecoder::forward(std::string param_path, // Layernorm Matrix3D last_hidden_states(last_hidden_states_buf, 1, sqlen, this->embed_dim); - this->norm.forward(hidden_states, last_hidden_states); + this->norm.forward(hidden_states, last_hidden_states, rms_norm_eps); struct Int4llamaDecoder_output output = {last_hidden_states, past_keys, past_values}; PROFILE_END(profile_name); diff --git a/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc b/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc index cbdb938f..0330ba24 100644 --- a/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc +++ b/llm/src/nn_modules/non_cuda/Int4llamaDecoderLayer.cc @@ -61,7 +61,7 @@ struct Int4llamaDecoderLayer_output Int4llamaDecoderLayer::forward(std::string p // Layernorm Matrix3D hidden_states(hidden_states_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, input.hidden_states.m_dim_z); - this->input_layernorm.forward(input.hidden_states, hidden_states); + this->input_layernorm.forward(input.hidden_states, hidden_states, rms_norm_eps); // Attention struct Int4llamaAttention_input attn_param(hidden_states, input.attention_mask, input.past_key, input.past_value, @@ -76,7 +76,7 @@ struct Int4llamaDecoderLayer_output Int4llamaDecoderLayer::forward(std::string p // Layernorm Matrix3D post_attention_layernorm(final_layer_norm_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, input.hidden_states.m_dim_z); - this->post_attention_layernorm.forward(residual_add, post_attention_layernorm); + this->post_attention_layernorm.forward(residual_add, post_attention_layernorm, rms_norm_eps); // Gate proj: embedding -> hidden_dim Matrix3D gate_proj(gate_proj_arr, input.hidden_states.m_dim_x, input.hidden_states.m_dim_y, @@ -147,6 +147,8 @@ Int4llamaDecoderLayer::Int4llamaDecoderLayer(std::string param_path, const struc post_attention_layernorm_weight.load((param_path + "/post_attention_layernorm/weight.bin").c_str()); this->post_attention_layernorm = LlamaRMSNorm(post_attention_layernorm_weight); + this->rms_norm_eps = config.rms_norm_eps; + this->embed_dim = config.embed_dim; this->num_attention_heads = config.num_heads; this->hidden_dim = config.hidden_dim; diff --git a/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc b/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc index b0c15930..d72acd14 100644 --- a/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc +++ b/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc @@ -28,6 +28,11 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ const int n = llama_tokenize(vocab, text.c_str(), input_ids.data(), input_ids.size(), true); input_ids.resize(n); + bool is_codellama = false; + if (param_path.find("CodeLLaMA") != std::string::npos) { + is_codellama = true; + } + int n_consumed = 0; while ((int)input_ids.size() > n_consumed) { embd.push_back(input_ids[n_consumed]); @@ -39,18 +44,22 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ break; } } - if (interactive) std::cout << "ASSISTANT: " << std::endl; + // if (interactive) std::cout << "ASSISTANT: " << std::endl; - bool has_past_kv = false; bool previous_two_hash = false; - std::vector> past_keys, past_values; - int n_remain = generation_config.n_predict; int break_cnt = 2; + bool new_prompt = true; + static bool has_past_kv = false; + static std::vector> past_keys, past_values; + int n_remain = generation_config.n_predict; std::string output; while (n_remain != 0 && break_cnt) { std::vector logits(generation_config.n_vocab); int sqlen = 1; + if (new_prompt) { + sqlen = input_ids.size(); + } if (model_type == LLaMA_INT4) { Int4LlamaForCausalLM *model = static_cast(model_ptr); struct Int4LlamaForCausalLM_output model_output; @@ -59,13 +68,12 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat, past_keys, past_values}; } else { - sqlen = input_ids.size(); Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat}; } - if (has_past_kv) STATS_START("Inference latency"); + if (!new_prompt) STATS_START("Inference latency"); model_output = model->forward(param_path, model_input); - if (has_past_kv) STATS_END("Inference latency"); + if (!new_prompt) STATS_END("Inference latency"); past_keys = model_output.past_keys; past_values = model_output.past_values; // memcpy model_ouput.logits[-1] to logits @@ -79,13 +87,12 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat, past_keys, past_values}; } else { - sqlen = input_ids.size(); Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen); model_input = {input_ids_mat}; } - if (has_past_kv) STATS_START("Inference latency"); + if (!new_prompt) STATS_START("Inference latency"); model_output = model->forward(model_input); - if (has_past_kv) STATS_END("Inference latency"); + if (!new_prompt) STATS_END("Inference latency"); past_keys = model_output.past_keys; past_values = model_output.past_values; // memcpy model_ouput.logits[-1] to logits @@ -171,6 +178,11 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ previous_two_hash = false; } + if (is_codellama && new_prompt) { + new_prompt = false; + continue; + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); embd.push_back(id); @@ -226,11 +238,12 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ } } } + + new_prompt = false; --n_remain; } if (voicechat && interactive){ sayInBackground(output); - } if (interactive) std::cout << std::endl; diff --git a/llm/src/ops/LlamaRMSNorm.cc b/llm/src/ops/LlamaRMSNorm.cc index e1b041c0..9a801d84 100644 --- a/llm/src/ops/LlamaRMSNorm.cc +++ b/llm/src/ops/LlamaRMSNorm.cc @@ -4,7 +4,7 @@ #include "operators.h" #include "utils.h" -void LlamaRMSNorm::forward(const Matrix3D &x, Matrix3D &output) { +void LlamaRMSNorm::forward(const Matrix3D &x, Matrix3D &output, float eps) { PROFILE_START(profile_name); const int last_dims = 2; diff --git a/llm/src/ops/cuda/LlamaRMSNorm.cu b/llm/src/ops/cuda/LlamaRMSNorm.cu index f86eb427..ffe5b159 100644 --- a/llm/src/ops/cuda/LlamaRMSNorm.cu +++ b/llm/src/ops/cuda/LlamaRMSNorm.cu @@ -2,7 +2,6 @@ #include #include "operators.h" -#include "utils.h" #include "reduction.cuh" static inline __device__ float to_float(half src) @@ -94,7 +93,7 @@ __global__ void generalT5LayerNorm( } -void LlamaRMSNorm_cuda::forward(const Matrix3D &x, Matrix3D &output) { +void LlamaRMSNorm_cuda::forward(const Matrix3D &x, Matrix3D &output, float eps) { int m = x.m_dim_x * x.m_dim_y; int n = x.m_dim_z; dim3 grid(m); diff --git a/llm/src/ops/linear.cc b/llm/src/ops/linear.cc index 252904ac..136d50e8 100644 --- a/llm/src/ops/linear.cc +++ b/llm/src/ops/linear.cc @@ -109,7 +109,7 @@ void Linear_FP_int4::forward_ref(const Matrix3D &a, Matrix3D &c) { } void Linear_FP_int4::forward_fast(const Matrix3D &x, Matrix3D &output) { - const int num_thread = 8; + const int num_thread = NUM_THREAD; Matrix3D b = this->weight; const int m = x.m_dim_y, n = b.m_dim_y, k = x.m_dim_z, b_size = b.m_dim_x; const long long ops = (long long)b_size * 2 * (long long)m * (long long)n * (long long)k; @@ -161,7 +161,7 @@ void Linear_FP_int4::initialize_memory(const int block_size) { #endif // USE_INT8_INT4_PRODUCT void Linear_FP_int4::forward(const Matrix3D &x, Matrix3D &output) { - const int num_thread = 16; + const int num_thread = NUM_THREAD; Matrix3D b = this->weight; const int m = x.m_dim_y, n = b.m_dim_y, k = x.m_dim_z, b_size = b.m_dim_x; const long long ops = (long long)b_size * 2 * (long long)m * (long long)n * (long long)k; diff --git a/llm/tests/cuda/test_Int4llamaAttention.cu b/llm/tests/cuda/test_Int4llamaAttention.cu index 12fb8cac..e269167e 100644 --- a/llm/tests/cuda/test_Int4llamaAttention.cu +++ b/llm/tests/cuda/test_Int4llamaAttention.cu @@ -15,7 +15,7 @@ void test_Int4llamaAttention() { const struct model_config llama7B = llama_7B; const int sqlen = 9, b = 1, embed_dim = llama7B.embed_dim, num_heads = llama7B.num_heads; - Int4llamaAttention attn = Int4llamaAttention("INT4/models/LLaMA_7B_2_chat/decoder/layer0/self_attn", llama7B); + Int4llamaAttention attn = Int4llamaAttention("INT4/models/LLaMA_7B_2_chat/decoder/layer0/self_attn", llama7B, 0); half* buffer_1; cudaMallocManaged(&buffer_1, sizeof(half) * embed_dim * sqlen * b); @@ -28,7 +28,7 @@ void test_Int4llamaAttention() { attn.initialized_memory(llama7B); struct Int4llamaAttention_input input(hidden_states, attention_mask, 0); - struct Int4llamaAttention_output output = attn.forward(input); + struct Int4llamaAttention_output output = attn.forward("INT4/models/LLaMA_7B_2_chat/decoder/layer0/self_attn", input); cudaDeviceSynchronize(); half* buffer_3; @@ -68,7 +68,7 @@ void test_Int4llamaAttention_gen() { const int sqlen = 1, b = 1, past_sqlen = 9, embed_dim = llama7B.embed_dim, num_heads = llama7B.num_heads, head_dim = embed_dim / num_heads; - Int4llamaAttention attn = Int4llamaAttention("INT4/models/LLaMA_7B_2_chat/decoder/layer0/self_attn", llama7B); + Int4llamaAttention attn = Int4llamaAttention("INT4/models/LLaMA_7B_2_chat/decoder/layer0/self_attn", llama7B, 0); half* buffer_1; cudaMallocManaged(&buffer_1, sizeof(half) * embed_dim * sqlen * b); @@ -89,7 +89,7 @@ void test_Int4llamaAttention_gen() { attn.initialized_memory(llama7B); struct Int4llamaAttention_input input(hidden_states, attention_mask, past_key, past_value, true, 0); - struct Int4llamaAttention_output output = attn.forward(input); + struct Int4llamaAttention_output output = attn.forward("INT4/models/LLaMA_7B_2_chat/decoder/layer0/self_attn", input); cudaDeviceSynchronize(); half* buffer_5; diff --git a/llm/tests/cuda/test_Int4llamaDecoder.cu b/llm/tests/cuda/test_Int4llamaDecoder.cu index ab607283..1d0a6402 100644 --- a/llm/tests/cuda/test_Int4llamaDecoder.cu +++ b/llm/tests/cuda/test_Int4llamaDecoder.cu @@ -32,7 +32,7 @@ void test_Decoder() { struct Int4llamaDecoder_input input_1st = {input_ids}; Int4llamaDecoder decoder = Int4llamaDecoder("INT4/models/LLaMA_7B_2_chat/decoder/", llama7B); - struct Int4llamaDecoder_output output_1st = decoder.forward(input_1st); + struct Int4llamaDecoder_output output_1st = decoder.forward("INT4/models/LLaMA_7B_2_chat/decoder/", input_1st); cudaDeviceSynchronize(); half* buffer_2; @@ -63,7 +63,7 @@ void test_Decoder() { input_ids_2nd.load("assets/llama/tests/decoder/2nd/input_ids.bin"); struct Int4llamaDecoder_input input_2nd = {input_ids_2nd, output_1st.past_keys, output_1st.past_values}; - struct Int4llamaDecoder_output output_2nd = decoder.forward(input_2nd); + struct Int4llamaDecoder_output output_2nd = decoder.forward("INT4/models/LLaMA_7B_2_chat/decoder/", input_2nd); cudaDeviceSynchronize(); half* buffer_5; diff --git a/llm/tests/cuda/test_Int4llamaDecoderLayer.cu b/llm/tests/cuda/test_Int4llamaDecoderLayer.cu index e0a16c70..88c857c4 100644 --- a/llm/tests/cuda/test_Int4llamaDecoderLayer.cu +++ b/llm/tests/cuda/test_Int4llamaDecoderLayer.cu @@ -31,7 +31,7 @@ void test_Int4llamaDecoderLayer() { read_to_array_half("assets/llama/tests/layer0/sqlen9/attention_mask_half.bin", attention_mask.m_data, attention_mask.length()); struct Int4llamaDecoderLayer_input input(hidden_states, attention_mask); - struct Int4llamaDecoderLayer_output output = layer.forward(input); + struct Int4llamaDecoderLayer_output output = layer.forward("INT4/models/LLaMA_7B_2_chat/decoder/layer0", input, 0); cudaDeviceSynchronize(); half* buffer_3; @@ -92,7 +92,7 @@ void test_Int4llamaDecoderLayer_gen() { read_to_array_half("assets/llama/tests/atten/sqlen9/past_value_half.bin", past_value.m_data, past_value.length()); struct Int4llamaDecoderLayer_input input(hidden_states, attention_mask, past_key, past_value); - struct Int4llamaDecoderLayer_output output = layer.forward(input); + struct Int4llamaDecoderLayer_output output = layer.forward("INT4/models/LLaMA_7B_2_chat/decoder/layer0", input, 0); cudaDeviceSynchronize(); half* buffer_5; diff --git a/llm/tests/cuda/test_Int4llamaForCausalLM.cu b/llm/tests/cuda/test_Int4llamaForCausalLM.cu index c9cd6c30..0ba99981 100644 --- a/llm/tests/cuda/test_Int4llamaForCausalLM.cu +++ b/llm/tests/cuda/test_Int4llamaForCausalLM.cu @@ -37,7 +37,7 @@ void test_Int4LlamaForCausalLM() { struct Int4LlamaForCausalLM_input input_1st = {input_ids}; Int4LlamaForCausalLM model = Int4LlamaForCausalLM("INT4/models/LLaMA_7B_2_chat", config); - struct Int4LlamaForCausalLM_output output_1st = model.forward(input_1st); + struct Int4LlamaForCausalLM_output output_1st = model.forward("INT4/models/LLaMA_7B_2_chat", input_1st); float* buffer_2; cudaMallocManaged(&buffer_2, sizeof(float) * b * sqlen * voc_size); @@ -55,7 +55,7 @@ void test_Int4LlamaForCausalLM() { input_ids_2nd.load("assets/llama/tests/model/2nd_input_ids.bin"); struct Int4LlamaForCausalLM_input input_2nd = {input_ids_2nd, output_1st.past_keys, output_1st.past_values}; - struct Int4LlamaForCausalLM_output output_2nd = model.forward(input_2nd); + struct Int4LlamaForCausalLM_output output_2nd = model.forward("INT4/models/LLaMA_7B_2_chat", input_2nd); float* buffer_4; cudaMallocManaged(&buffer_4, sizeof(float) * b * 1 * voc_size); diff --git a/llm/tests/cuda/test_ops.cu b/llm/tests/cuda/test_ops.cu index 600e5dd3..9c902e5c 100644 --- a/llm/tests/cuda/test_ops.cu +++ b/llm/tests/cuda/test_ops.cu @@ -566,7 +566,7 @@ void test_LlamaRMSNorm() { LlamaRMSNorm op(weight); - op.forward(hidden_states, output); + op.forward(hidden_states, output, llama7B.rms_norm_eps); bool success = check_two_equal(output.m_data, outputGT.m_data, sqlen * embed_dim); if (!success) @@ -692,7 +692,7 @@ void test_LlamaRMSNorm_cuda() { Matrix3D output(buffer_4, 1, sqlen, embed_dim); LlamaRMSNorm_cuda op(weight); - op.forward(hidden_states, output); + op.forward(hidden_states, output, llama7B.rms_norm_eps); cudaDeviceSynchronize(); bool success = check_two_equal_half_half(output.m_data, outputGT.m_data, sqlen * embed_dim); diff --git a/llm/tests/non_cuda/test_Int4llamaAttention.cc b/llm/tests/non_cuda/test_Int4llamaAttention.cc index 23a26ef6..638d2feb 100644 --- a/llm/tests/non_cuda/test_Int4llamaAttention.cc +++ b/llm/tests/non_cuda/test_Int4llamaAttention.cc @@ -9,7 +9,7 @@ void test_Int4llamaAttention() { MemoryAllocator mem_buf; - Int4llamaAttention attn = Int4llamaAttention("models/LLaMA_7B/decoder/layer0/self_attn", llama7B); + Int4llamaAttention attn = Int4llamaAttention("models/LLaMA_7B/decoder/layer0/self_attn", llama7B, 0); Matrix3D hidden_states(mem_buf.get_fpbuffer(embed_dim * sqlen), b, sqlen, embed_dim); read_to_array("assets/llama/tests/atten/sqlen9/hidden_states.bin", hidden_states.m_data, b * sqlen * embed_dim); @@ -20,7 +20,7 @@ void test_Int4llamaAttention() { attn.initialized_memory(llama7B); struct Int4llamaAttention_input input(hidden_states, attention_mask, 0); - struct Int4llamaAttention_output output = attn.forward(input); + struct Int4llamaAttention_output output = attn.forward("models/LLaMA_7B/decoder/layer0/self_attn", input); Matrix3D attn_outputGT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); read_to_array("assets/llama/tests/atten/sqlen9/attn_output.bin", attn_outputGT.m_data, b * sqlen * embed_dim); @@ -45,7 +45,7 @@ void test_Int4llamaAttention_gen() { MemoryAllocator mem_buf; - Int4llamaAttention attn = Int4llamaAttention("models/LLaMA_7B/decoder/layer0/self_attn", llama7B); + Int4llamaAttention attn = Int4llamaAttention("models/LLaMA_7B/decoder/layer0/self_attn", llama7B, 0); Matrix3D hidden_states(mem_buf.get_fpbuffer(embed_dim * sqlen), b, sqlen, embed_dim); hidden_states.load("assets/llama/tests/atten/sqlen1/hidden_states.bin"); @@ -59,7 +59,7 @@ void test_Int4llamaAttention_gen() { attn.initialized_memory(llama7B); struct Int4llamaAttention_input input(hidden_states, attention_mask, past_key, past_value, true, 0); - struct Int4llamaAttention_output output = attn.forward(input); + struct Int4llamaAttention_output output = attn.forward("models/LLaMA_7B/decoder/layer0/self_attn", input); Matrix3D attn_outputGT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); attn_outputGT.load("assets/llama/tests/atten/sqlen1/attn_output.bin"); diff --git a/llm/tests/non_cuda/test_Int4llamaDecoder.cc b/llm/tests/non_cuda/test_Int4llamaDecoder.cc index 8e72d2a4..2c21f1d7 100644 --- a/llm/tests/non_cuda/test_Int4llamaDecoder.cc +++ b/llm/tests/non_cuda/test_Int4llamaDecoder.cc @@ -17,7 +17,7 @@ void test_Decoder() { Int4llamaDecoder decoder = Int4llamaDecoder("models/LLaMA_7B/decoder/", llama7B); - struct Int4llamaDecoder_output output_1st = decoder.forward(input_1st); + struct Int4llamaDecoder_output output_1st = decoder.forward("models/LLaMA_7B/decoder/", input_1st); // reasoning phase: 1st run Matrix3D last_hidden_state1_GT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); @@ -47,7 +47,7 @@ void test_Decoder() { input_ids_2nd.load("assets/llama/tests/decoder/2nd/input_ids.bin"); struct Int4llamaDecoder_input input_2nd = {input_ids_2nd, output_1st.past_keys, output_1st.past_values}; - struct Int4llamaDecoder_output output_2nd = decoder.forward(input_2nd); + struct Int4llamaDecoder_output output_2nd = decoder.forward("models/LLaMA_7B/decoder/", input_2nd); Matrix3D last_hidden_state2_GT(mem_buf.get_fpbuffer(b * 1 * embed_dim), b, 1, embed_dim); last_hidden_state2_GT.load("assets/llama/tests/decoder/2nd/last_hidden_state.bin"); diff --git a/llm/tests/non_cuda/test_Int4llamaDecoderLayer.cc b/llm/tests/non_cuda/test_Int4llamaDecoderLayer.cc index 25fb9ba3..dc74d49a 100644 --- a/llm/tests/non_cuda/test_Int4llamaDecoderLayer.cc +++ b/llm/tests/non_cuda/test_Int4llamaDecoderLayer.cc @@ -19,7 +19,7 @@ void test_Int4llamaDecoderLayer() { struct Int4llamaDecoderLayer_input input(hidden_states, attention_mask); - struct Int4llamaDecoderLayer_output output = layer.forward(input); + struct Int4llamaDecoderLayer_output output = layer.forward("models/LLaMA_7B/decoder/layer0", input, 0); Matrix3D outputGT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); outputGT.load("assets/llama/tests/layer0/sqlen9/output_hidden_states.bin"); @@ -58,7 +58,7 @@ void test_Int4llamaDecoderLayer_gen() { struct Int4llamaDecoderLayer_input input(hidden_states, attention_mask, past_key, past_value); - struct Int4llamaDecoderLayer_output output = layer.forward(input); + struct Int4llamaDecoderLayer_output output = layer.forward("models/LLaMA_7B/decoder/layer0", input, 0); Matrix3D outputGT(mem_buf.get_fpbuffer(b * sqlen * embed_dim), b, sqlen, embed_dim); outputGT.load("assets/llama/tests/layer0/sqlen1/output_hidden_states.bin"); diff --git a/llm/tests/non_cuda/test_Int4llamaForCausalLM.cc b/llm/tests/non_cuda/test_Int4llamaForCausalLM.cc index abb83d81..17edef2f 100644 --- a/llm/tests/non_cuda/test_Int4llamaForCausalLM.cc +++ b/llm/tests/non_cuda/test_Int4llamaForCausalLM.cc @@ -20,7 +20,7 @@ void test_Int4LlamaForCausalLM() { Int4LlamaForCausalLM model = Int4LlamaForCausalLM("models/LLaMA_7B", config); - struct Int4LlamaForCausalLM_output output_1st = model.forward(input_1st); + struct Int4LlamaForCausalLM_output output_1st = model.forward("models/LLaMA_7B", input_1st); Matrix3D logits(mem_buf.get_fpbuffer(b * sqlen * voc_size), b, sqlen, voc_size); logits.load("assets/llama/tests/model/1st_logits.bin"); @@ -39,7 +39,7 @@ void test_Int4LlamaForCausalLM() { struct Int4LlamaForCausalLM_input input_2nd = {input_ids_2nd, output_1st.past_keys, output_1st.past_values}; struct Int4LlamaForCausalLM_output output_2nd; - for (int i = 0; i < 10; i++) output_2nd = model.forward(input_2nd); + for (int i = 0; i < 10; i++) output_2nd = model.forward("models/LLaMA_7B", input_2nd); logits = Matrix3D(mem_buf.get_fpbuffer(b * 1 * voc_size), b, 1, voc_size); logits.load("assets/llama/tests/model/2nd_logits.bin"); diff --git a/llm/tests/non_cuda/test_ops.cc b/llm/tests/non_cuda/test_ops.cc index f111d1aa..17a64f59 100644 --- a/llm/tests/non_cuda/test_ops.cc +++ b/llm/tests/non_cuda/test_ops.cc @@ -566,7 +566,7 @@ void test_LlamaRMSNorm() { LlamaRMSNorm op(weight); - op.forward(hidden_states, output); + op.forward(hidden_states, output, llama7B.rms_norm_eps); bool success = check_two_equal(output.m_data, outputGT.m_data, sqlen * embed_dim); if (!success) diff --git a/llm/tools/download_model.py b/llm/tools/download_model.py index d895004a..64d502f7 100644 --- a/llm/tools/download_model.py +++ b/llm/tools/download_model.py @@ -30,6 +30,14 @@ "url": "https://www.dropbox.com/scl/fi/qpzv3805ftdldvlocssu4/LLaMA_13B_2_chat.zip?rlkey=tfgnv9cz2i8lwuznyy6u3sf4k&dl=1", # noqa: E501 "md5sum": "59b73efa638be4131e5fd27c3fdee597", }, + "CodeLLaMA_7B_Instruct_fp32": { + "url": "", + "md5sum": "", + }, + "CodeLLaMA_13B_Instruct_fp32": { + "url": "", + "md5sum": "", + }, "opt_6.7B_fp32": { "url": "https://www.dropbox.com/scl/fi/mwy0uw51anodezy9rtcf1/OPT_6.7B.zip?rlkey=f8mtjg5eesuflrz3t5och4219&dl=1", "md5sum": "69cffdc090388ac2d2abcbe8163b0397", @@ -58,6 +66,14 @@ "url": "https://www.dropbox.com/scl/fi/rb7el1reycad98xrzif9a/LLaMA_13B_2_chat.zip?rlkey=wwd400no2uelcthvqxut3ojvj&dl=1", # noqa: E501 "md5sum": "f1f7693da630bb7aa269ecae5bcc397a", }, + "CodeLLaMA_7B_Instruct_awq_int4": { + "url": "https://www.dropbox.com/scl/fi/m6qcwnsg37sdtewvh41sb/CodeLLaMA_7B_Instruct.zip?rlkey=mlnn1s76k63zez44uatmsc7ij&dl=1", + "md5sum": "a5b4c15857944daaa1e1ee34c5917264", + }, + "CodeLLaMA_13B_Instruct_awq_int4": { + "url": "https://www.dropbox.com/scl/fi/7gcmtonyyyavdaeccnivi/CodeLLaMA_13B_Instruct.zip?rlkey=e1u6ne71prrtcjh1sp8hs5sns&dl=1", + "md5sum": "d749ec83a54dcf40a7d87e7dbfba42d4", + }, "opt_125m_awq_int4": { "url": "https://www.dropbox.com/scl/fi/3dedmlzi36jngj74iskr6/OPT_125m.zip?rlkey=hy7z46cwfbr4dlz9bcs1mtx5b&dl=1", # noqa: E501 "md5sum": "2b42c3866c54642557046140367217fa", @@ -84,6 +100,14 @@ "url": "https://www.dropbox.com/scl/fi/t4u1jkp7gav8om4m6xjjv/LLaMA_13B_2_chat.zip?rlkey=tahltmq9bqu3ofx03r4mrsk2r&dl=1", # noqa: E501 "md5sum": "3684e5740f44ed05e213d6d807a1f136", }, + "CodeLLaMA_7B_Instruct_awq_int4": { + "url": "https://www.dropbox.com/scl/fi/fav8kvwcuw1dpdiykny24/CodeLLaMA_7B_Instruct.zip?rlkey=bjhf467r8xb7di2lilqbgv8vm&dl=1", + "md5sum": "b208eec1b1bbb6532f26b68a7a3caae6", + }, + "CodeLLaMA_13B_Instruct_awq_int4": { + "url": "https://www.dropbox.com/scl/fi/0appg7uacff9z21hth06n/CodeLLaMA_13B_Instruct.zip?rlkey=v6fxuomhqmskwqgtclsat9pzt&dl=1", + "md5sum": "71ade74fe50b6beb378d52e19396926d", + }, "opt_125m_awq_int4": { "url": "https://www.dropbox.com/scl/fi/sl6kc1ql0877w550e4v17/OPT_125m.zip?rlkey=fsdqf3bc0vktl7iv6pfi6bbyx&dl=1", # noqa: E501 "md5sum": "c9c26bb5c8bf9867e21e525da744ef19", @@ -106,6 +130,14 @@ "url": "https://www.dropbox.com/scl/fi/fes1l27b9kv4dn4h0qjzu/LLaMA_13B_2_chat.zip?rlkey=u1j2kt96xpj764zkj1v87gw6u&dl=1", # noqa: E501 "md5sum": "802c81d86b6393aff3e93326e5b58f7f", }, + "CodeLLaMA_7B_Instruct_awq_int4": { + "url": "", + "md5sum": "", + }, + "CodeLLaMA_13B_Instruct_awq_int4": { + "url": "", + "md5sum": "", + }, }, "INT8": { "opt_125m_smooth_int8": { diff --git a/llm/tools/llama_exporter.py b/llm/tools/llama_exporter.py index f886bae0..5213dc33 100644 --- a/llm/tools/llama_exporter.py +++ b/llm/tools/llama_exporter.py @@ -1,4 +1,4 @@ -"""Implementation of exporting LLaMA PyTorch model to TinyLLMEngine format. +"""Implementation of exporting LLaMA PyTorch model to TinyChatEngine format. Usage: python llama_exporter.py @@ -99,11 +99,11 @@ def _export_attention_params(attn, prefix: str): def main(): - """Export a LLaMA model to TinyLLMEngine format.""" - parser = argparse.ArgumentParser(description="export LLaMA pytorch model to TinyLLMEngine format.") + """Export a LLaMA model to TinyChatEngine format.""" + parser = argparse.ArgumentParser(description="export LLaMA pytorch model to TinyChatEngine format.") parser.add_argument("--hf_path", type=str, help="Path to huggingface model hub", default=None) - parser.add_argument("model", type=str, help="Path of the LLaMA torch model") - parser.add_argument("output", type=str, help="Output directory of the exported model") + parser.add_argument("--model", type=str, help="Path of the LLaMA torch model") + parser.add_argument("--output", type=str, help="Output directory of the exported model") args = parser.parse_args() @@ -118,12 +118,29 @@ def main(): print("Loading model...") if args.model.endswith(".pt"): - model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", torch_dtype=torch.float16) + if args.model.split("/")[-1].lower().startswith("llama-2"): + if args.model.split("-")[2].lower() == "7b": + print("Loading LLaMA 7B model..."); + model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", torch_dtype=torch.float16) + elif args.model.split("-")[2].lower() == "13b": + print("Loading LLaMA 13B model..."); + model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-13b-hf", torch_dtype=torch.float16) + elif args.model.split("/")[-1].lower().startswith("codellama"): + if args.model.split("-")[1].lower() == "7b": + print("Loading CodaLLaMA 7B model..."); + model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-Instruct-hf", torch_dtype=torch.float16) + elif args.model.split("-")[1].lower() == "13b": + print("Loading CodaLLaMA 13B model..."); + model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-13b-Instruct-hf", torch_dtype=torch.float16) + else: + print("Model not supported.") + return + model.load_state_dict(torch.load(args.model)) else: model = LlamaForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16) else: - model = LlamaForCausalLM.from_pretrained(args.hf_path, torch_dtype=torch.float16) + model = LlamaForCausalLM.from_pretrained(args.hf_path, torch_dtype=torch.bfloat16) print("Start exporting the model...") _export_model(model, args.output) diff --git a/llm/tools/model_quantizer.py b/llm/tools/model_quantizer.py index d5f1f1fe..f5bfd3ee 100644 --- a/llm/tools/model_quantizer.py +++ b/llm/tools/model_quantizer.py @@ -124,9 +124,9 @@ def _quantize_model( layer_num = 24 elif model_name_size == "OPT_6.7B": layer_num = 32 - elif model_name_size.startswith("LLaMA_7B"): + elif model_name_size.startswith("LLaMA_7B") or model_name_size.startswith("CodeLLaMA_7B"): layer_num = 32 - elif model_name_size.startswith("LLaMA_13B"): + elif model_name_size.startswith("LLaMA_13B") or model_name_size.startswith("CodeLLaMA_13B"): layer_num = 40 else: raise ValueError( @@ -273,11 +273,11 @@ def _quantize_model( print(f"Quantization of layer {idx} finished.") # LLaMA - elif model_name.startswith("LLaMA"): - if model_name.startswith("LLaMA_7B"): + elif model_name.startswith("LLaMA") or model_name.startswith("CodeLLaMA"): + if model_name.startswith("LLaMA_7B") or model_name.startswith("CodeLLaMA_7B"): embed_dim = 4096 hidden_dim = 11008 - elif model_name.startswith("LLaMA_13B"): + elif model_name.startswith("LLaMA_13B") or model_name.startswith("CodeLLaMA_13B"): embed_dim = 5120 hidden_dim = 13824 else: diff --git a/llm/tools/upload.py b/llm/tools/upload.py index bb46d150..12f6dbe5 100644 --- a/llm/tools/upload.py +++ b/llm/tools/upload.py @@ -9,7 +9,11 @@ import dropbox files_to_upload = [ - "assets.zip", + "CodeLLaMA_13B_Instruct.zip", + "CodeLLaMA_7B_Instruct.zip", + # "LLaMA_13B_2_chat.zip", + # "LLaMA_7B_2_chat.zip", + # "assets.zip", ] @@ -44,7 +48,7 @@ def subebackups(file_path, target_path, token): parser.add_argument("token", help="Your Dropbox OAuth2 token.") args = parser.parse_args() - db_prefix = "/MIT/transformer_assets/" + db_prefix = "/HAN Lab Public Space/Projects/TinyChatEngine/assets and models/QM_x86/" local_prefix = "uploads" for file in files_to_upload: