Skip to content

Commit

Permalink
Support new features (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 authored Oct 8, 2023
1 parent a6cb261 commit 3aa79df
Show file tree
Hide file tree
Showing 42 changed files with 298 additions and 123 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.a
*.pyc
*.cuu
*.ccc
.DS_Store
.build/
.cache/
Expand All @@ -17,6 +18,7 @@ assets/
*.zip
*.txt
!requirements.txt
*.pt
*.json
test_*
!test_*.cc
Expand Down
3 changes: 2 additions & 1 deletion kernels/avx/matmul_avx_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ static void quantize_fp_to_int8_block_size32(float *x, int size, int8_t *qx, flo
namespace matmul {

void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params) {
const int num_thread = 4; // params->opt_params.num_thread;
// const int num_thread = 4;
const int num_thread = params->opt_params.num_thread;
int i, j, k;
pthread_t thread_pool[num_thread];
struct int4_thread_args threads_args[num_thread];
Expand Down
3 changes: 2 additions & 1 deletion kernels/neon/matmul_neon_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ static void *fast_zp_no_offset_over_column_func_v3(void *args) {
namespace matmul {

void MatmulOperator::mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params) {
const int num_thread = 32;
// const int num_thread = 32;
const int num_thread = params->opt_params.num_thread;
int i, j, k;
pthread_t thread_pool[num_thread];
struct int4_thread_args threads_args[num_thread];
Expand Down
3 changes: 2 additions & 1 deletion kernels/neon/matmul_neon_int4_offset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ static void *fast_over_column_func_v1(void *args) {
namespace matmul {

void MatmulOperator::mat_mul_accelerator_int4_fast(const struct matmul_params *params) {
const int num_thread = 16;
// const int num_thread = 16;
const int num_thread = params->opt_params.num_thread;
int i, j, k;
pthread_t thread_pool[num_thread];
struct int4_thread_args threads_args[num_thread];
Expand Down
3 changes: 2 additions & 1 deletion kernels/neon/matmul_neon_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_
// ref imp.
// matmul_int8_int4_no_offset(params);

const int num_thread = 8;
// const int num_thread = 8;
const int num_thread = params->opt_params.num_thread;
pthread_t thread_pool[num_thread];
struct a8w4_thread_args threads_args[num_thread];
assert(params->block_size == 32); // support block size 32 for now
Expand Down
83 changes: 75 additions & 8 deletions llm/application/chat.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#include <iostream>
#include <map>
#include <cstring>

#include "Generate.h"

std::map<std::string, int> model_config = {
{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B},
{"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B}};
{"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B},
{"CodeLLaMA_7B", CodeLLaMA_7B}, {"CodeLLaMA_13B", CodeLLaMA_13B}};

std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"},
{"OPT_1.3B", "models/OPT_1.3B"},
Expand All @@ -14,29 +16,62 @@ std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"}
{"LLaMA2_7B_chat", "models/LLaMA_7B_2_chat"},
{"LLaMA2_13B_chat", "models/LLaMA_13B_2_chat"},
{"7b", "models/LLaMA_7B_2_chat"},
{"13b", "models/LLaMA_13B_2_chat"}};
{"13b", "models/LLaMA_13B_2_chat"},
{"CodeLLaMA_7B", "models/CodeLLaMA_7B_Instruct"},
{"CodeLLaMA_13B", "models/CodeLLaMA_13B_Instruct"}};

std::map<std::string, int> data_format_list = {
{"FP32", FP32}, {"INT8", QINT8}, {"INT4", INT4}, {"int4", INT4}, {"fp32", FP32},
};

bool isLLaMA(std::string s) {
std::string LLaMA_prefix = "LLaMA";
std::string CodeLLaMA_prefix = "CodeLLaMA";

if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s == "7b" || s == "13b")
if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s.substr(0, CodeLLaMA_prefix.size()) == CodeLLaMA_prefix || s == "7b" || s == "13b")
return true;
else
return false;
}

bool isCodeLLaMA(std::string s) {
std::string CodeLLaMA_prefix = "CodeLLaMA";

if (s.substr(0, CodeLLaMA_prefix.size()) == CodeLLaMA_prefix)
return true;
else
return false;
}

bool convertToBool(const char* str) {
if (strcmp(str, "true") == 0 || strcmp(str, "1") == 0) {
return true;
}
else if (strcmp(str, "false") == 0 || strcmp(str, "0") == 0) {
return false;
}
else {
std::cerr << "Error: Invalid boolean value: " << str << std::endl;
exit(EXIT_FAILURE);
}
}

int main(int argc, char* argv[]) {
std::string target_model = "LLaMA2_7B_chat";
std::string target_data_format = "INT4";
bool instruct = true;
Profiler::getInstance().for_demo = true;

std::cout << "TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine" << std::endl;

if (argc == 3) {
if (argc >= 3 && argc <= 5) {
if (argc >= 4) {
NUM_THREAD = atoi(argv[3]);
}
if (argc == 5) {
instruct = convertToBool(argv[4]);
}

auto target_str = argv[1];
target_model = argv[1];
if (model_config.count(target_model) == 0) {
Expand Down Expand Up @@ -108,10 +143,14 @@ int main(int argc, char* argv[]) {
#endif

struct opt_params generation_config;
generation_config.n_predict = 512;
generation_config.n_vocab = 32000;
generation_config.temp = 0.1f;
generation_config.n_predict = 512;
generation_config.repeat_penalty = 1.25f;
generation_config.temp = 0.1f;
if(isCodeLLaMA(target_model)) {
generation_config.temp = 0.2f;
generation_config.top_p = 0.95f;
}

if (format_id == FP32) {
Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id));
Expand All @@ -122,7 +161,21 @@ int main(int argc, char* argv[]) {
std::cout << "USER: ";
std::string input;
std::getline(std::cin, input);
input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n";
if (instruct) {
std::cout << "ASSISTANT: " << std::endl;
if (isCodeLLaMA(target_model)) {
input = "<s>[INST] " + input + " [/INST]\n";
}
}
else {
if (isCodeLLaMA(target_model)) {
std::cout << input;
}
}

if (!isCodeLLaMA(target_model)) {
input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n";
}

LLaMAGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama_vocab.bin", true, false);
}
Expand All @@ -136,7 +189,21 @@ int main(int argc, char* argv[]) {
std::cout << "USER: ";
std::string input;
std::getline(std::cin, input);
input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n";
if (instruct) {
std::cout << "ASSISTANT: " << std::endl;
if (isCodeLLaMA(target_model)) {
input = "<s>[INST] " + input + " [/INST]";
}
}
else {
if (isCodeLLaMA(target_model)) {
std::cout << input;
}
}

if (!isCodeLLaMA(target_model)) {
input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n";
}

LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, false);
}
Expand Down
25 changes: 15 additions & 10 deletions llm/application/voicechat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ int main(int argc, char* argv[]) {
std::string output;
std::string model_input;

std::system("./application/sts_utils/listen");
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
std::system("rm tmpfile");
result = std::system("rm tmpfile");
(void)result;
std::cout << input << std::endl;

if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.")
Expand All @@ -145,10 +146,11 @@ int main(int argc, char* argv[]) {
std::string output;
std::string model_input;

std::system("./application/sts_utils/listen");
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
std::system("rm tmpfile");
result = std::system("rm tmpfile");
(void)result;
std::cout << input << std::endl;

if (input == " quit" || input == " Quit" || input == " Quit." || input == " quit.")
Expand Down Expand Up @@ -191,10 +193,11 @@ int main(int argc, char* argv[]) {
std::string output;
std::string model_input;

std::system("./application/sts_utils/listen");
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
std::system("rm tmpfile");
result = std::system("rm tmpfile");
(void)result;
std::vector<int> input_ids = encoder.encode(input);
std::string decoded = encoder.decode(input_ids);
std::cout << input << std::endl;
Expand All @@ -215,10 +218,11 @@ int main(int argc, char* argv[]) {
std::string output;
std::string model_input;

std::system("./application/sts_utils/listen");
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
std::system("rm tmpfile");
result = std::system("rm tmpfile");
(void)result;
std::vector<int> input_ids = encoder.encode(input);
std::string decoded = encoder.decode(input_ids);
std::cout << input << std::endl;
Expand All @@ -238,10 +242,11 @@ int main(int argc, char* argv[]) {
std::string output;
std::string model_input;

std::system("./application/sts_utils/listen");
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
std::system("rm tmpfile");
result = std::system("rm tmpfile");
(void)result;
std::vector<int> input_ids = encoder.encode(input);
std::string decoded = encoder.decode(input_ids);
std::cout << input << std::endl;
Expand Down
24 changes: 15 additions & 9 deletions llm/include/Generate.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Adapted from llama.cpp:
*/

#ifndef GENERATE_H
#define GENERATE_H

#include <algorithm>
#include <cassert>
#include <cstdio>
Expand All @@ -24,7 +27,8 @@ Adapted from llama.cpp:
#include "operators.h"
#include "utils.h"

inline std::mt19937 OPT_rng;
// inline std::mt19937 OPT_rng; // inline variables are only available with ‘-std=c++17’ or ‘-std=gnu++17’
static std::mt19937 OPT_rng;

typedef struct OPT_token_data {
int id; // token id
Expand All @@ -39,14 +43,14 @@ typedef struct OPT_token_data_array {
} OPT_token_data_array;

struct opt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = 1; // TODO: fix this
int32_t n_predict = 128; // new tokens to predict
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_vocab = 50272; // vocabulary size
int32_t seed = -1; // RNG seed
int32_t n_threads = 1; // TODO: fix this
int32_t n_predict = 128; // new tokens to predict
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_vocab = 50272; // vocabulary size

// sampling parameters
std::unordered_map<int, float> logit_bias; // logit bias for specific tokens
Expand Down Expand Up @@ -97,3 +101,5 @@ std::vector<int> OPTGenerate(void* model, int model_type, std::vector<int> input
enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4 };
std::string LLaMAGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
std::string voc_path, bool interactive, bool voicechat);

#endif // GENERATE_H
33 changes: 21 additions & 12 deletions llm/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,32 @@ struct model_config {
int hidden_dim;
int vocsize;
int padding_idx;
int qk; // group size
float rms_norm_eps; // RMSNorm epsilon (only for LLaMA models)

model_config() : model_config(1, 12, 12, 512, 768, 3072, 50272, 1) {}
model_config() : model_config(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6) {}
model_config(int batch, int num_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize,
int padding_idx)
int padding_idx, float rms_norm_eps)
: batch(batch),
num_heads(num_heads),
num_layers(num_layers),
max_sqlen(max_sqlen),
embed_dim(embed_dim),
hidden_dim(hidden_dim),
vocsize(vocsize),
padding_idx(padding_idx) {}
padding_idx(padding_idx),
rms_norm_eps(rms_norm_eps) {}
};

enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B };
enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B };
enum { FP32, QINT8, INT4 };

const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1);
const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1);
const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1);
const struct model_config llama_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1);
const struct model_config llama_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1);
const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0);
const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1, 0);
const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1, 0);
const struct model_config llama_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);
const struct model_config llama_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6);
const struct model_config codellama_7B(1, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5);
const struct model_config codellama_13B(1, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5);
static struct model_config get_opt_model_config(int choise) {
struct model_config ret;
switch (choise) {
Expand All @@ -46,12 +49,18 @@ static struct model_config get_opt_model_config(int choise) {
case OPT_6_7B:
ret = opt_6_7B;
break;
case LLaMA_7B:;
case LLaMA_7B:
ret = llama_7B;
break;
case LLaMA_13B:;
case LLaMA_13B:
ret = llama_13B;
break;
case CodeLLaMA_7B:
ret = codellama_7B;
break;
case CodeLLaMA_13B:
ret = codellama_13B;
break;
default:
throw("Unsupported model choice.");
break;
Expand Down
1 change: 1 addition & 0 deletions llm/include/nn_modules/Fp32llamaDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Fp32llamaDecoder {
struct Fp32llamaDecoder_output forward(const struct Fp32llamaDecoder_input& input);
Embedding embed_tokens;
LlamaRMSNorm norm;
float rms_norm_eps;
int voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
std::vector<Fp32llamaDecoderLayer> layers;
std::string profile_name = "Fp32llamaDecoder";
Expand Down
1 change: 1 addition & 0 deletions llm/include/nn_modules/Fp32llamaDecoderLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Fp32llamaDecoderLayer {
struct Fp32llamaDecoderLayer_output forward(const struct Fp32llamaDecoderLayer_input &input);

int embed_dim, num_attention_heads, hidden_dim, layer_idx;
float rms_norm_eps;
LlamaRMSNorm input_layernorm, post_attention_layernorm;
Linear_FP gate_proj, down_proj, up_proj;
Fp32llamaAttention attn;
Expand Down
Loading

0 comments on commit 3aa79df

Please sign in to comment.