-
Notifications
You must be signed in to change notification settings - Fork 2
/
group.py
54 lines (40 loc) · 1.3 KB
/
group.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Group fragments together according model
"""
import os
from sys import argv
import pandas
if __name__ == "__main__":
if len(argv) != 2:
print("Usage: python prepare_test.py source-data-dir")
exit(1)
SOURCE_DIR = argv[1]
def prepare_test_set():
"""
Creates a csv where the entries are:
model name, labels from all fragments
"""
LABELS = pandas.read_csv(os.path.join(SOURCE_DIR, "labels.csv"))
FRAGMENTS = pandas.read_csv(os.path.join(SOURCE_DIR, "fragments.csv"))
grouped: dict[str, str] = {} # key: model name, text
for index, row in LABELS.iterrows():
fragment = FRAGMENTS.loc[FRAGMENTS["unique_id"] == row["fragment_id"]]
model_name = fragment["model"].values[0]
text: str = row["label"]
# add punctuation
if not text.endswith("."):
text += "."
if model_name in grouped:
grouped[model_name] += " " + text
else:
grouped[model_name] = text
return grouped
if __name__ == "__main__":
grouped = prepare_test_set()
models = []
texts = []
for key, item in grouped.items():
models.append(key)
texts.append(item)
grouped_frame = pandas.DataFrame(data={"model": models, "text": texts})
grouped_frame.to_csv("data/grouped.csv")