Skip to content

Commit

Permalink
use raw dataset instead of dataloaders to get forwards to work
Browse files Browse the repository at this point in the history
  • Loading branch information
eddogola committed Jul 14, 2023
1 parent f0a0dad commit 64ca53c
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions examples/checkpoint/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@


class RandomCustomDataset(Dataset):
def __init__(self, size=10000):
self.samples = [torch.randn(d_hid, d_hid) for _ in range(size)]
self.targets = [torch.randn(d_hid, d_hid) for _ in range(size)]
def __init__(self, chunks=1, size=10000):
self.samples = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)]
self.targets = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)]

def __len__(self):
return len(self.samples)
Expand Down Expand Up @@ -64,14 +64,18 @@ def run_worker(args):
ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device)
target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device)

ds = RandomCustomDataset()
ds = RandomCustomDataset(chunks=args.chunks)
train_size = int(0.7*len(ds))
test_size = len(ds) - train_size
train_ds, test_ds = random_split(ds, [train_size, test_size])
train_dl, test_dl = DataLoader(train_ds, batch_size=1), DataLoader(test_ds, batch_size=1)
loaders = {
"train": train_dl,
"test": test_dl,
# train_dl, test_dl = DataLoader(train_ds), DataLoader(test_ds)
# loaders = {
# "train": train_dl,
# "test": test_dl,
# }
datasets = {
"train": train_ds,
"test": test_ds,
}

stage = compile_stage(
Expand All @@ -90,29 +94,27 @@ def run_worker(args):
for epoch in range(2): # change to no. epochs
print(f"Epoch: {epoch + 1}")

for k, loader in loaders.items():
for k, dataset in datasets.items():
epoch_correct = 0
epoch_all = 0

for i, (x_batch, y_batch) in enumerate(loader):
x_batch = x_batch.to(args.device)
y_batch = y_batch.to(args.device)
for i, (x, y) in enumerate(dataset):
x = x.to(args.device)
y = y.to(args.device)

if k == "train":
# Zero gradients
optimizer.zero_grad()

# Run
if args.rank == 0:
out = stage(ec_x)
out = stage(x)
elif args.rank == args.world_size - 1:
out = stage(target)

out_tensor = out['loss']

preds = out_tensor.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
correct = (preds == y).sum()
all = len(y)
epoch_correct += correct.item()
epoch_all += all

Expand All @@ -123,12 +125,18 @@ def run_worker(args):
else:
stage.eval()
with torch.no_grad():
out = stage(x_batch)
preds = out.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all
if args.rank == 0:
out = stage(x)
elif args.rank == args.world_size - 1:
out = stage(x)

preds = out.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all
else:
stage(x)

print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}")

Expand Down Expand Up @@ -156,6 +164,7 @@ def main(args=None):
type=int,
default=4,
)
parser.add_argument("--checkpoint_epochs", type=int, default=5)
args = parser.parse_args(args)

if args.cuda:
Expand Down

0 comments on commit 64ca53c

Please sign in to comment.