Skip to content

Commit

Permalink
Set output_grads correctly (#840)
Browse files Browse the repository at this point in the history
## Issue

In case there are multiple output values and one of them is loss, some
user reported the following error:
```
    output_grads[i] for i in outputs_with_grads_idxs
IndexError: tuple index out of range
...
RuntimeError: 
        Failed to run backward stage stage_backward for stage %submod_7 : [ = call_module[target=submod_7](args = (%submod_6, %_inputs), kwargs = {})
        Stage output: ('Tensor(torch.Size([100, 20, 4096]), grad=False)', 'Tensor(torch.Size([100, 4096]), grad=False)', 'Tensor(torch.Size([100, 4096]), grad=False)', 'Tensor(torch.Size([]), grad=True)', 'Tensor(torch.Size([100]), grad=False)', 'Tensor(torch.Size([100]), grad=False)')
        Output gradient: ('None',)
        Input: ['Tensor(torch.Size([100, 20, 4096]), grad=True)', 'Tensor(torch.Size([100, 20, 4096]), grad=False)', 'Tensor(torch.Size([100]), grad=False)', 'Tensor(torch.Size([100]), grad=False)']
```
Note this part: `Output gradient: ('None',)`

I can repro the issue in local_test_c10d_bwd.py, if I change the output
to:
```
-        return {"loss": loss}
+        return {"logits": x, "loss": loss}
```

## Cause
The above issue is caused by the fixed setting in the else case:
```
                # (None,) is for `stage_backward` signature
                bwd_kwargs["output_grads"] = (
                    grads if len(grads) > 0 else (None,)
                )
```
This tuple `(None,)` should have variable length depending on the
output.

## Fix
Only update `bwd_kwargs["output_grads"]` when we have actually received
gradients; otherwise, use the tuple prepared during IR phase, i.e.
`bwd_node.kwargs["output_grads"]`, which may look like `(None, None)` if
there are two outputs.
  • Loading branch information
kwen2501 authored Jul 14, 2023
1 parent a6f9997 commit e60ebea
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
11 changes: 7 additions & 4 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,13 @@ def backward_maybe_with_nosync(bwd_kwargs: Dict, is_last_chunk: bool):
bwd_kwargs["stage_output"],
bwd_kwargs["input_values"],
) = fwd_cache.pop(bwd_chunk)
# (None,) is for `stage_backward` signature
bwd_kwargs["output_grads"] = (
grads if len(grads) > 0 else (None,)
)
# Fill actual gradients received for outputs
# If nothing received, as in the case of last stage, then we
# would use the default `output_grads` prepared in the IR phase,
# i.e. from `bwd_node.kwargs`. For example, it may look like
# this if there are two outputs: ('None', 'None')
if len(grads) > 0:
bwd_kwargs["output_grads"] = grads

# `stage_backward` node does not have `args`, only `kwargs`
grads_input = backward_maybe_with_nosync(
Expand Down
4 changes: 2 additions & 2 deletions test/local_test_c10d_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, x, target):
x = self.lin(x)
x = torch.relu(x)
loss = self.mse_loss(x, target)
return {"loss": loss}
return {"logits": x, "loss": loss}


def run_worker(args):
Expand Down Expand Up @@ -74,7 +74,7 @@ def run_worker(args):
# Last rank checks result
if args.rank == args.world_size - 1:
ref_out = ec(ec_x, target)
torch.testing.assert_close(out["loss"], ref_out["loss"])
torch.testing.assert_close(out, ref_out)
print(
f"equivalence test passed, loss = {out['loss']}, ref loss = {ref_out['loss']}"
)
Expand Down

0 comments on commit e60ebea

Please sign in to comment.