Skip to content

Commit

Permalink
Update models (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 authored Apr 18, 2024
1 parent 572492b commit d46a858
Show file tree
Hide file tree
Showing 10 changed files with 382 additions and 21 deletions.
149 changes: 143 additions & 6 deletions llm/application/chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
#include "interface.h"

std::map<std::string, int> model_config = {
{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B},
{"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B},
{"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B},
{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B},
{"7b", LLaMA_7B}, {"LLaMA2_7B_chat", LLaMA_7B}, {"13b", LLaMA_13B}, {"LLaMA2_13B_chat", LLaMA_13B},
{"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B},
{"StarCoder", StarCoder_15_5B}, {"StarCoder_15.5B", StarCoder_15_5B}, {"LLaVA_7B", LLaVA_7B}, {"LLaVA_13B", LLaVA_13B},
{"VILA_7B", VILA_7B}, {"VILA_13B", VILA_13B}, {"Clip_ViT_Large", Clip_ViT_Large}
{"VILA_2.7B", VILA_2_7B}, {"VILA_7B", VILA_7B}, {"VILA_13B", VILA_13B}, {"Clip_ViT_Large", Clip_ViT_Large},
{"Mistral_7B", Mistral_7B}
};

std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"},
Expand All @@ -28,9 +29,11 @@ std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"}
{"StarCoder_15.5B", "models/StarCoder"},
{"LLaVA_7B", "models/LLaVA_7B"},
{"LLaVA_13B", "models/LLaVA_13B"},
{"VILA_2.7B", "models/VILA_2.7B"},
{"VILA_7B", "models/VILA_7B"},
{"VILA_13B", "models/VILA_13B"},
{"Clip_ViT_Large", "models/CLIP_ViT_Large"}
{"Clip_ViT_Large", "models/CLIP_ViT_Large"},
{"Mistral_7B", "models/Mistral_7B"},
};

std::map<std::string, int> data_format_list = {
Expand Down Expand Up @@ -78,6 +81,14 @@ bool isVILA(std::string s) {
return false;
}

bool isMistral(std::string s) {
std::string Mistral_prefix = "Mistral";
if (s.substr(0, Mistral_prefix.size()) == Mistral_prefix)
return true;
else
return false;
}

bool convertToBool(const char* str) {
if (strcmp(str, "true") == 0 || strcmp(str, "1") == 0) {
return true;
Expand Down Expand Up @@ -177,7 +188,7 @@ int main(int argc, char* argv[]) {

auto data_format_input = "INT4";
} else {
if (isLLaMA(target_model) || isCodeLLaMA(target_model) || isStarCoder(target_model) || isLLaVA(target_model) || isVILA(target_model)) {
if (isLLaMA(target_model) || isCodeLLaMA(target_model) || isStarCoder(target_model) || isLLaVA(target_model) || isVILA(target_model) || isMistral(target_model)) {
std::cout << "Using model: " + target_model << std::endl;
if (target_data_format == "INT4" || target_data_format == "int4")
std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl;
Expand Down Expand Up @@ -719,6 +730,132 @@ int main(int argc, char* argv[]) {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for VILA_7B." << std::endl;
}
} else if (isMistral(target_model)) {
int format_id = data_format_list[target_data_format];

// Voicechat instructions
if (use_voicechat) {
std::cout << "You are using the TinyVoiceChat." << std::endl;
std::cout << "*Usage instructions*" << std::endl;
std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl;
std::cout << "- Please start speaking after \"USER: [Start speaking]\" shows up." << std::endl;
std::cout << "- Please press `Ctrl+C` multiple times to exit the program." << std::endl << std::endl;
}

// Load model
std::cout << "Loading model... " << std::flush;
int model_id = model_config[target_model];
std::string m_path = model_path[target_model];

#ifdef MODEL_PREFIX
m_path = MODEL_PREFIX + m_path;
#endif

struct opt_params generation_config;
generation_config.n_predict = 512;
generation_config.repeat_penalty = 1.0f;
generation_config.temp = 0.3f;
generation_config.n_vocab = 32000;

bool first_prompt = true;

if (format_id == FP32) {
Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id));
std::cout << "Finished!" << std::endl << std::endl;

// Get input from the user
while (true) {
std::string input;
if (use_voicechat) {
// Set prompt color
set_print_yellow();
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
// set user input color
set_print_red();
std::getline(in, input);
result = std::system("rm tmpfile");
(void)result;
std::cout << input << std::endl;
// reset color
set_print_reset();
} else {
// Set prompt color
set_print_yellow();
std::cout << "USER: ";
// set user input color
set_print_red();
std::getline(std::cin, input);
// reset color
set_print_reset();
}
if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
break;
if (instruct) {
std::cout << "ASSISTANT: ";
}

if (first_prompt) {
input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
first_prompt = false;
}
else {
input = "### Human: " + input + "\n### Assistant: \n";
}

LLaMAGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama_vocab.bin", true, false);
}
} else if (format_id == INT4) {
m_path = "INT4/" + m_path;
Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id));
std::cout << "Finished!" << std::endl << std::endl;

// Get input from the user
while (true) {
std::string input;
if (use_voicechat) {
// Set prompt color
set_print_yellow();
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
// set user input color
set_print_red();
std::getline(in, input);
result = std::system("rm tmpfile");
(void)result;
std::cout << input << std::endl;
// reset color
set_print_reset();
} else {
// Set prompt color
set_print_yellow();
std::cout << "USER: ";
// set user input color
set_print_red();
std::getline(std::cin, input);
// reset color
set_print_reset();
}
if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
break;
if (instruct) {
std::cout << "ASSISTANT: ";
}

if (first_prompt) {
input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
first_prompt = false;
}
else {
input = "### Human: " + input + "\n### Assistant: \n";
}

LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, use_voicechat);
}
} else {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for Mistral-7B." << std::endl;
}
} else { // OPT
#ifdef QM_CUDA
printf("OPT is not supported with CUDA backend yet.");
Expand Down
10 changes: 9 additions & 1 deletion llm/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct model_config {
mmproj_dim(mmproj_dim) {}
};

enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_7B, VILA_13B, Clip_ViT_Large };
enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_2_7B, VILA_7B, VILA_13B, Clip_ViT_Large, Mistral_7B};
enum { FP32, QINT8, INT4 };

const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0);
Expand All @@ -61,9 +61,11 @@ const struct model_config codellama_13B(1, 40, 40, 2048, 5120, 13824, 32016, 1,
const struct model_config starcoder_15_5B(1, 48, 40, 2048, 6144, 24576, 49152, 1, 0);
const struct model_config llava_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
const struct model_config llava_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
const struct model_config vila_2_7B(1, 20, 32, 2048, 2560, 6912, 32000, 1, 1e-5);
const struct model_config vila_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
const struct model_config vila_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
const struct model_config clip_vit_large(1, 16, 23, 2048, 1024, 4096, 0, 1, 0, 336, 14, 768, 4096); // llava's and vila's clip model uses only 23 layers out of 24
const struct model_config mistral_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);

static struct model_config get_opt_model_config(int choise) {
struct model_config ret;
Expand Down Expand Up @@ -98,6 +100,9 @@ static struct model_config get_opt_model_config(int choise) {
case LLaVA_13B:
ret = llava_13B;
break;
case VILA_2_7B:
ret = vila_2_7B;
break;
case VILA_7B:
ret = vila_7B;
break;
Expand All @@ -107,6 +112,9 @@ static struct model_config get_opt_model_config(int choise) {
case Clip_ViT_Large:
ret = clip_vit_large;
break;
case Mistral_7B:
ret = mistral_7B;
break;
default:
throw("Unsupported model choice.");
break;
Expand Down
2 changes: 2 additions & 0 deletions llm/mistral
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# !/bin/bash
./chat Mistral_7B INT4 5
24 changes: 24 additions & 0 deletions llm/tools/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
"url": "https://www.dropbox.com/scl/fi/0uroj92srmo6z4ib4xr43/LLaVA_13B_CLIP_ViT-L.zip?rlkey=34x3r8yfh8ztiqbisg5z64hmd&dl=1", # noqa: E501
"md5sum": "3d4afd8051c779c014ba69aec7886961",
},
"VILA_2.7B_CLIP_ViT-L_fp32": {
"url": "https://www.dropbox.com/scl/fi/f1vfgtwhr88yhpd8aabwp/VILA_2.7B_CLIP_ViT-L.zip?rlkey=qesrenbana7elbwk0szu53nzj&dl=1", # noqa: E501
"md5sum": "48455c57594ea1a6b44496fda3877c75",
},
"VILA_7B_CLIP_ViT-L_fp32": {
"url": "https://www.dropbox.com/scl/fi/4oi3g3uypx2hgmw6hkahy/VILA_7B_CLIP_ViT-L.zip?rlkey=0393uexrzh4ofevkr0yaldefd&dl=1", # noqa: E501
"md5sum": "d2201fd2853da56c3e2b4b7043b1d37a",
Expand All @@ -66,6 +70,10 @@
"url": "https://www.dropbox.com/scl/fi/vc1956by8v275t0ol6vw5/StarCoder_15.5B.zip?rlkey=aydnpd9w9jhgtlfqo5krkd0yx&dl=1",
"md5sum": "e3e9301866f47ab84817b46467ac49f6",
},
"Mistral_7B_v0.2_Instruct_fp32": {
"url": "",
"md5sum": "",
},
}

Qmodels = {
Expand Down Expand Up @@ -110,6 +118,10 @@
"url": "https://www.dropbox.com/scl/fi/hzqrq72xrk2uwupkktmpk/LLaVA_13B_CLIP_ViT-L.zip?rlkey=zit6e00fic7vdygrlg0cybivq&dl=1", # noqa: E501
"md5sum": "fec078d99449df73c0f1236377b53eb3",
},
"VILA_2.7B_awq_int4_CLIP_ViT-L": {
"url": "https://www.dropbox.com/scl/fi/pc9vohr7dyde2k3pbhai7/VILA_2.7B_CLIP_ViT-L.zip?rlkey=5dfayissvbj5unuuhzxzipaxk&dl=1",
"md5sum": "177b1a58707355c641da4f15fb3c7a71",
},
"VILA_7B_awq_int4_CLIP_ViT-L": {
"url": "https://www.dropbox.com/scl/fi/9axqkn8e95p7zxy97ixjx/VILA_7B_CLIP_ViT-L.zip?rlkey=mud5qg3rr3yec12qcvsltca5w&dl=1", # noqa: E501
"md5sum": "29aa8688b59dfde21d0b0b0b94b0ac27",
Expand All @@ -118,6 +130,10 @@
"url": "https://www.dropbox.com/scl/fi/fe4dkrnzc25bt166w6bby/StarCoder_15.5B.zip?rlkey=ml1x96uep2k03z78ci7s1c0yb&dl=1",
"md5sum": "0f16236c0aec0b32b553248cc78b8caf",
},
"Misitral_7B_v0.2_Instruct_awq_int4": {
"url": "https://www.dropbox.com/scl/fi/ssr6bn9a6l9d4havu04om/Mistral_7B_v0.2_Instruct.zip?rlkey=73yqj6pw300o3izwr43etjqkr&dl=1",
"md5sum": "ee96bcdee3d09046719f7d31d7f023f4",
},
},
"QM_x86": {
"LLaMA_7B_awq_int4": {
Expand Down Expand Up @@ -160,6 +176,10 @@
"url": "https://www.dropbox.com/scl/fi/7u8wihmvvr9jlio2rjw2f/LLaVA_13B_CLIP_ViT-L.zip?rlkey=bimpaaemyb3rp30wgkznytkuv&dl=1", # noqa: E501
"md5sum": "f22e8d5d754c64f0aa34d5531d3059bc",
},
"VILA_2.7B_awq_int4_CLIP_ViT-L": {
"url": "https://www.dropbox.com/scl/fi/gldsl2fh6g5f0fvwnf8kq/VILA_2.7B_CLIP_ViT-L.zip?rlkey=oj2y01xt4vwtbg7vdg4g1btxd&dl=1",
"md5sum": "e83ff23d58a0b91c732a9e3928aa344a",
},
"VILA_7B_awq_int4_CLIP_ViT-L": {
"url": "https://www.dropbox.com/scl/fi/25cw3ob1oar6p3maxg6lq/VILA_7B_CLIP_ViT-L.zip?rlkey=b4vr29gvsdxlj9bg3i5cwsnjn&dl=1", # noqa: E501
"md5sum": "7af675198ec3c73d440ccc96b2722813",
Expand All @@ -168,6 +188,10 @@
"url": "https://www.dropbox.com/scl/fi/86o2cblncmfd3xvuyyaqc/StarCoder_15.5B.zip?rlkey=2gswnyq9xihencaduddylpb2k&dl=1",
"md5sum": "48383ce0bf01b137069e3612cab8525f",
},
"Mistral_7B_v0.2_Instruct_awq_int4": {
"url": "https://www.dropbox.com/scl/fi/2f7djt8z8lhkd60velfb3/Mistral_7B_v0.2_Instruct.zip?rlkey=gga6mh8trxf6durck4y4cyihe&dl=1",
"md5sum": "22e8692d7481807b4151f28c54f112da",
},
},
"QM_CUDA": {
"LLaMA2_7B_chat_awq_int4": {
Expand Down
8 changes: 4 additions & 4 deletions llm/tools/llama_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,17 @@ def main():
if args.model.endswith(".pt"):
if args.model.split("/")[-1].lower().startswith("llama-2"):
if args.model.split("-")[2].lower() == "7b":
print("Loading LLaMA 7B model...");
print("Loading LLaMA 7B model...")
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", torch_dtype=torch.float16)
elif args.model.split("-")[2].lower() == "13b":
print("Loading LLaMA 13B model...");
print("Loading LLaMA 13B model...")
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-13b-hf", torch_dtype=torch.float16)
elif args.model.split("/")[-1].lower().startswith("codellama"):
if args.model.split("-")[1].lower() == "7b":
print("Loading CodaLLaMA 7B model...");
print("Loading CodaLLaMA 7B model...")
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-Instruct-hf", torch_dtype=torch.float16)
elif args.model.split("-")[1].lower() == "13b":
print("Loading CodaLLaMA 13B model...");
print("Loading CodaLLaMA 13B model...")
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-13b-Instruct-hf", torch_dtype=torch.float16)
else:
print("Model not supported.")
Expand Down
Loading

0 comments on commit d46a858

Please sign in to comment.