diff --git a/scripts/train.py b/scripts/train.py index 0a069d7..046af59 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -8,8 +8,8 @@ from arcsf.config.experiment import ExperimentConfig from arcsf.constants import EXPERIMENT_CONFIG_DIR from arcsf.data.data_module import ( - FinetuneDataset, - QAForgetDataset, + FinetuneQADataset, + ForgetQADataset, QAFormatter, get_data, ) @@ -67,7 +67,7 @@ def main(experiment_path): qa_formatter = QAFormatter(**experiment_config.model_config.qa_formatter_kwargs) if experiment_config.train_type in ["full", "retain"]: - train_dataset = FinetuneDataset( + train_dataset = FinetuneQADataset( data=retain, # if full training retain will contain all the data tokenizer=tokenizer, qa_formatter=qa_formatter, @@ -75,7 +75,7 @@ def main(experiment_path): base_truth_ratios = None else: loss_type = "idk" if experiment_config.train_type == "idk" else "normal" - train_dataset = QAForgetDataset( + train_dataset = ForgetQADataset( (forget, retain), tokenizer, qa_formatter, diff --git a/src/arcsf/data/data_module.py b/src/arcsf/data/data_module.py index aa53a0b..350e9e6 100644 --- a/src/arcsf/data/data_module.py +++ b/src/arcsf/data/data_module.py @@ -98,8 +98,9 @@ def __call__(self, question: str, answer: str) -> str: class EvalQADataset(torch.utils.data.Dataset): """ - Question answer format dataset, __getitem__ returns a tokenized question--answer - pair as a tuple. + Dataset class used for evaluation. __getitem__ returns a list of tokenized + question--answer pairs (formatted as a single string), with the first item + containing the ground truth answer and the rest perturbed answers. """ def __init__( @@ -224,11 +225,11 @@ def __getitem__(self, idx): return [gt_inputs] + perturbed_inputs -class FinetuneDataset(torch.utils.data.Dataset): +class FinetuneQADataset(torch.utils.data.Dataset): """ - Finetune version of the dataset, __getitem__ returns a sample taken either from - retain, forget subsets, or a combination of both. Samples are formatted using a - question formatter allowing for autoregression. + Dataset class used for conventionally fine-tuning models on a whole question-answer + dataset. __getitem__ returns a single tokenized & formatted question--answer pair + (conctenated to a single string). """ def __init__( @@ -265,11 +266,12 @@ def __getitem__(self, idx): return self.tokenizer(inp) -class QAForgetDataset(torch.utils.data.Dataset): +class ForgetQADataset(torch.utils.data.Dataset): """ - Q+A Forget version of the dataset, __getitem__ returns a retain and forget sample. - Both are formatted using a question formatter. There is an option to output samples - using "I don't know" synonyms by specifying loss_type as "idk". + Dataset class used for unlearning. __getitem__ returns a retain and forget sample. + Both are formatted using a question formatter and then tokenized. There is an option + to replace forget set answers with "I don't know" synonyms by specifying + loss_type as "idk". """ def __init__( @@ -347,7 +349,7 @@ def __getitem__(self, idx): class ForgetterDataCollator: """ Data collator that parses lists of forget and retain inputs as provided by - QAForgetDataset. + ForgetQADataset. """ def __init__(self, base_collator: Callable[[List[InputDataClass]], Dict[str, Any]]): @@ -363,7 +365,7 @@ def __call__( ) -> tuple[dict[str, Any], dict[str, Any]]: """ Args: - features: A list of outputs from QAForgetDataset, containing tuples of + features: A list of outputs from ForgetQADataset, containing tuples of forget and retain data. kwargs: Additional arguments to pass to the base collator. @@ -379,8 +381,7 @@ def __call__( class EvaluateDataCollator: """ Data collator for the evaluation scripts, on __call__ it takes a list of samples - from the evaluation dataset, and packs each clean/perturbed inputs into a padded - batch. + from an EvalQADataset, and packs each clean/perturbed inputs into a padded batch. """ def __init__(self, tokenizer: PreTrainedTokenizer, padding_side="left"): diff --git a/src/arcsf/forget/base.py b/src/arcsf/forget/base.py index 7b89865..b3c32af 100644 --- a/src/arcsf/forget/base.py +++ b/src/arcsf/forget/base.py @@ -95,7 +95,7 @@ def compute_loss( Args: model: The model to compute the loss of. inputs: Tuple of forget and either retain or IDK inputs, as returned by - QAForgetDataset. All child classes of Forgetter should expect two inputs + ForgetQADataset. All child classes of Forgetter should expect two inputs in this order. return_outputs: Whether to return the outputs of the model or just the loss. diff --git a/tests/test_data_module.py b/tests/test_data_module.py index ea36291..fc42bfb 100644 --- a/tests/test_data_module.py +++ b/tests/test_data_module.py @@ -3,7 +3,7 @@ from arcsf.data.data_module import ( EvalQADataset, - QAForgetDataset, + ForgetQADataset, QAFormatter, get_data, get_idk_responses, @@ -62,7 +62,7 @@ def test_permutation(qa_formatter, dummy_tokenizer): random_seed=42, ) # create dataset object - data_set = QAForgetDataset( + data_set = ForgetQADataset( data, dummy_tokenizer, qa_formatter, loss_type="standard" ) # dataset creates a random permutation of retain indices @@ -106,7 +106,7 @@ def test_formatter(): def test_idk_targets(data, dummy_tokenizer): """Check that when using an idk loss, that the targets are correct.""" # load idk type dataset - idk_set = QAForgetDataset( + idk_set = ForgetQADataset( (data, data), dummy_tokenizer, QAFormatter("{question} Answer:", " {answer}"),