Skip to content

Commit

Permalink
Decouple dataset split from torch
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanx749 committed Mar 24, 2024
1 parent 3c07b8b commit 2b00721
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
33 changes: 1 addition & 32 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import scipy.linalg as linalg
import torch
import torch.nn as nn
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Dataset


Expand Down Expand Up @@ -88,41 +87,11 @@ def __init__(
self,
vocab=Vocab(),
root="./data",
neg_file="0r",
pos_file="123r",
split="train",
size=5000,
):
super().__init__()
self.vocab = vocab
df0 = pd.read_csv(Path(root, f"{neg_file}.csv"), index_col=0)
df0 = df0[df0["SARS_CoV2"] == 0]
df0 = df0.sample(frac=1, random_state=42)
if split == "train":
df0 = df0.iloc[: -2 * size]
elif split == "valid":
df0 = df0.iloc[-2 * size : -size]
elif split == "test":
df0 = df0.iloc[-size:]
if "Label" not in df0.columns:
df0["Label"] = 0
df1 = pd.read_csv(Path(root, f"{pos_file}.csv"), index_col=0)
df1 = df1[df1["SARS_CoV2"] == 0]
y = df1["Label"]
sss = StratifiedShuffleSplit(n_splits=1, test_size=size, random_state=42)
train_index, test_index = next(sss.split(y, y))
if split == "test":
df1 = df1.iloc[test_index]
else:
df1 = df1.iloc[train_index]
y = df1["Label"]
sss = StratifiedShuffleSplit(n_splits=1, test_size=size, random_state=42)
train_index, valid_index = next(sss.split(y, y))
if split == "valid":
df1 = df1.iloc[valid_index]
elif split == "train":
df1 = df1.iloc[train_index]
self.df = pd.concat([df0, df1])
self.df = pd.read_csv(Path(root, f"{split}.csv"), index_col=0)
self.input = [vocab.numericalize(seq) for seq in self.df["Description"]]
self.label = self.df["Label"].tolist()

Expand Down
33 changes: 33 additions & 0 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,36 @@ def fasta2id(name):
df2fasta(dfc[dfc["SARS_CoV2"] == 1], "sars")

# %%
from sklearn.model_selection import StratifiedShuffleSplit

size = 5000
splits = ("train", "valid", "test")

df_neg = pd.read_csv("0r.csv", index_col=0)
df_neg = df_neg[df_neg["SARS_CoV2"] == 0]
df_neg = df_neg.sample(frac=1, random_state=42)
if "Label" not in df_neg.columns:
df_neg["Label"] = 0
d_neg = dict.fromkeys(splits)
d_neg["train"] = df_neg.iloc[: -2 * size]
d_neg["valid"] = df_neg.iloc[-2 * size : -size]
d_neg["test"] = df_neg.iloc[-size:]

df_pos = pd.read_csv("123r.csv", index_col=0)
df_pos = df_pos[df_pos["SARS_CoV2"] == 0]
y = df_pos["Label"]
sss = StratifiedShuffleSplit(n_splits=1, test_size=size, random_state=42)
train_index, test_index = next(sss.split(y, y))
d_pos = dict.fromkeys(splits)
d_pos["test"] = df_pos.iloc[test_index]
df_pos = df_pos.iloc[train_index]
y = df_pos["Label"]
sss = StratifiedShuffleSplit(n_splits=1, test_size=size, random_state=42)
train_index, valid_index = next(sss.split(y, y))
d_pos["valid"] = df_pos.iloc[valid_index]
d_pos["train"] = df_pos.iloc[train_index]

for split in splits:
pd.concat([d_neg[split], d_pos[split]]).to_csv(f"{split}.csv")

# %%

0 comments on commit 2b00721

Please sign in to comment.