Skip to content

Commit

Permalink
Support LLaVA (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 authored Feb 2, 2024
1 parent cc9dfb6 commit 73545ea
Show file tree
Hide file tree
Showing 44 changed files with 10,981 additions and 86 deletions.
9 changes: 5 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
.vscode/

assets/
models/
*.bin
!llama_vocab.bin
!starcoder_vocab.bin
Expand All @@ -31,8 +32,8 @@ voicechat
profile_*
!profile_*.cc
libtorch/
checkpoints/

llm/chat
llm/output.wav
llm/tmpfile
llm/TTS
output.wav
tmpfile
TTS/
152 changes: 141 additions & 11 deletions llm/application/chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ std::map<std::string, int> model_config = {
{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B},
{"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B},
{"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B},
{"StarCoder", StarCoder_15_5B}, {"StarCoder_15.5B", StarCoder_15_5B}
{"StarCoder", StarCoder_15_5B}, {"StarCoder_15.5B", StarCoder_15_5B}, {"LLaVA_7B", LLaVA_7B}, {"Clip_ViT_Large", Clip_ViT_Large}
};

std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"},
Expand All @@ -23,7 +23,9 @@ std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"}
{"CodeLLaMA_7B_Instruct", "models/CodeLLaMA_7B_Instruct"},
{"CodeLLaMA_13B_Instruct", "models/CodeLLaMA_13B_Instruct"},
{"StarCoder", "models/StarCoder"},
{"StarCoder_15.5B", "models/StarCoder"}
{"StarCoder_15.5B", "models/StarCoder"},
{"LLaVA_7B", "models/LLaVA_7B"},
{"Clip_ViT_Large", "models/CLIP_ViT_Large"}
};

std::map<std::string, int> data_format_list = {
Expand All @@ -33,7 +35,6 @@ std::map<std::string, int> data_format_list = {
bool isLLaMA(std::string s) {
std::string LLaMA_prefix = "LLaMA";
std::string CodeLLaMA_prefix = "CodeLLaMA";

if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s.substr(0, CodeLLaMA_prefix.size()) == CodeLLaMA_prefix || s == "7b" || s == "13b")
return true;
else
Expand All @@ -42,7 +43,6 @@ bool isLLaMA(std::string s) {

bool isCodeLLaMA(std::string s) {
std::string CodeLLaMA_prefix = "CodeLLaMA";

if (s.substr(0, CodeLLaMA_prefix.size()) == CodeLLaMA_prefix)
return true;
else
Expand All @@ -51,13 +51,20 @@ bool isCodeLLaMA(std::string s) {

bool isStarCoder(std::string s) {
std::string StarCoder_prefix = "StarCoder";

if (s.substr(0, StarCoder_prefix.size()) == StarCoder_prefix)
return true;
else
return false;
}

bool isLLaVA(std::string s) {
std::string LLaVA_prefix = "LLaVA";
if (s.substr(0, LLaVA_prefix.size()) == LLaVA_prefix)
return true;
else
return false;
}

bool convertToBool(const char* str) {
if (strcmp(str, "true") == 0 || strcmp(str, "1") == 0) {
return true;
Expand All @@ -71,7 +78,7 @@ bool convertToBool(const char* str) {
}
}

int NUM_THREAD = 8;
int NUM_THREAD = 5;

int main(int argc, char* argv[]) {
bool use_voicechat = false;
Expand All @@ -92,19 +99,26 @@ int main(int argc, char* argv[]) {
std::string target_model = "LLaMA2_7B_chat";
std::string target_data_format = "INT4";
bool instruct = true;
std::string img_path = "images/monalisa.jpg";
Profiler::getInstance().for_demo = true;

std::cout << "TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine" << std::endl;
if (argc >= 3 && argc <= 5) {
auto target_str = argv[1];
target_model = argv[1];

if (argc >= 4) {
NUM_THREAD = atoi(argv[3]);
}
if (argc == 5) {
instruct = convertToBool(argv[4]);
if (isCodeLLaMA(target_model)) {
instruct = convertToBool(argv[4]);
}
else if (isLLaVA(target_model)) {
img_path = argv[4];
}
}

auto target_str = argv[1];
target_model = argv[1];
if (model_config.count(target_model) == 0) {
std::cerr << "Model config:" << target_str << " unsupported" << std::endl;
std::cerr << "Please select one of the following:";
Expand Down Expand Up @@ -161,6 +175,13 @@ int main(int argc, char* argv[]) {
else
std::cout << "Using data format: " << target_data_format << std::endl;
}
else if (isLLaVA(target_model)) {
std::cout << "Using model: " + target_model << std::endl;
if (target_data_format == "INT4" || target_data_format == "int4")
std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl;
else
std::cout << "Using data format: " << target_data_format << std::endl;
}
else { // OPT
target_model = "OPT6.7B";
target_data_format = "INT8";
Expand Down Expand Up @@ -225,7 +246,6 @@ int main(int argc, char* argv[]) {
input = " </s> <s>[INST] " + input + " [/INST] ";
}
}

}
else {
if (isCodeLLaMA(target_model)) {
Expand Down Expand Up @@ -297,7 +317,7 @@ int main(int argc, char* argv[]) {
}
} else {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for LLaMA7B." << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for LLaMA_7B." << std::endl;
}
} else if (isStarCoder(target_model)) {
int format_id = data_format_list[target_data_format];
Expand Down Expand Up @@ -349,6 +369,116 @@ int main(int argc, char* argv[]) {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for StarCoder." << std::endl;
}
} else if (isLLaVA(target_model)) {
int format_id = data_format_list[target_data_format];

// Load model
std::cout << "Loading model... " << std::flush;
std::string clip_m_path = model_path["Clip_ViT_Large"];
std::string llama_m_path = model_path[target_model];

int clip_model_id = model_config["Clip_ViT_Large"];
int llama_model_id = model_config[target_model];

#ifdef MODEL_PREFIX
llama_m_path = MODEL_PREFIX + llama_m_path;
#endif

struct opt_params generation_config;
generation_config.n_predict = 512;
generation_config.repeat_penalty = 1.1f;
generation_config.temp = 0.2f;
generation_config.n_vocab = 32000;

int prompt_iter = 0;

if (format_id == FP32) {
Fp32CLIPVisionTransformer clip_model = Fp32CLIPVisionTransformer(clip_m_path, get_opt_model_config(clip_model_id));
Fp32LlamaForCausalLM llama_model = Fp32LlamaForCausalLM(llama_m_path, get_opt_model_config(llama_model_id));

// Get input from the user
while (true) {
std::string input;
if (prompt_iter == 1) {
std::cout << "Finished!" << std::endl;
}
if (prompt_iter > 0) {
if (use_voicechat){
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
result = std::system("rm tmpfile");
(void)result;
std::cout << input << std::endl;
} else {
std::cout << "USER: ";
std::getline(std::cin, input);
}
if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
break;
std::cout << "ASSISTANT: " << std::endl;
}

if (prompt_iter == 0) {
input = "This is a chat between a user and an assistant.\n\n### USER: ";
prompt_iter += 1;
}
else if (prompt_iter == 1) {
input = "\n" + input + "\n### ASSISTANT:";
prompt_iter += 1;
}
else {
input = "### USER: " + input + "\n### ASSISTANT: \n";
}

LLaVAGenerate(llama_m_path, &llama_model, clip_m_path, &clip_model, LLaVA_FP32, input, img_path, generation_config, "models/llama_vocab.bin", true, false);
}
} else if (format_id == INT4) {
Fp32CLIPVisionTransformer clip_model = Fp32CLIPVisionTransformer(clip_m_path, get_opt_model_config(clip_model_id));
llama_m_path = "INT4/" + llama_m_path;
Int4LlamaForCausalLM llama_model = Int4LlamaForCausalLM(llama_m_path, get_opt_model_config(llama_model_id));

// Get input from the user
while (true) {
if (prompt_iter == 1) {
std::cout << "Finished!" << std::endl;
}
std::string input;
if (prompt_iter > 0) {
if (use_voicechat){
int result = std::system("./application/sts_utils/listen");
std::ifstream in("tmpfile");
std::getline(in, input);
result = std::system("rm tmpfile");
(void)result;
std::cout << input << std::endl;
} else {
std::cout << "USER: ";
std::getline(std::cin, input);
}
if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
break;
std::cout << "ASSISTANT: " << std::endl;
}

if (prompt_iter == 0) {
input = "This is a chat between a user and an assistant.\n\n### USER: ";
prompt_iter += 1;
}
else if (prompt_iter == 1) {
input = "\n" + input + "\n### ASSISTANT:";
prompt_iter += 1;
}
else {
input = "### USER: " + input + "\n### ASSISTANT: \n";
}

LLaVAGenerate(llama_m_path, &llama_model, clip_m_path, &clip_model, LLaVA_INT4, input, img_path, generation_config, "models/llama_vocab.bin", true, use_voicechat);
}
} else {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for LLaVA_7B." << std::endl;
}
} else { // OPT
#ifdef QM_CUDA
printf("OPT is not supported with CUDA backend yet.");
Expand Down
7 changes: 6 additions & 1 deletion llm/include/Generate.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Adapted from llama.cpp:
#include "Int4OPTForCausalLM.h"
#include "Int4llamaForCausalLM.h"
#include "Int4GPTBigCodeForCausalLM.h"
#include "Fp32CLIPVisionTransformer.h"
#include "OPTForCausalLM.h"
#include "OPTTokenizer.h"
#include "operators.h"
Expand Down Expand Up @@ -100,11 +101,15 @@ std::vector<int> OPTGenerate(void* model, int model_type, std::vector<int> input
const struct opt_params generation_config, Encoder* encoder = NULL,
bool interactive = false, bool voicechat = false);

enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4, StarCoder_FP32, StarCoder_INT4 };
enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4, StarCoder_FP32, StarCoder_INT4, LLaVA_FP32, LLaVA_INT4 };
std::string LLaMAGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
std::string voc_path, bool interactive, bool voicechat);

std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config,
std::string voc_path, bool interactive);

std::string LLaVAGenerate(std::string llama_param_path, void* llama_model_ptr, std::string clip_param_path, void* clip_model_ptr, int model_type,
std::string text, std::string img_path, const struct opt_params generation_config, std::string voc_path, bool interactive,
bool voicechat);

#endif // GENERATE_H
100 changes: 100 additions & 0 deletions llm/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,106 @@ class Matrix3D {
Matrix3D() { m_data = NULL; }
};

template <typename T>
class Matrix4D {
public:
Matrix4D(T *data, int dim_w, int dim_x, int dim_y, int dim_z) :
m_data(data), m_dim_w(dim_w), m_dim_x(dim_x), m_dim_y(dim_y), m_dim_z(dim_z) {}

#if defined(__CUDACC__)
__host__ __device__ T &operator()(int w, int x, int y, int z) {
return m_data[w * m_dim_x * m_dim_y * m_dim_z + x * m_dim_y * m_dim_z + y * m_dim_z + z];
}

__host__ __device__ const T &operator()(int w, int x, int y, int z) const {
return m_data[w * m_dim_x * m_dim_y * m_dim_z + x * m_dim_y * m_dim_z + y * m_dim_z + z];
}
#else
T &operator()(int w, int x, int y, int z) {
if (w < 0 || w >= m_dim_w || x < 0 || x >= m_dim_x || y < 0 || y >= m_dim_y || z < 0 || z >= m_dim_z) {
printf("%d, %d, %d, %d\n", w, x, y, z);
printf("%d, %d, %d, %d\n", m_dim_w, m_dim_x, m_dim_y, m_dim_z);
throw std::out_of_range("Matrix4D: Indices out of range.");
}
return m_data[w * m_dim_x * m_dim_y * m_dim_z + x * m_dim_y * m_dim_z + y * m_dim_z + z];
}

const T &operator()(int w, int x, int y, int z) const {
if (w < 0 || w >= m_dim_w || x < 0 || x >= m_dim_x || y < 0 || y >= m_dim_y || z < 0 || z >= m_dim_z) {
printf("%d, %d, %d, %d\n", w, x, y, z);
printf("%d, %d, %d, %d\n", m_dim_w, m_dim_x, m_dim_y, m_dim_z);
throw std::out_of_range("Matrix4D: Indices out of range.");
}
return m_data[w * m_dim_x * m_dim_y * m_dim_z + x * m_dim_y * m_dim_z + y * m_dim_z + z];
}
#endif

bool operator==(const Matrix4D<T> &other) const {
if (m_dim_w != other.m_dim_w || m_dim_x != other.m_dim_x || m_dim_y != other.m_dim_y || m_dim_z != other.m_dim_z) {
return false;
}

for (int w = 0; w < m_dim_w; ++w) {
for (int x = 0; x < m_dim_x; ++x) {
for (int y = 0; y < m_dim_y; ++y) {
for (int z = 0; z < m_dim_z; ++z) {
if ((*this)(w, x, y, z) != other(w, x, y, z)) {
return false;
}
}
}
}
}

return true;
}

#if defined(__CUDACC__)
__host__ __device__ int length() const { return m_dim_w * m_dim_x * m_dim_y * m_dim_z; }
#else
int length() const { return m_dim_w * m_dim_x * m_dim_y * m_dim_z; }
#endif

T sum() const {
T sum = 0;
for (int i = 0; i < this->length(); i++) {
sum += this->m_data[i];
}
return sum;
}
T sum(int size) const {
T sum = 0;
for (int i = 0; i < size; i++) {
sum += this->m_data[i];
}
return sum;
}

T sum(int size, int start_idx) const {
T sum = 0;
for (int i = 0; i < size; i++) {
sum += this->m_data[start_idx + i];
}
return sum;
}

void load(const char *path) {
std::ifstream infile(path, std::ios::binary | std::ios::in);
if (infile.fail()) {
std::cout << strerror(errno) << ": " << path << std::endl;
throw("Expected error...");
} else {
infile.read(reinterpret_cast<char *>(this->m_data), this->length() * sizeof(T));
infile.close();
}
}
T *m_data;
int m_dim_w, m_dim_x, m_dim_y, m_dim_z;

// Default constructor
Matrix4D() { m_data = NULL; }
};

static inline void debug_info(std::string s) {
#ifdef DEBUG
std::cout << s << std::endl;
Expand Down
Loading

0 comments on commit 73545ea

Please sign in to comment.