Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compile] Understand why FSDP2 saves both SDPA out and wo in for bwd #610

Open
awgu opened this issue Oct 11, 2024 · 0 comments
Open

[Compile] Understand why FSDP2 saves both SDPA out and wo in for bwd #610

awgu opened this issue Oct 11, 2024 · 0 comments
Labels
question Further information is requested

Comments

@awgu
Copy link
Contributor

awgu commented Oct 11, 2024

With FSDP2 and transformer block compile, torch.compile saves both the SDPA output and the contiguous transposed tensor for backward:

output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)

However, with simpleFSDP with full model compile, torch.compile only saves the SDPA output. This means that FSDP2 saves an extra (bs, seq_len, dim) tensor per transformer block.

Traditionally, SDPA output is required for SDPA backward, and the input to wo is required for the wo backward. However, it may be profitable memory-wise to recompute one from the other (e.g. recompute SDPA output from undo-ing the transpose of wo input).

One question is why the activations saved for backward differ between simple FSDP with full model compile vs. FSDP2 with transformer block compile.

@tianyu-l tianyu-l added the question Further information is requested label Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants