-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_long_bench.py
executable file
·190 lines (168 loc) · 7.55 KB
/
run_long_bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#Adapted from https://github.com/THUDM/LongBench/blob/main/pred.py
import os
from datasets import load_dataset
import torch
from loguru import logger
from tqdm import tqdm
import numpy as np
import random
import argparse
import time
import json
from loguru import logger
from datetime import datetime
os.environ["WANDB_DISABLED"] = "true"
from longbench_utils import scorer, MODEL2MAXLEN, DATASET2PROMPT, DATASET2MAXLEN
from utils import load_model_and_tokenizer, add_common_args
from palu.quant_utils import configure_latent_quantizer
import palu.model
def post_process(response, model_name):
if "xgen" in model_name:
response = response.strip().replace("Assistant:", "")
elif "internlm" in model_name:
response = response.split("<eoa>")[0]
return response
# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
# Copy from KIVI
if "longchat" in model_name.lower() or "vicuna" in model_name.lower():
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
elif "mistral-v0.2-instruct" in model_name.lower():
messages = [
{
"role": "user",
"content": prompt
}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name):
preds = []
for json_obj in tqdm(data):
prompt = prompt_format.format(**json_obj)
# truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
if len(tokenized_prompt) > max_length:
half = int(max_length/2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
prompt = build_chat(tokenizer, prompt, model_name)
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
output = model.generate(
**input,
max_new_tokens=max_gen,
num_beams=1,
do_sample=False,
temperature=1.0,
min_length=context_length+1,
eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
)[0]
else:
output = model.generate(
**input,
max_new_tokens=max_gen,
num_beams=1,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)[0]
pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
pred = post_process(pred, model_name)
preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]})
return preds
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
def main(args):
model2maxlen = MODEL2MAXLEN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, tokenizer = load_model_and_tokenizer(args.model_name_or_path, use_flash_attn2=args.flash2)
configure_latent_quantizer(
model, n_bits=args.lt_bits,
group_size=args.lt_group_size,
sym=args.lt_sym,
clip_ratio=args.lt_clip_ratio,
hadamard=args.lt_hadamard
)
#NOTE(brian1009): This is a hack to get the model name
# We assume the model name is the inside the last part of the path
# and the Palu's compression information is follow by the model name with a "_"
# Hence, we split the path by "/" and then keep only the first part by "_"
# Example: Mistral-7B-Instruct-v0.2_ratio-0.7_gs-4-fisher_uniform
raw_model_name = args.model_name_or_path.split("/")[-1]
model_type = args.model_name_or_path.split("/")[-1].split('_')[0]
model.eval()
if not model_type in model2maxlen:
raise ValueError(f"Model {model_type} not supported")
max_length = model2maxlen[model_type]
logger.info(f"Running model: {raw_model_name}")
logger.info(f"Max length: {max_length}")
datasets = args.datasets
# we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
dataset2prompt = DATASET2PROMPT
dataset2maxlen = DATASET2MAXLEN
# predict on each dataset
if not os.path.exists("Longbench/pred"):
os.makedirs("Longbench/pred")
results = {}
for dataset in datasets:
logger.info("Evaluating dataset: {}".format(dataset))
start_time = time.time()
data = load_dataset('THUDM/LongBench', dataset, split='test')
prompt_format = dataset2prompt[dataset]
max_gen = dataset2maxlen[dataset]
preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_type)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Elapsed time for dataset {dataset}: {elapsed_time/60} minutes")
# calculate score
predictions, answers, lengths = [], [], []
for pred in preds:
predictions.append(pred["pred"])
answers.append(pred["answers"])
if "length" in pred:
lengths.append(pred["length"])
all_classes = pred["all_classes"]
score = scorer(dataset, predictions, answers, all_classes)
logger.info(f"dataset: {dataset}")
logger.info(f"score: {score}")
# Log the results of each datasets
with open(f"results/Longbench/{raw_model_name}_bits_{args.lt_bits}.json", "a") as f:
data_to_log = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"dataset": dataset,
"score": score,
}
json.dump(data_to_log, f)
f.write("\n")
if __name__ == '__main__':
seed_everything(42)
parser = argparse.ArgumentParser()
add_common_args(parser)
parser.add_argument(
'--datasets', type=lambda s: [item for item in s.split(',')],
default=["triviaqa", "qasper", "trec", "samsum", "lcc", "repobench-p", "qmsum", "multi_news"],
help='The datasets to be evaluated'
)
parser.add_argument(
"--verbose",
action="store_true",
help="Whether to print verbose information or not."
)
args = parser.parse_args()
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level="INFO" if not args.verbose else "DEBUG")
#Create directory to log evaluation results.
os.makedirs("results/Longbench", exist_ok=True)
main(args)