Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Issue In case there are multiple output values and one of them is loss, some user reported the following error: ``` output_grads[i] for i in outputs_with_grads_idxs IndexError: tuple index out of range ... RuntimeError: Failed to run backward stage stage_backward for stage %submod_7 : [ = call_module[target=submod_7](args = (%submod_6, %_inputs), kwargs = {}) Stage output: ('Tensor(torch.Size([100, 20, 4096]), grad=False)', 'Tensor(torch.Size([100, 4096]), grad=False)', 'Tensor(torch.Size([100, 4096]), grad=False)', 'Tensor(torch.Size([]), grad=True)', 'Tensor(torch.Size([100]), grad=False)', 'Tensor(torch.Size([100]), grad=False)') Output gradient: ('None',) Input: ['Tensor(torch.Size([100, 20, 4096]), grad=True)', 'Tensor(torch.Size([100, 20, 4096]), grad=False)', 'Tensor(torch.Size([100]), grad=False)', 'Tensor(torch.Size([100]), grad=False)'] ``` Note this part: `Output gradient: ('None',)` I can repro the issue in local_test_c10d_bwd.py, if I change the output to: ``` - return {"loss": loss} + return {"logits": x, "loss": loss} ``` ## Cause The above issue is caused by the fixed setting in the else case: ``` # (None,) is for `stage_backward` signature bwd_kwargs["output_grads"] = ( grads if len(grads) > 0 else (None,) ) ``` This tuple `(None,)` should have variable length depending on the output. ## Fix Only update `bwd_kwargs["output_grads"]` when we have actually received gradients; otherwise, use the tuple prepared during IR phase, i.e. `bwd_node.kwargs["output_grads"]`, which may look like `(None, None)` if there are two outputs.
- Loading branch information