Skip to content

Commit

Permalink
fix formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Eddy Ogola Onyango committed Jul 20, 2023
1 parent b834958 commit 519a360
Showing 1 changed file with 67 additions and 44 deletions.
111 changes: 67 additions & 44 deletions examples/mnist/pippy_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
import argparse
import os

import pippy
import pippy.fx

import torch
import torch.distributed as dist
from pippy.IR import LossWrapper, PipeSplitWrapper
from pippy.microbatch import sum_reducer, TensorChunkSpec
from torch import nn, optim
from torch.nn.functional import cross_entropy
from torch.utils.data import DistributedSampler
from torchvision import datasets, transforms # type: ignore
from tqdm import tqdm # type: ignore

import pippy
import pippy.fx
from pippy.IR import PipeSplitWrapper, LossWrapper
from pippy.microbatch import sum_reducer, TensorChunkSpec


pippy.fx.Tracer.proxy_buffer_attributes = True

USE_TQDM = bool(int(os.getenv('USE_TQDM', '1')))
USE_TQDM = bool(int(os.getenv("USE_TQDM", "1")))


# Get process group for ranks in a pipeline
Expand Down Expand Up @@ -62,19 +62,29 @@ def run_worker(args):

batch_size = args.batch_size * args.chunks

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
valid_data = datasets.MNIST('./data', train=False, transform=transform)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_sampler = DistributedSampler(train_data, num_replicas=args.dp_group_size, rank=dp_rank, shuffle=False,
drop_last=False)
train_data = datasets.MNIST(
"./data", train=True, download=True, transform=transform
)
valid_data = datasets.MNIST("./data", train=False, transform=transform)

train_sampler = DistributedSampler(
train_data,
num_replicas=args.dp_group_size,
rank=dp_rank,
shuffle=False,
drop_last=False,
)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
train_dataloader = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, sampler=train_sampler
)
valid_dataloader = torch.utils.data.DataLoader(
valid_data, batch_size=batch_size
)

class OutputLossWrapper(LossWrapper):
def __init__(self, module, loss_fn):
Expand All @@ -88,11 +98,13 @@ def forward(self, input, target):
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
PipeSplitWrapper(nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
)),
PipeSplitWrapper(nn.Linear(64, 10))
PipeSplitWrapper(
nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
)
),
PipeSplitWrapper(nn.Linear(64, 10)),
)

wrapper = OutputLossWrapper(model, cross_entropy)
Expand All @@ -103,7 +115,7 @@ def forward(self, input, target):

# sample input
x = torch.randint(0, 5, (batch_size, 28, 28), device=args.device)
target = torch.randint(0, 9, (batch_size, ), device=args.device)
target = torch.randint(0, 9, (batch_size,), device=args.device)

stage = pippy.compile_stage(
wrapper,
Expand All @@ -116,21 +128,22 @@ def forward(self, input, target):
output_chunk_spec=output_chunk_spec,
)

optimizer = optim.Adam(stage.submod.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
optimizer = optim.Adam(
stage.submod.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8
)
# TODO: add back LR scheduler
#lr_sched = pipe_driver.instantiate_lr_scheduler(optim.lr_scheduler.LinearLR, verbose=LR_VERBOSE)
# lr_sched = pipe_driver.instantiate_lr_scheduler(optim.lr_scheduler.LinearLR, verbose=LR_VERBOSE)

loaders = {
"train": train_dataloader,
"valid": valid_dataloader
}
loaders = {"train": train_dataloader, "valid": valid_dataloader}

for epoch in range(args.max_epochs):
print(f"Epoch: {epoch + 1}")
for k, dataloader in loaders.items():
epoch_correct = 0
epoch_all = 0
for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader):
for i, (x_batch, y_batch) in enumerate(
tqdm(dataloader) if USE_TQDM else dataloader
):
x_batch = x_batch.to(args.device)
y_batch = y_batch.to(args.device)
if k == "train":
Expand All @@ -157,26 +170,36 @@ def forward(self, input, target):
if pp_rank == args.pp_group_size - 1 and epoch_all > 0:
print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}")

#if k == "train":
# if k == "train":
# lr_sched.step()

print('Finished')
print("Finished")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 3)))
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))

parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=10)

parser.add_argument('--replicate', type=int, default=int(os.getenv("REPLICATE", '0')))
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
parser.add_argument('--visualize', type=int, default=0, choices=[0, 1])
parser.add_argument('--checkpoint', type=int, default=0, choices=[0, 1])
parser.add_argument(
"--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 3))
)
parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument(
"--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost")
)
parser.add_argument(
"--master_port", type=str, default=os.getenv("MASTER_PORT", "29500")
)

parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=10)

parser.add_argument(
"--replicate", type=int, default=int(os.getenv("REPLICATE", "0"))
)
parser.add_argument(
"--cuda", type=int, default=int(torch.cuda.is_available())
)
parser.add_argument("--visualize", type=int, default=0, choices=[0, 1])
parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1])
parser.add_argument(
"--chunks",
type=int,
Expand Down

0 comments on commit 519a360

Please sign in to comment.