From 64ca53c88d3adbe897fb8cb0619341b9a66a7cce Mon Sep 17 00:00:00 2001 From: Eddy Date: Fri, 14 Jul 2023 11:50:51 -0700 Subject: [PATCH] use raw dataset instead of dataloaders to get forwards to work --- examples/checkpoint/toy_model.py | 55 +++++++++++++++++++------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/examples/checkpoint/toy_model.py b/examples/checkpoint/toy_model.py index 078357ee0..0a9d0474a 100644 --- a/examples/checkpoint/toy_model.py +++ b/examples/checkpoint/toy_model.py @@ -19,9 +19,9 @@ class RandomCustomDataset(Dataset): - def __init__(self, size=10000): - self.samples = [torch.randn(d_hid, d_hid) for _ in range(size)] - self.targets = [torch.randn(d_hid, d_hid) for _ in range(size)] + def __init__(self, chunks=1, size=10000): + self.samples = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)] + self.targets = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)] def __len__(self): return len(self.samples) @@ -64,14 +64,18 @@ def run_worker(args): ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) - ds = RandomCustomDataset() + ds = RandomCustomDataset(chunks=args.chunks) train_size = int(0.7*len(ds)) test_size = len(ds) - train_size train_ds, test_ds = random_split(ds, [train_size, test_size]) - train_dl, test_dl = DataLoader(train_ds, batch_size=1), DataLoader(test_ds, batch_size=1) - loaders = { - "train": train_dl, - "test": test_dl, + # train_dl, test_dl = DataLoader(train_ds), DataLoader(test_ds) + # loaders = { + # "train": train_dl, + # "test": test_dl, + # } + datasets = { + "train": train_ds, + "test": test_ds, } stage = compile_stage( @@ -90,13 +94,13 @@ def run_worker(args): for epoch in range(2): # change to no. epochs print(f"Epoch: {epoch + 1}") - for k, loader in loaders.items(): + for k, dataset in datasets.items(): epoch_correct = 0 epoch_all = 0 - for i, (x_batch, y_batch) in enumerate(loader): - x_batch = x_batch.to(args.device) - y_batch = y_batch.to(args.device) + for i, (x, y) in enumerate(dataset): + x = x.to(args.device) + y = y.to(args.device) if k == "train": # Zero gradients @@ -104,15 +108,13 @@ def run_worker(args): # Run if args.rank == 0: - out = stage(ec_x) + out = stage(x) elif args.rank == args.world_size - 1: out = stage(target) - out_tensor = out['loss'] - preds = out_tensor.argmax(-1) - correct = (preds == y_batch).sum() - all = len(y_batch) + correct = (preds == y).sum() + all = len(y) epoch_correct += correct.item() epoch_all += all @@ -123,12 +125,18 @@ def run_worker(args): else: stage.eval() with torch.no_grad(): - out = stage(x_batch) - preds = out.argmax(-1) - correct = (preds == y_batch).sum() - all = len(y_batch) - epoch_correct += correct.item() - epoch_all += all + if args.rank == 0: + out = stage(x) + elif args.rank == args.world_size - 1: + out = stage(x) + + preds = out.argmax(-1) + correct = (preds == y_batch).sum() + all = len(y_batch) + epoch_correct += correct.item() + epoch_all += all + else: + stage(x) print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") @@ -156,6 +164,7 @@ def main(args=None): type=int, default=4, ) + parser.add_argument("--checkpoint_epochs", type=int, default=5) args = parser.parse_args(args) if args.cuda: