Skip to content

Commit

Permalink
update pippy mnist compile stage example to one that works
Browse files Browse the repository at this point in the history
  • Loading branch information
Eddy Ogola Onyango committed Jul 20, 2023
1 parent 885f71e commit 8b0d423
Showing 1 changed file with 123 additions and 112 deletions.
235 changes: 123 additions & 112 deletions examples/mnist/new_pippy_mnist.py
Original file line number Diff line number Diff line change
@@ -1,194 +1,205 @@
from tqdm import tqdm
# Copyright (c) Meta Platforms, Inc. and affiliates
import argparse
import os

import torch
from torch import nn
import torch.optim as optim
import torch.distributed as dist
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn.functional import cross_entropy
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler
from torchvision import datasets, transforms # type: ignore
from tqdm import tqdm # type: ignore

from pippy.IR import LossWrapper, PipeSplitWrapper
import pippy
import pippy.fx
from pippy.IR import PipeSplitWrapper, LossWrapper
from pippy.microbatch import sum_reducer, TensorChunkSpec
from pippy.compile import compile_stage

USE_TQDM = bool(int(os.getenv("USE_TQDM", 1)))
LR_VERBOSE = bool(int(os.getenv("LR_VERBOSE", 1)))

pippy.fx.Tracer.proxy_buffer_attributes = True

USE_TQDM = bool(int(os.getenv('USE_TQDM', '1')))


# Get process group for ranks in a pipeline
def get_pp_subgroup(args):
my_pp_rank = args.rank // args.dp_group_size
my_dp_rank = args.rank % args.dp_group_size
for dp_rank in range(0, args.dp_group_size):
pp_group_ranks = list(
range(dp_rank, args.world_size, args.dp_group_size)
)
pp_group = dist.new_group(ranks=pp_group_ranks)
if dp_rank == my_dp_rank:
my_pp_group = pp_group
print(f"Rank {args.rank} done getting pp group")
return my_pp_group, my_pp_rank


# Get DP process group for ranks with the same stage
def get_dp_subgroup(args):
my_pp_rank = args.rank // args.dp_group_size
my_dp_rank = args.rank % args.dp_group_size
for pp_rank in range(0, args.pp_group_size):
dp_group_ranks = list(
range(
pp_rank * args.dp_group_size, (pp_rank + 1) * args.dp_group_size
)
)
dp_group = dist.new_group(ranks=dp_group_ranks)
if pp_rank == my_pp_rank:
my_dp_group = dp_group
print(f"Rank {args.rank} done getting dp group")
return my_dp_group, my_dp_rank


def run_worker(args):
# define transforms
torch.manual_seed(42)

# Get DP and PP sub process groups
dp_group, dp_rank = get_dp_subgroup(args)
pp_group, pp_rank = get_pp_subgroup(args)

batch_size = args.batch_size * args.chunks

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.3017,), (0.3081))
transforms.Normalize((0.1307,), (0.3081,))
])
# load data
train_data = datasets.MNIST("./data", train=True, download=True, transform=transform)
valid_data = datasets.MNIST("./data", train=False, transform=transform)
# setup training sampler
# train_sampler = DistributedSampler(train_data, num_replicas)
train_dataloader = DataLoader(train_data, batch_size=args.batch_size * args.chunks)
valid_dataloader = DataLoader(valid_data, batch_size=args.batch_size * args.chunks)

# define custom loss wrapper

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
valid_data = datasets.MNIST('./data', train=False, transform=transform)

train_sampler = DistributedSampler(train_data, num_replicas=args.dp_group_size, rank=dp_rank, shuffle=False,
drop_last=False)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)

class OutputLossWrapper(LossWrapper):
def __init__(self, module, loss_fn):
super().__init__(module, loss_fn)

def forward(self, input, target):
output = self.module(input)

return output, self.loss_fn(output, target)

# define model
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
PipeSplitWrapper(
nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
)
),
PipeSplitWrapper(nn.Linear(64, 10)),
PipeSplitWrapper(nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
)),
PipeSplitWrapper(nn.Linear(64, 10))
)

# define model wrapper
wrapper = OutputLossWrapper(model, cross_entropy)
wrapper.to(args.device)

# sample input
x = torch.randint(0, 5, (args.batch_size * args.chunks, 28, 28), device=args.device)
target = torch.randint(0, 9, (args.chunks * args.batch_size, ), device=args.device)
wrapper.to(args.device)

output_chunk_spec = (TensorChunkSpec(0), sum_reducer)

# setup compile stage
stage = compile_stage(
# sample input
x = torch.randint(0, 5, (batch_size, 28, 28), device=args.device)
target = torch.randint(0, 9, (batch_size, ), device=args.device)

stage = pippy.compile_stage(
wrapper,
rank=args.rank,
num_ranks=args.world_size,
num_chunks=args.chunks,
device=args.device,
group=None,
example_inputs=[x, target],
pp_rank,
args.pp_group_size,
args.chunks,
args.device,
pp_group,
[x, target],
output_chunk_spec=output_chunk_spec,
)

# setup optimizer
optimizer = optim.Adam(stage.submod.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
# setup lr scheduler
lr_sched = optim.lr_scheduler.LinearLR(optimizer, verbose=LR_VERBOSE)
# TODO: add back LR scheduler
#lr_sched = pipe_driver.instantiate_lr_scheduler(optim.lr_scheduler.LinearLR, verbose=LR_VERBOSE)

loaders = {
"train": train_dataloader,
"valid": valid_dataloader,
"valid": valid_dataloader
}

batches_events_contexts = []

for epoch in range(args.max_epochs):
print(f"Epoch: {epoch + 1} of {args.max_epochs}")

print(f"Epoch: {epoch + 1}")
for k, dataloader in loaders.items():
epoch_correct = 0
epoch_all = 0
for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader):
x_batch = x_batch.to(args.device)
y_batch = y_batch.to(args.device)

if k == "train":
stage.train()
outp = None
optimizer.zero_grad()

out = None

if args.rank == 0:
if pp_rank == 0:
stage(x_batch)
elif args.rank == args.world_size - 1:
out = stage(y_batch)
elif pp_rank == args.pp_group_size - 1:
outp, _ = stage(y_batch)
else:
stage()
optimizer.step()

# outp, loss = stage(x_batch, y_batch)
if out:
preds = out.argmax(-1)
if outp is not None:
preds = outp.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all
else:
# TODO: add evaluation support in PiPPy
pass

optimizer.step()
# else:
# # stage.eval()
# with torch.no_grad():
# if args.rank == 0:
# stage(x_batch, y_batch)
# elif args.rank == args.world_size - 1:
# out = stage()
# else:
# stage()
# # outp, _ = stage(x_batch, y_batch)

# preds = out.argmax(-1)
# correct = (preds == y_batch).sum()
# all = len(y_batch)
# epoch_correct += correct.item()
# epoch_all += all

print(f"Loader: {k} Accuracy: {epoch_correct / epoch_all}")

if k == "train":
lr_sched.step()
# if LR_VERBOSE:
# print(f"Pipe ") # should we have pp_ranks
if pp_rank == args.pp_group_size - 1 and epoch_all > 0:
print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}")

#if k == "train":
# lr_sched.step()

dist.barrier()
print(f"Rank {args.rank} completed!")
print('Finished')


def main(args=None):
# set seed for reproducibility
torch.manual_seed(15)

# set up parser
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# set up arguments
parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)))
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 3)))
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))

parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=10)

parser.add_argument('--replicate', type=int, default=int(os.getenv("REPLICATE", '0')))
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
parser.add_argument('--visualize', type=int, default=0, choices=[0, 1])
parser.add_argument('--checkpoint', type=int, default=0, choices=[0, 1])
parser.add_argument(
"--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost")
"--chunks",
type=int,
default=4,
)
parser.add_argument(
"--master_port", type=str, default=os.getenv("MASTER_PORT", "29500")
)
parser.add_argument("--cuda", type=int, default=int(torch.cuda.is_available()))
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--chunks", type=int, default=4)
parser.add_argument("--visualize", type=int, default=1, choices=[0, 1])
args = parser.parse_args(args)
args = parser.parse_args()

args.pp_group_size = 3
assert args.world_size % args.pp_group_size == 0
args.dp_group_size = args.world_size // args.pp_group_size

if args.cuda:
dev_id = args.rank % torch.cuda.device_count()
args.device = torch.device(f"cuda:{dev_id}")
else:
args.device = torch.device("cpu")

# init process group
backend = "nccl" if torch.cuda.is_available() else "gloo" # TODO: change to args.cuda after setting up args

# Init process group
backend = "nccl" if args.cuda else "gloo"
dist.init_process_group(
backend=backend,
rank=args.rank,
world_size=args.world_size,
)

run_worker(args)


if __name__ == "__main__":
main()

0 comments on commit 8b0d423

Please sign in to comment.