From 04139d0758df51cb35b160b61d48d022fb966e1d Mon Sep 17 00:00:00 2001 From: Praateek Date: Tue, 8 Oct 2024 14:42:18 -0700 Subject: [PATCH] fc Signed-off-by: Praateek --- nemo_curator/datasets/doc_dataset.py | 8 +++ nemo_curator/utils/distributed_utils.py | 70 +++++++++++++++++-------- 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index b3c595cf..86a3ec4e 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -43,6 +43,7 @@ def read_json( files_per_partition: int = 1, add_filename: bool = False, input_meta: Union[str, dict] = None, + partition_size: str = "2gb", ): return cls( _read_json_or_parquet( @@ -52,6 +53,7 @@ def read_json( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + partition_size=partition_size, ) ) @@ -62,6 +64,7 @@ def read_parquet( backend="pandas", files_per_partition=1, add_filename=False, + partition_size: str = "2gb", ): return cls( _read_json_or_parquet( @@ -70,6 +73,7 @@ def read_parquet( backend=backend, files_per_partition=files_per_partition, add_filename=add_filename, + partition_size=partition_size, ) ) @@ -175,6 +179,7 @@ def _read_json_or_parquet( files_per_partition: int, add_filename: bool, input_meta: Union[str, dict] = None, + partition_size: str = "2gb", ): """ `input_files` may be a list or a string type. @@ -205,6 +210,7 @@ def _read_json_or_parquet( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + partition_size=partition_size, ) # List of directories @@ -222,6 +228,7 @@ def _read_json_or_parquet( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + partition_size=partition_size, ) dfs.append(df) @@ -245,6 +252,7 @@ def _read_json_or_parquet( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + partition_size=partition_size, ) else: diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 3f37eb90..75da340f 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -308,6 +308,7 @@ def read_data( file_type: str = "pickle", backend: str = "cudf", files_per_partition: int = 1, + partition_size: str = "2gb", add_filename: bool = False, input_meta: Union[str, dict] = None, ) -> Union[dd.DataFrame, dask_cudf.DataFrame]: @@ -327,35 +328,60 @@ def read_data( A Dask-cuDF or a Dask-pandas DataFrame. """ - if backend == "cudf": - # Try using cuDF. If not availible will throw an error. - test_obj = cudf.Series - if file_type == "pickle": df = read_pandas_pickle(input_files[0], add_filename=add_filename) df = dd.from_pandas(df, npartitions=16) if backend == "cudf": df = df.to_backend("cudf") - - elif file_type in ["json", "jsonl", "parquet"]: + elif file_type in {"json", "jsonl", "parquet"}: print(f"Reading {len(input_files)} files", flush=True) - input_files = sorted(input_files) - if files_per_partition > 1: - input_files = [ - input_files[i : i + files_per_partition] - for i in range(0, len(input_files), files_per_partition) - ] + + if backend == "cudf" and ( + (file_type in {"json", "jsonl"}) + or (file_type == "parquet" and not add_filename) + ): + # Try using cuDF. If not availible will throw an error. + # test_obj = cudf.Series + import dask_cudf + + if file_type in {"json", "jsonl"}: + read_func = dask_cudf.read_json + elif file_type in {"parquet"}: + read_func = dask_cudf.read_parquet + + read_kwargs = dict() + if file_type in {"json", "jsonl"}: + read_kwargs["lines"] = file_type == "jsonl" + if input_meta is not None: + read_kwargs["prune_columns"] = True + read_kwargs["dtype"] = ( + ast.literal_eval(input_meta) + if isinstance(input_meta, str) + else input_meta + ) + + if add_filename: + read_kwargs["include_path_column"] = add_filename + df = read_func(input_files, blocksize=partition_size, **read_kwargs) + else: - input_files = [[file] for file in input_files] - return dd.from_map( - read_single_partition, - input_files, - filetype=file_type, - backend=backend, - add_filename=add_filename, - input_meta=input_meta, - enforce_metadata=False, - ) + input_files = sorted(input_files) + if files_per_partition > 1: + input_files = [ + input_files[i : i + files_per_partition] + for i in range(0, len(input_files), files_per_partition) + ] + else: + input_files = [[file] for file in input_files] + return dd.from_map( + read_single_partition, + input_files, + filetype=file_type, + backend=backend, + add_filename=add_filename, + input_meta=input_meta, + enforce_metadata=False, + ) else: raise RuntimeError("Could not read data, please check file type") return df