From 5a3b5e53168e681bb2ed21b462b0484a96808108 Mon Sep 17 00:00:00 2001 From: Ben Eisner Date: Mon, 24 Jul 2023 12:26:10 -0400 Subject: [PATCH] fix formatting --- flowbot3d/datasets/flow_dataset_pyg.py | 2 +- flowbot3d/train.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/flowbot3d/datasets/flow_dataset_pyg.py b/flowbot3d/datasets/flow_dataset_pyg.py index 1f88bff..1958951 100644 --- a/flowbot3d/datasets/flow_dataset_pyg.py +++ b/flowbot3d/datasets/flow_dataset_pyg.py @@ -23,7 +23,7 @@ def __init__( randomize_joints: bool = True, randomize_camera: bool = True, n_points: Optional[int] = 1200, - seed: int = 42 + seed: int = 42, ) -> None: super().__init__() diff --git a/flowbot3d/train.py b/flowbot3d/train.py index 2b87a25..d42fdb4 100644 --- a/flowbot3d/train.py +++ b/flowbot3d/train.py @@ -51,7 +51,7 @@ def create_flowbot_datasets( root=root / "raw", split="umpnet-train-train", randomize_camera=randomize_camera, - seed=seed + seed=seed, ), data_keys=rpd.UMPNET_TRAIN_TRAIN_OBJ_IDS, root=root, @@ -60,7 +60,7 @@ def create_flowbot_datasets( randomize_camera, ), n_repeat=100, - n_workers=32, + n_workers=32, n_proc_per_worker=n_proc, ) @@ -70,7 +70,7 @@ def create_flowbot_datasets( root=root / "raw", split="umpnet-train-test", randomize_camera=randomize_camera, - seed=seed + seed=seed, ), data_keys=rpd.UMPNET_TRAIN_TEST_OBJ_IDS, root=root, @@ -79,7 +79,7 @@ def create_flowbot_datasets( randomize_camera, ), n_repeat=1, - n_workers=32, + n_workers=32, n_proc_per_worker=n_proc, ) @@ -89,7 +89,7 @@ def create_flowbot_datasets( root=root / "raw", split="umpnet-test", randomize_camera=randomize_camera, - seed=seed + seed=seed, ), data_keys=rpd.UMPNET_TEST_OBJ_IDS, root=root, @@ -98,7 +98,7 @@ def create_flowbot_datasets( randomize_camera, ), n_repeat=1, - n_workers=32, + n_workers=32, n_proc_per_worker=n_proc, ) elif dataset == "single": @@ -108,7 +108,7 @@ def create_flowbot_datasets( root=root / "raw", split=["7179"], randomize_camera=randomize_camera, - seed=seed + seed=seed, ), data_keys=["7179"], root=root, @@ -117,7 +117,7 @@ def create_flowbot_datasets( randomize_camera, ), n_repeat=1, - n_workers=32, + n_workers=32, n_proc_per_worker=n_proc, ) train_dset = dset