Skip to content

Commit

Permalink
fix gpu and update demo ui with shortcut (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
meenchen authored Aug 23, 2023
1 parent 989f554 commit 7c832a2
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 23 deletions.
38 changes: 32 additions & 6 deletions transformer/application/chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,30 @@

std::map<std::string, int> model_config = {
{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B},
{"LLaMA_7B_AWQ", LLaMA_7B}, {"LLaMA_7B_2_chat", LLaMA_7B}, {"LLaMA_13B_2_chat", LLaMA_13B}};
{"LLaMA_7B_AWQ", LLaMA_7B}, {"LLaMA_7B_2_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA_13B_2_chat", LLaMA_13B}, {"13b", LLaMA_13B}};

std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"},
{"OPT_1.3B", "models/OPT_1.3B"},
{"OPT_6.7B", "models/OPT_6.7B"},
{"LLaMA_7B", "models/LLaMA_7B"},
{"LLaMA_7B_AWQ", "models/LLaMA_7B_AWQ"},
{"LLaMA_7B_2_chat", "models/LLaMA_7B_2_chat"},
{"LLaMA_13B_2_chat", "models/LLaMA_13B_2_chat"}};
{"7b", "models/LLaMA_7B_2_chat"},
{"LLaMA_13B_2_chat", "models/LLaMA_13B_2_chat"},
{"13b", "models/LLaMA_13B_2_chat"}};

std::map<std::string, int> data_format_list = {
{"FP32", FP32},
{"INT8", INT8},
{"INT4", INT4},
{"int4", INT4},
{"fp32", FP32},
};

bool isLLaMA(std::string s) {
std::string LLaMA_prefix = "LLaMA";

if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix)
if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix || s == "7b" || s == "13b")
return true;
else
return false;
Expand All @@ -33,9 +37,11 @@ bool isLLaMA(std::string s) {
int main(int argc, char* argv[]) {
std::string target_model = "LLaMA_7B_2_chat";
std::string target_data_format = "INT4";
Profiler::getInstance().for_demo = true;

if (argc == 3) {
auto target_str = argv[1];
target_model = argv[1];
if (model_config.count(target_model) == 0) {
std::cerr << "Model config:" << target_str << " unsupported" << std::endl;
std::cerr << "Please select one of the following:";
Expand All @@ -46,12 +52,11 @@ int main(int argc, char* argv[]) {
throw("Unsupported model\n");
}
std::cout << "Model: " << argv[1] << " selected" << std::endl;
target_model = argv[1];

auto data_format_input = argv[2];
if (data_format_list.count(data_format_input) == 0) {
std::cerr << "Data format:" << data_format_input << " unsupported" << std::endl;
std::cerr << "Please select one of the following:";
std::cerr << "Please select one of the following: ";
for (const auto& k : data_format_list) {
std::cerr << k.first << ", ";
}
Expand All @@ -60,7 +65,23 @@ int main(int argc, char* argv[]) {
}
std::cout << "Data format: " << argv[2] << " selected" << std::endl;
target_data_format = argv[2];
} else {
} else if (argc == 2){
auto target_str = argv[1];
target_model = argv[1];
if (model_config.count(target_model) == 0) {
std::cerr << "Model config:" << target_str << " unsupported" << std::endl;
std::cerr << "Please select one of the following: ";
for (const auto& k : model_config) {
std::cerr << k.first << ", ";
}
std::cerr << std::endl;
throw("Unsupported model\n");
}
std::cout << "Model: " << argv[1] << " selected" << std::endl;

auto data_format_input = "INT4";
}
else {
if (isLLaMA(target_model)) {
std::cout << "Using model: " + target_model << std::endl;
std::cout << "Using LLaMA's default data format: " + target_data_format << std::endl;
Expand Down Expand Up @@ -118,6 +139,10 @@ int main(int argc, char* argv[]) {
std::cerr << "At this time, we only support FP32 and INT4 for LLaMA7B." << std::endl;
}
} else { // OPT
#ifdef QM_CUDA
printf("OPT is not supported with CUDA backend yet.");
exit(-1);
#else
// Load model
std::cout << "Loading model... " << std::flush;
int model_id = model_config[target_model];
Expand Down Expand Up @@ -175,5 +200,6 @@ int main(int argc, char* argv[]) {
std::vector<int> generated_ids =
OPTGenerate(&model, OPT_INT4, input_ids, generation_config, &encoder, true);
}
#endif // QN_CUDA
}
};
4 changes: 4 additions & 0 deletions transformer/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@
"url": "https://www.dropbox.com/scl/fi/du8rfgexkk4b2xp9j6yrn/LLaMA_7B_2_chat.zip?rlkey=2nao2sh4hi3t1dhltsoae2muw&dl=1", # noqa: E501
"md5sum": "d0b1d11e498ac7d0a2e90348e946a7f5",
},
"LLaMA_13B_2_chat": {
"url": "https://www.dropbox.com/scl/fi/fes1l27b9kv4dn4h0qjzu/LLaMA_13B_2_chat.zip?rlkey=u1j2kt96xpj764zkj1v87gw6u&dl=1", # noqa: E501
"md5sum": "802c81d86b6393aff3e93326e5b58f7f",
},
},
"INT8": {
"OPT_125m": {
Expand Down
40 changes: 27 additions & 13 deletions transformer/include/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class Profiler {
public:
bool for_demo = false;
static Profiler& getInstance() {
static Profiler instance;
return instance;
Expand Down Expand Up @@ -35,20 +36,33 @@ class Profiler {
}

void report_internal() const {
std::cout << "Section, Total time(us), Average time(us), Count, GOPs" << std::endl;
for (const auto& entry : durations) {
std::string row;
row += entry.first + ", ";
row += std::to_string(entry.second) + ", ";
row += std::to_string(entry.second / counts.at(entry.first)) + ", ";
if (flops.count(entry.first) == 0)
row += std::to_string(counts.at(entry.first)) + ", N/A";
else {
row += std::to_string(counts.at(entry.first)) + ", ";
// ops and microsecond
row += std::to_string((((float)flops.at(entry.first)) / (float)(entry.second)) / 1000.0);
if (for_demo){
std::cout << "Section, Total time(s), ms/token, #tokens" << std::endl;

for (const auto& entry : durations) {
std::string row;
std::cout << entry.first + ", ";
float s = (float)(entry.second) / 1000000;
float ts = (float)counts.at(entry.first);
printf("Total time: %.1f s, %.1f ms/token, %.1f token/s, %d tokens\n" , s, s/ts*1000, ts/s, counts.at(entry.first));
}
}
else{
std::cout << "Section, Total time(us), Average time(us), Count, GOPs" << std::endl;
for (const auto& entry : durations) {
std::string row;
row += entry.first + ", ";
row += std::to_string(entry.second) + ", ";
row += std::to_string(entry.second / counts.at(entry.first)) + ", ";
if (flops.count(entry.first) == 0)
row += std::to_string(counts.at(entry.first)) + ", N/A";
else {
row += std::to_string(counts.at(entry.first)) + ", ";
// ops and microsecond
row += std::to_string((((float)flops.at(entry.first)) / (float)(entry.second)) / 1000.0);
}
std::cout << row << std::endl;
}
std::cout << row << std::endl;
}
}

Expand Down
8 changes: 4 additions & 4 deletions transformer/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from upload import subebackups

model_paths = ["models/LLaMA_7B", "models/LLaMA_7B_2_chat", "models/LLaMA_7B_AWQ", "models/LLaMA_13B_2_chat"]
model_paths = ["models/LLaMA_13B_2_chat"]

quantized_dir = "INT4"
db_prefix = "/MIT/transformer_assets/"
Expand All @@ -38,8 +38,8 @@ def _get_parser():
parser = _get_parser()
args = parser.parse_args()

if args.method not in ["QM_x86", "QM_ARM", "FP32", "INT8"]:
raise ValueError("expect method to be one of ['QM_x86', 'QM_ARM', 'FP32', 'INT8']")
if args.method not in ["QM_x86", "QM_ARM", "QM_CUDA", "FP32", "INT8"]:
raise ValueError("expect method to be one of ['QM_x86', 'QM_ARM', 'QM_CUDA', 'FP32', 'INT8']")
QM_method = args.method

if args.model_path:
Expand All @@ -49,7 +49,7 @@ def _get_parser():

for model_path in target_paths:
# quantize
if args.method in ["QM_x86", "QM_ARM"]:
if args.method in ["QM_x86", "QM_CUDA", "QM_ARM"]:
out_dir = quantized_dir
quantize_cmd = (
f"python model_quantizer.py --model_path {model_path} --method {QM_method} --output_path {out_dir}"
Expand Down
2 changes: 2 additions & 0 deletions transformer/src/OPTGenerate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "utils.h"

#ifndef QM_CUDA // not support yet
// OPTGenerate function
std::vector<int> OPTGenerate(void *model_ptr, int model_type, std::vector<int> input_ids,
const struct opt_params generation_config, Encoder *encoder, bool interactive) {
Expand Down Expand Up @@ -175,3 +176,4 @@ std::vector<int> OPTGenerate(void *model_ptr, int model_type, std::vector<int> i

return generate_ids;
}
#endif
4 changes: 4 additions & 0 deletions transformer/src/ops/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ void Linear_FP::forward(const Matrix3D<float> &a, Matrix3D<float> &c) {
params.opt_params.num_thread = NUM_THREAD;

matmul::MatmulOperator op = matmul::MatmulOperator();
#ifndef QM_CUDA // not support yet
if (this->has_bias) {
params.bias.row = this->bias.m_dim_y;
params.bias.column = this->bias.m_dim_z;
params.bias.data_ptr = this->bias.m_data;
op.mat_mul_accelerator_transposed_fastover_column_bias((const struct matmul_params *)&params);
} else
#endif
op.mat_mul_accelerator_transposed_fastover_column((const struct matmul_params *)&params);

PROFILE_END(profile_name);
Expand Down Expand Up @@ -199,10 +201,12 @@ void Linear_FP_int4::forward(const Matrix3D<float> &x, Matrix3D<float> &output)
#ifdef PACK_QK
params.B.int4_data_ptr = (uint8_t *)this->packed_weights;
#endif
#ifndef QM_CUDA // not support yet
if (!this->has_bias)
params.bias.data_ptr = NULL;
else
params.bias.data_ptr = this->bias.m_data;
#endif
op.mat_mul_accelerator_int8_int4_fast_no_offset(&params);
#else
op.mat_mul_accelerator_int4_fast_no_offset(&params);
Expand Down

0 comments on commit 7c832a2

Please sign in to comment.