From f17335a8a83e56ac0b1ba2a8de6c9022f2d23450 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Mon, 15 Jul 2024 17:46:18 -0400 Subject: [PATCH] Update docs for ConcatDataset (#1181) --- docs/source/tutorials/datasets.rst | 5 +++-- torchtune/datasets/_concat.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/datasets.rst b/docs/source/tutorials/datasets.rst index 878aa7cee0..7807c5bb21 100644 --- a/docs/source/tutorials/datasets.rst +++ b/docs/source/tutorials/datasets.rst @@ -395,8 +395,9 @@ you can also add more advanced behavior. Multiple in-memory datasets --------------------------- -It is also possible to train on multiple datasets and configure them individually. -You can even mix instruct and chat datasets or other custom datasets. +It is also possible to train on multiple datasets and configure them individually using +our :class:`~torchtune.datasets.ConcatDataset` interface. You can even mix instruct and chat datasets +or other custom datasets. .. code-block:: yaml diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index 307ed637d7..6c7522b884 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -41,13 +41,29 @@ class ConcatDataset(Dataset): _indexes (List[Tuple[int, int, int]]): A list of tuples where each tuple contains the starting index, the ending index, and the dataset index for quick lookup and access during indexing operations. - Example: + Examples: >>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2) >>> concat_dataset = ConcatDataset([dataset1, dataset2]) >>> print(len(concat_dataset)) # Total length of both datasets >>> data_point = concat_dataset[1500] # Accesses an element from the appropriate dataset + This can also be accomplished by passing in a list of datasets to the YAML config:: + + dataset: + - _component_: torchtune.datasets.instruct_dataset + source: vicgalle/alpaca-gpt4 + template: torchtune.data.AlpacaInstructTemplate + split: train + train_on_input: True + - _component_: torchtune.datasets.instruct_dataset + source: samsum + template: torchtune.data.SummarizeTemplate + column_map: {"output": "summary"} + output: summary + split: train + train_on_input: False + This class primarily focuses on providing a unified interface to access elements from multiple datasets, enhancing the flexibility in handling diverse data sources for training machine learning models. """