Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about torch.compile has better throughput with 128-GPUs than 8-GPUs #619

Open
dz1iang opened this issue Oct 15, 2024 · 3 comments
Assignees
Labels
question Further information is requested

Comments

@dz1iang
Copy link

dz1iang commented Oct 15, 2024

Thank you for publishing the paper. I hope to get your answers to the following questions.:
Normally, the training speed will decline as the number of GPUs increases. However, in the paper, with the torch.compile technology, the speed with 128 GPUs is better than that with 8 GPUs.
compile

@tianyu-l
Copy link
Contributor

Thank you for the question. This is a great observation!

After some initial investigation, I think the difference might be caused by the underlying hardware differences.

  • We ran the experiments on internal training clusters, whose locations span multiple regions. I noticed that all jobs except the 128-GPU torch.compile one are run in the same region.
  • So it is likely that the unexpected increase in throughput might be due to the hardware difference. E.g. the one in the separate region may be on a faster variant of H100 GPUs. This particular experiment should be redone to validate the hypothesis.
  • I checked that for other experiments, the machines were from the same regions, so the comparisons should be fair.

Will make an update once we have the new result.

@tianyu-l tianyu-l self-assigned this Oct 15, 2024
@tianyu-l tianyu-l added the question Further information is requested label Oct 15, 2024
@dz1iang
Copy link
Author

dz1iang commented Oct 16, 2024

Thank you for the question. This is a great observation!

After some initial investigation, I think the difference might be caused by the underlying hardware differences.

  • We ran the experiments on internal training clusters, whose locations span multiple regions. I noticed that all jobs except the 128-GPU torch.compile one are run in the same region.
  • So it is likely that the unexpected increase in throughput might be due to the hardware difference. E.g. the one in the separate region may be on a faster variant of H100 GPUs. This particular experiment should be redone to validate the hypothesis.
  • I checked that for other experiments, the machines were from the same regions, so the comparisons should be fair.

Will make an update once we have the new result.

Thank you for your answer.
Additionally, I hope you can answer the following questions:
Under pure FSDP, the speedup ratio from 8 GPUs to 128 GPUs is 90%. Is there still room for optimization? Since the hardware paired with H100 is relatively good, is there a chance for this scale of speedup ratio to reach 95%?

@tianyu-l
Copy link
Contributor

That's another good question.

We haven't done any specific studies on weak scaling which you cared about -- I need to understand more on where this slowdown come from before answering your question with confidence. Are you aware of any studies on this topic?

I can think of several slowdowns when scaling from 8 GPU (single-node) to 128 GPU, for eager/pure FSDP:

  1. Most prominently, the communication overhead from FSDP becomes much larger. Although most communication in FSDP should be overlapped by computation, some may still be exposed more easily (e.g. the very first all-gather from each iteration). This likely can be tuned by doing more careful FSDP wrapping, so I'd say there's some room for improvement.
  2. Data loading overhead (CPU) -- the data loader would need to iterate 128/8 = 16 times longer to make up a batch. I think this overhead should be hidden by GPU computation anyways.
  3. The extra all-reduce to sync, e.g. the loss, when doing metric logging. Given its small volume, this shouldn't contribute much.

On the other hand there should be some (although maybe not much) savings when scaling up. E.g. each rank now holds a smaller fraction of parameter, to be updated by the optimizer.

More can be said if we look at the profile traces and compare. In general, this seems to be a complex topic, and the "speedup ratio" can be task-specific. E.g. what if we train a 70B model, would the ratio be higher or lower? This is not to mention the variation each run/iteration could have.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants