Skip to content

Commit

Permalink
Fix generation of loss spec from output spec (#835)
Browse files Browse the repository at this point in the history
## Description

In PiPPy we automatically generate loss spec from output chunk spec.
Terminology:
"loss spec": a one-hot map indicating which output value is a loss. It
consists of True, False values.
"output chunk spec": a data structure corresponding to output format
describing which (chunked) value should be merged and which should be
reduced (such as loss).

## Issue

In previous code, we only considered the case where the output chunk
spec is a dictionary.
But in cases such as `return logits, loss`, the output chunk spec is a
tuple, i.e. `(TensorChunkSpec(0), sum_reducer)`.

## Fix

Use `fx.node.map_aggregate` to generalize the auto generation.
It takes a lambda function that outputs True/False.

## Testing

torchrun --nproc-per-node 4 local_test_chunkspec.py
  • Loading branch information
kwen2501 authored Jul 10, 2023
1 parent edc0452 commit 11b2c49
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pippy/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def _compile(

# Figure out which output is loss from output_chunk_spec
output_loss_value_spec: Any = None
if isinstance(output_chunk_spec, dict):
output_loss_value_spec = {
k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items()
}
if output_chunk_spec is not None:
output_loss_value_spec = fx.node.map_aggregate(
output_chunk_spec, lambda v: isinstance(v, LossReducer)
)

logging.info("[PiPPy] Tracing model ...")
pipe_model = Pipe.from_tracing(
Expand Down Expand Up @@ -239,10 +239,10 @@ def compile_stage(

# Figure out which output is loss from output_chunk_spec
output_loss_value_spec: Any = None
if isinstance(output_chunk_spec, dict):
output_loss_value_spec = {
k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items()
}
if output_chunk_spec is not None:
output_loss_value_spec = fx.node.map_aggregate(
output_chunk_spec, lambda v: isinstance(v, LossReducer)
)

logging.info("[PiPPy] Tracing model ...")
pipe = Pipe.from_tracing(
Expand Down
136 changes: 136 additions & 0 deletions test/local_test_chunkspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import argparse
import os
import unittest

import torch
import torch.distributed as dist

from pippy.compile import compile_stage
from pippy.IR import pipe_split
from pippy.microbatch import sum_reducer, TensorChunkSpec

d_hid = 512
chunk_size = 256

torch.manual_seed(0)


class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin = torch.nn.Linear(d_hid, d_hid)
self.mse_loss = torch.nn.MSELoss(reduction="sum")

def forward(self, x, target):
x = torch.mm(x, self.mm_param)
skip_connection = x
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param)
x = self.lin(x)
pipe_split()
x = torch.relu(x)
x = x + skip_connection
x = torch.mm(x, self.mm_param2)
pipe_split()
x = self.lin(x)
x = torch.relu(x)
loss = self.mse_loss(x, target)
return loss, x


def run_worker(args):
ec = ExampleCode()
ec.to(args.device)
ec.train()

ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device)
target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device)

stage = compile_stage(
ec,
args.rank,
args.world_size,
args.chunks,
args.device,
None,
[ec_x, target],
output_chunk_spec=(sum_reducer, TensorChunkSpec(0)),
)

# Run
if args.rank == 0:
out = stage(ec_x)
elif args.rank == args.world_size - 1:
out = stage(target)
else:
stage()

dist.barrier()
print(f"Rank {args.rank} completes")

# Last rank checks result
if args.rank == args.world_size - 1:
ref_out = ec(ec_x, target)
torch.testing.assert_close(out, ref_out)
print(
f"equivalence test passed, loss = {out[0]}, ref loss = {ref_out[0]}"
)


def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4))
)
parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument(
"--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost")
)
parser.add_argument(
"--master_port", type=str, default=os.getenv("MASTER_PORT", "29500")
)
parser.add_argument(
"--cuda", type=int, default=int(torch.cuda.is_available())
)
parser.add_argument(
"--chunks",
type=int,
default=4,
)
args = parser.parse_args(args)

if args.cuda:
dev_id = args.rank % torch.cuda.device_count()
args.device = torch.device(f"cuda:{dev_id}")
else:
args.device = torch.device("cpu")

# Init process group
backend = "nccl" if args.cuda else "gloo"
dist.init_process_group(
backend=backend,
rank=args.rank,
world_size=args.world_size,
)

run_worker(args)


if __name__ == "__main__":
main()


class LocalTestC10DBwdTest(unittest.TestCase):
def test_c10d_bwd(self):
import random

port = random.randint(29500, 30000)
args = [
"--master_port",
str(port),
]
main(args)

0 comments on commit 11b2c49

Please sign in to comment.