Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random OOM and crashes #8216

Open
alexanderswerdlow opened this issue Oct 4, 2024 · 1 comment
Open

Random OOM and crashes #8216

alexanderswerdlow opened this issue Oct 4, 2024 · 1 comment

Comments

@alexanderswerdlow
Copy link

alexanderswerdlow commented Oct 4, 2024

❓ Questions and Help

I've found that I'm unable to train more than ~20-80K steps without a crash and it's difficult to figure out how to debug this. In a typical PyTorch training run, I would get a clear OOM message at a particular line, or any other error and this would be printed to log/console.

However, about half the time, my training run simply exits with no message on any rank, and the other half the time it's clearly due to memory with a "Resource Exhausted" message. The issue is it's not clear where this new allocation happens (I have a fairly standard decoder based transformer, not even any eval batches, and I'm not using any eager modes). I tried to switch to nightly to get a recent dataloader memory fix, but that doesn't seem to fix it.

I know there are many flags that can be used for debugging, but it's unclear exactly which ones can be used during training without a large performance hit. I've done all the suggested steps including profiling, and making sure there isn't re-compiliation happening, etc. Perhaps it would be good to clarify the impact of the flags somewhere to make it clear which are safe—and any other advice on how to debug this would be great!

Also, I should note this occurs with SPMD multi-node training, I have not spent time testing other modes, but this has happened with between 2 and 8 TPUv4 VMs, both in DDP-like configurations and several other mesh configurations

@JackCaoG
Copy link
Collaborator

JackCaoG commented Oct 9, 2024

Hmm, good question. It seems like there are 2 problems here

  1. training code OOM
  2. error discovery is difficult(where the OOM happened for example)

For 2 it is because from XLA perspective it is executing a compiled program and in the middle of that it got OOM(assuming there is no recompilation, the OOM is runtime OOM). It is hard for PyTorch/XLA to map this OOM event back to the specified python line. However the silent exiting case seems weird, when I force the runtime error in SPMD it always throw a error so it is hard for me to think of what happened there.

Regarding the code OOM, try

watch -n0 tpu-info

you should already have the tpu-info installed when you install the libtpu if you are using nightly. Try to see if the memory usage slowly going up across runs, I am wondering if there are some small tensors slowly accumulated in the HBM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants