Skip to content

Commit

Permalink
Improve consistency of dataset class names/docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jack89roberts committed Oct 16, 2024
1 parent 7125172 commit 778ceb1
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 22 deletions.
8 changes: 4 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from arcsf.config.experiment import ExperimentConfig
from arcsf.constants import EXPERIMENT_CONFIG_DIR
from arcsf.data.data_module import (
FinetuneDataset,
QAForgetDataset,
FinetuneQADataset,
ForgetQADataset,
QAFormatter,
get_data,
)
Expand Down Expand Up @@ -67,15 +67,15 @@ def main(experiment_path):
qa_formatter = QAFormatter(**experiment_config.model_config.qa_formatter_kwargs)

if experiment_config.train_type in ["full", "retain"]:
train_dataset = FinetuneDataset(
train_dataset = FinetuneQADataset(
data=retain, # if full training retain will contain all the data
tokenizer=tokenizer,
qa_formatter=qa_formatter,
)
base_truth_ratios = None
else:
loss_type = "idk" if experiment_config.train_type == "idk" else "normal"
train_dataset = QAForgetDataset(
train_dataset = ForgetQADataset(
(forget, retain),
tokenizer,
qa_formatter,
Expand Down
29 changes: 15 additions & 14 deletions src/arcsf/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def __call__(self, question: str, answer: str) -> str:

class EvalQADataset(torch.utils.data.Dataset):
"""
Question answer format dataset, __getitem__ returns a tokenized question--answer
pair as a tuple.
Dataset class used for evaluation. __getitem__ returns a list of tokenized
question--answer pairs (formatted as a single string), with the first item
containing the ground truth answer and the rest perturbed answers.
"""

def __init__(
Expand Down Expand Up @@ -224,11 +225,11 @@ def __getitem__(self, idx):
return [gt_inputs] + perturbed_inputs


class FinetuneDataset(torch.utils.data.Dataset):
class FinetuneQADataset(torch.utils.data.Dataset):
"""
Finetune version of the dataset, __getitem__ returns a sample taken either from
retain, forget subsets, or a combination of both. Samples are formatted using a
question formatter allowing for autoregression.
Dataset class used for conventionally fine-tuning models on a whole question-answer
dataset. __getitem__ returns a single tokenized & formatted question--answer pair
(conctenated to a single string).
"""

def __init__(
Expand Down Expand Up @@ -265,11 +266,12 @@ def __getitem__(self, idx):
return self.tokenizer(inp)


class QAForgetDataset(torch.utils.data.Dataset):
class ForgetQADataset(torch.utils.data.Dataset):
"""
Q+A Forget version of the dataset, __getitem__ returns a retain and forget sample.
Both are formatted using a question formatter. There is an option to output samples
using "I don't know" synonyms by specifying loss_type as "idk".
Dataset class used for unlearning. __getitem__ returns a retain and forget sample.
Both are formatted using a question formatter and then tokenized. There is an option
to replace forget set answers with "I don't know" synonyms by specifying
loss_type as "idk".
"""

def __init__(
Expand Down Expand Up @@ -347,7 +349,7 @@ def __getitem__(self, idx):
class ForgetterDataCollator:
"""
Data collator that parses lists of forget and retain inputs as provided by
QAForgetDataset.
ForgetQADataset.
"""

def __init__(self, base_collator: Callable[[List[InputDataClass]], Dict[str, Any]]):
Expand All @@ -363,7 +365,7 @@ def __call__(
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Args:
features: A list of outputs from QAForgetDataset, containing tuples of
features: A list of outputs from ForgetQADataset, containing tuples of
forget and retain data.
kwargs: Additional arguments to pass to the base collator.
Expand All @@ -379,8 +381,7 @@ def __call__(
class EvaluateDataCollator:
"""
Data collator for the evaluation scripts, on __call__ it takes a list of samples
from the evaluation dataset, and packs each clean/perturbed inputs into a padded
batch.
from an EvalQADataset, and packs each clean/perturbed inputs into a padded batch.
"""

def __init__(self, tokenizer: PreTrainedTokenizer, padding_side="left"):
Expand Down
2 changes: 1 addition & 1 deletion src/arcsf/forget/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def compute_loss(
Args:
model: The model to compute the loss of.
inputs: Tuple of forget and either retain or IDK inputs, as returned by
QAForgetDataset. All child classes of Forgetter should expect two inputs
ForgetQADataset. All child classes of Forgetter should expect two inputs
in this order.
return_outputs: Whether to return the outputs of the model or just the loss.
Expand Down
6 changes: 3 additions & 3 deletions tests/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from arcsf.data.data_module import (
EvalQADataset,
QAForgetDataset,
ForgetQADataset,
QAFormatter,
get_data,
get_idk_responses,
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_permutation(qa_formatter, dummy_tokenizer):
random_seed=42,
)
# create dataset object
data_set = QAForgetDataset(
data_set = ForgetQADataset(
data, dummy_tokenizer, qa_formatter, loss_type="standard"
)
# dataset creates a random permutation of retain indices
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_formatter():
def test_idk_targets(data, dummy_tokenizer):
"""Check that when using an idk loss, that the targets are correct."""
# load idk type dataset
idk_set = QAForgetDataset(
idk_set = ForgetQADataset(
(data, data),
dummy_tokenizer,
QAFormatter("{question} Answer:", " {answer}"),
Expand Down

0 comments on commit 778ceb1

Please sign in to comment.