-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
96 lines (81 loc) · 2.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from collections import defaultdict
from typing import Dict, List
import pandas as pd
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from data import Vocab
Tdict = Dict[str, Tensor]
def evaluate(data: DataLoader, model: nn.Module):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
lst = defaultdict(list)
progress_bar = tqdm(data, ascii=True)
with torch.no_grad():
for batch_idx, batch in enumerate(progress_bar):
batch: Tdict = {k: v.to(device) for k, v in batch.items()}
output: Tdict = model(**batch)
for k in batch.keys():
if k.startswith("label_"):
lst[k].extend(batch[k].tolist())
for k in output.keys():
lst[k].extend(output[k].tolist())
return lst
def fasta2df(file) -> pd.DataFrame:
rows = []
columns = ["identifier", "sequence"]
with open(file) as f:
for line in f:
if line[0] == ">":
identifier = line[1:].rstrip()
break
else:
raise ValueError("Empty file.")
seq = ""
for line in f:
if line[0] != ">":
seq += line.rstrip()
else:
rows.append((identifier, seq))
identifier = line[1:].rstrip()
seq = ""
rows.append((identifier, seq))
return pd.DataFrame.from_records(rows, columns=columns)
class Peptides(Dataset):
def __init__(self, df: pd.DataFrame, vocab=Vocab()):
self.vocab = vocab
self.input = [vocab.numericalize(seq) for seq in df["sequence"]]
def __len__(self):
return len(self.input)
def __getitem__(self, key):
return self.input[key]
def collate_fn(self, batch: List[List[int]]):
input = [torch.tensor(lst) for lst in batch]
input = nn.utils.rnn.pad_sequence(
input, batch_first=True, padding_value=self.vocab.stoi["<pad>"]
)
return input
def predict(data: DataLoader, model: nn.Module):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
lst = defaultdict(list)
progress_bar = tqdm(data, ascii=True)
with torch.no_grad():
for batch_idx, batch in enumerate(progress_bar):
batch: Tensor = batch.to(device)
output: Tdict = model(input=batch)
for k in output.keys():
lst[k].extend(output[k].tolist())
return lst
def dict2df(lst: dict) -> pd.DataFrame:
label = dict(zip(range(3), ["A", "E", "M"]))
data = {
"score": lst["score_p"],
"epitope": lst["prediction_p"],
"Ig": [label[e] for e in lst["prediction_ig"]],
}
return pd.DataFrame(data)