You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've found that I'm unable to train more than ~20-80K steps without a crash and it's difficult to figure out how to debug this. In a typical PyTorch training run, I would get a clear OOM message at a particular line, or any other error and this would be printed to log/console.
However, about half the time, my training run simply exits with no message on any rank, and the other half the time it's clearly due to memory with a "Resource Exhausted" message. The issue is it's not clear where this new allocation happens (I have a fairly standard decoder based transformer, not even any eval batches, and I'm not using any eager modes). I tried to switch to nightly to get a recent dataloader memory fix, but that doesn't seem to fix it.
I know there are many flags that can be used for debugging, but it's unclear exactly which ones can be used during training without a large performance hit. I've done all the suggested steps including profiling, and making sure there isn't re-compiliation happening, etc. Perhaps it would be good to clarify the impact of the flags somewhere to make it clear which are safe—and any other advice on how to debug this would be great!
Also, I should note this occurs with SPMD multi-node training, I have not spent time testing other modes, but this has happened with between 2 and 8 TPUv4 VMs, both in DDP-like configurations and several other mesh configurations
The text was updated successfully, but these errors were encountered:
Hmm, good question. It seems like there are 2 problems here
training code OOM
error discovery is difficult(where the OOM happened for example)
For 2 it is because from XLA perspective it is executing a compiled program and in the middle of that it got OOM(assuming there is no recompilation, the OOM is runtime OOM). It is hard for PyTorch/XLA to map this OOM event back to the specified python line. However the silent exiting case seems weird, when I force the runtime error in SPMD it always throw a error so it is hard for me to think of what happened there.
Regarding the code OOM, try
watch -n0 tpu-info
you should already have the tpu-info installed when you install the libtpu if you are using nightly. Try to see if the memory usage slowly going up across runs, I am wondering if there are some small tensors slowly accumulated in the HBM.
❓ Questions and Help
I've found that I'm unable to train more than ~20-80K steps without a crash and it's difficult to figure out how to debug this. In a typical PyTorch training run, I would get a clear OOM message at a particular line, or any other error and this would be printed to log/console.
However, about half the time, my training run simply exits with no message on any rank, and the other half the time it's clearly due to memory with a "Resource Exhausted" message. The issue is it's not clear where this new allocation happens (I have a fairly standard decoder based transformer, not even any eval batches, and I'm not using any eager modes). I tried to switch to nightly to get a recent dataloader memory fix, but that doesn't seem to fix it.
I know there are many flags that can be used for debugging, but it's unclear exactly which ones can be used during training without a large performance hit. I've done all the suggested steps including profiling, and making sure there isn't re-compiliation happening, etc. Perhaps it would be good to clarify the impact of the flags somewhere to make it clear which are safe—and any other advice on how to debug this would be great!
Also, I should note this occurs with SPMD multi-node training, I have not spent time testing other modes, but this has happened with between 2 and 8 TPUv4 VMs, both in DDP-like configurations and several other mesh configurations
The text was updated successfully, but these errors were encountered: