-
Notifications
You must be signed in to change notification settings - Fork 51
/
evaluate.py
312 lines (282 loc) · 10.1 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import json
import shutil
from itertools import islice
from time import time
from typing import Tuple, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from baselines.ft import FTHyperParams, apply_ft_to_model
from baselines.mend import MENDHyperParams, MendRewriteExecutor
from dsets import (
AttributeSnippets,
CounterFactDataset,
MENDQADataset,
MultiCounterFactDataset,
get_tfidf_vectorizer,
)
from experiments.py.eval_utils_counterfact import compute_rewrite_quality_counterfact
from experiments.py.eval_utils_zsre import compute_rewrite_quality_zsre
from memit import MEMITHyperParams, apply_memit_to_model
from rome import ROMEHyperParams, apply_rome_to_model
from util import nethook
from util.globals import *
ALG_DICT = {
"MEMIT": (MEMITHyperParams, apply_memit_to_model),
"ROME": (ROMEHyperParams, apply_rome_to_model),
"FT": (FTHyperParams, apply_ft_to_model),
"MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model),
}
DS_DICT = {
"mcf": (MultiCounterFactDataset, compute_rewrite_quality_counterfact),
"cf": (CounterFactDataset, compute_rewrite_quality_counterfact),
"zsre": (MENDQADataset, compute_rewrite_quality_zsre),
}
def main(
alg_name: str,
model_name: Union[str, Tuple],
hparams_fname: str,
ds_name: str,
dataset_size_limit: int,
continue_from_run: str,
skip_generation_tests: bool,
generation_test_interval: int,
conserve_memory: bool,
dir_name: str,
num_edits: int = 1,
use_cache: bool = False,
):
# Set algorithm-specific variables
params_class, apply_algo = ALG_DICT[alg_name]
# Determine run directory
# Create new dir if not continuing from prev run OR prev run doesn't exist
if (
continue_from_run is None
or not (run_dir := RESULTS_DIR / dir_name / continue_from_run).exists()
):
continue_from_run = None
if continue_from_run is None:
alg_dir = RESULTS_DIR / dir_name
if alg_dir.exists():
id_list = [
int(str(x).split("_")[-1])
for x in alg_dir.iterdir()
if str(x).split("_")[-1].isnumeric()
]
run_id = 0 if not id_list else max(id_list) + 1
else:
run_id = 0
run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}"
run_dir.mkdir(parents=True, exist_ok=True)
print(f"Results will be stored at {run_dir}")
# Get run hyperparameters
params_path = (
run_dir / "params.json"
if continue_from_run is not None
else HPARAMS_DIR / alg_name / hparams_fname
)
hparams = params_class.from_json(params_path)
if not (run_dir / "params.json").exists():
shutil.copyfile(params_path, run_dir / "params.json")
print(f"Executing {alg_name} with parameters {hparams}")
# Instantiate vanilla model
if type(model_name) is str:
print("Instantiating model")
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tok = AutoTokenizer.from_pretrained(model_name)
tok.pad_token = tok.eos_token
else:
model, tok = model_name
model_name = model.config._name_or_path
# Load data
print("Loading dataset, attribute snippets, tf-idf data")
snips = AttributeSnippets(DATA_DIR) if not skip_generation_tests else None
vec = get_tfidf_vectorizer(DATA_DIR) if not skip_generation_tests else None
if num_edits > 1:
assert ds_name != "cf", f"{ds_name} does not support multiple edits"
ds_class, ds_eval_method = DS_DICT[ds_name]
ds = ds_class(DATA_DIR, tok=tok, size=dataset_size_limit)
# Get cache templates
cache_template = None
if use_cache:
cache_template = (
KV_DIR
/ f"{model_name.replace('/', '_')}_{alg_name}"
/ f"{ds_name}_layer_{{}}_clamp_{{}}_case_{{}}.npz"
)
print(f"Will load cache from {cache_template}")
# Iterate through dataset
for record_chunks in chunks(ds, num_edits):
case_result_template = str(run_dir / "{}_edits-case_{}.json")
# Is the chunk already done?
already_finished = True
for record in record_chunks:
if not Path(
case_result_template.format(num_edits, record["case_id"])
).exists():
already_finished = False
break
if already_finished:
continue
# Compute weight changes + record weights that changed
case_ids = [record["case_id"] for record in record_chunks]
args_conserve_memory = (
dict(return_orig_weights_device=("cpu" if conserve_memory else "cuda"))
if conserve_memory
else dict()
)
etc_args = dict(cache_template=cache_template) if any(alg in alg_name for alg in ["ROME", "MEMIT"]) else dict()
start = time()
edited_model, weights_copy = apply_algo(
model,
tok,
[
{"case_id": record["case_id"], **record["requested_rewrite"]}
for record in record_chunks
],
hparams,
copy=False,
return_orig_weights=True,
**args_conserve_memory,
**etc_args,
)
exec_time = time() - start
print("Execution took", exec_time)
# Evaluate new model
start = time()
gen_test_vars = [snips, vec]
for record in record_chunks:
out_file = Path(case_result_template.format(num_edits, record["case_id"]))
if out_file.exists():
print(f"Skipping {out_file}; already exists")
continue
metrics = {
"case_id": record["case_id"],
"grouped_case_ids": case_ids,
"num_edits": num_edits,
"requested_rewrite": record["requested_rewrite"],
"time": exec_time,
"post": ds_eval_method(
edited_model,
tok,
record,
*(
gen_test_vars
if record["case_id"] % generation_test_interval == 0
else [None, None]
), # Only test generation every generation_test_interval cases
),
}
# Dump metrics in .json
with open(out_file, "w") as f:
json.dump(metrics, f, indent=1)
# Restore original weights
with torch.no_grad():
for k, v in weights_copy.items():
nethook.get_parameter(model, k)[...] = v.to("cuda")
print("Evaluation took", time() - start)
def window(seq, n=2):
"Returns a sliding window (of width n) over data from the iterable"
" s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... "
it = iter(seq)
result = tuple(islice(it, n))
if len(result) == n:
yield result
for elem in it:
result = result[1:] + (elem,)
yield result
def chunks(arr, n):
"""Yield successive n-sized chunks from arr."""
for i in range(0, len(arr), n):
yield arr[i : i + n]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--alg_name",
choices=["MEMIT", "ROME", "FT", "MEND"],
default="ROME",
help="Editing algorithm to use. Results are saved in results/<alg_name>/<run_id>, "
"where a new run_id is generated on each run. "
"If continuing from previous run, specify the run_id in --continue_from_run.",
required=True,
)
parser.add_argument(
"--model_name",
choices=["gpt2-medium", "gpt2-large", "gpt2-xl", "EleutherAI/gpt-j-6B"],
default="gpt2-xl",
help="Model to edit.",
required=True,
)
parser.add_argument(
"--hparams_fname",
type=str,
default="gpt2-xl.json",
help="Name of hyperparameters file, located in the hparams/<alg_name> folder.",
required=True,
)
parser.add_argument(
"--ds_name",
choices=["mcf", "cf", "zsre"],
default="mcf",
help="Dataset to perform evaluations on. Either CounterFact (cf), MultiCounterFact (mcf), or zsRE (zsre).",
)
parser.add_argument(
"--continue_from_run",
type=str,
default=None,
help="If continuing from previous run, set to run_id. Otherwise, leave as None.",
)
parser.add_argument(
"--dataset_size_limit",
type=int,
default=None,
help="Truncate CounterFact to first n records.",
)
parser.add_argument(
"--skip_generation_tests",
dest="skip_generation_tests",
action="store_true",
help="Only run fast probability-based tests without slow generation tests. "
"Useful for quick debugging and hyperparameter sweeps.",
)
parser.add_argument(
"--generation_test_interval",
type=int,
default=1,
help="One generation test is performed every [flag_value] iterations. If -1, generation tests are skipped.",
)
parser.add_argument(
"--conserve_memory",
dest="conserve_memory",
action="store_true",
help="Reduce memory usage during evaluation at the cost of a minor slowdown. "
"Backs up model weights on CPU instead of GPU.",
)
parser.add_argument(
"--num_edits",
type=int,
default=1,
help="Number of rewrites to perform simultaneously.",
)
parser.add_argument(
"--use_cache",
dest="use_cache",
action="store_true",
help="Use cached k/v pairs",
)
parser.set_defaults(skip_generation_tests=False, conserve_memory=False)
args = parser.parse_args()
main(
args.alg_name,
args.model_name,
args.hparams_fname,
args.ds_name,
args.dataset_size_limit,
args.continue_from_run,
args.skip_generation_tests,
args.generation_test_interval,
args.conserve_memory,
dir_name=args.alg_name,
num_edits=args.num_edits,
use_cache=args.use_cache,
)