Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reproducable numerics for loss, weights and gradients for single node (8 GPUs) #593

Open
weifengpy opened this issue Oct 1, 2024 · 2 comments · May be fixed by #597
Open

reproducable numerics for loss, weights and gradients for single node (8 GPUs) #593

weifengpy opened this issue Oct 1, 2024 · 2 comments · May be fixed by #597
Labels
enhancement New feature or request

Comments

@weifengpy
Copy link
Contributor

weifengpy commented Oct 1, 2024

by default, torchtitan use FSDP2 mixed precision (param_dtype=bfloat16, reduce_dtype=float32)

for low-precision dtypes (float8 and int8), it's nature to compare loss curve with bfloat16 and see how well they match. (also a good idea to compare weights norm and gradients norm)

for bfloat16 itself, multiple runs will yield different loss curves and the undeterminism should be understood and documented (say NCCL gradient reduction, attention, seed). Otherwise it's hard to understand if numeric differences are coming from low-precision dtypes

I plotted gradient norms, loss = sum(model.parameters.grad), using llama3-8b with 8 GPUs with deterministic model init and deterministic data loader

for bfloat16, gradients are quite different in repeated runs
Screenshot 2024-09-30 at 5 15 08 PM

turning off gradient norm clipping helps a lot, but could not explain all of the divergence
Screenshot 2024-09-30 at 5 17 06 PM

filing the issue here and hopefully it can be a good candidate for what's next

@weifengpy weifengpy changed the title reproducable numerics for loss, weights and gradients reproducable numerics for loss, weights and gradients for single node (8 GPUs) Oct 1, 2024
@awgu
Copy link
Contributor

awgu commented Oct 1, 2024

IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?

I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

@weifengpy
Copy link
Contributor Author

IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?

I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

good call out!

@weifengpy weifengpy linked a pull request Oct 2, 2024 that will close this issue
@yf225 yf225 added the enhancement New feature or request label Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants