Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure reproducible determinsitc numerics #597

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Oct 2, 2024

resolve #593

grad norms are quite different if running the same config twice
Screenshot 2024-10-01 at 8 34 40 PM

grad norms become exactly the same in repeated runs
Screenshot 2024-10-01 at 8 53 19 PM

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 2, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@@ -48,6 +48,10 @@ def set_determinism(seed: Optional[int]) -> None:
torch.backends.cudnn.benchmark = False
# set Python seed
os.environ["PYTHONHASHSEED"] = str(seed)
torch.use_deterministic_algorithms(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to add this earlier but I found it crashes compile when used with fp8:
"torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: "fill_out" not implemented for 'Float8_e4m3fn'"

Thus, I don't think we want to add this atm.
It works (compiles) if you don't use fp8 but a lot of the need for determinism is to better showcase fp8.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed an issue to start work on getting this resolved:
pytorch/pytorch#137160

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining the context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave a TODO here reminding us to use it after the issue gets fixed.

also tbh I don't know what exactly torch.use_deterministic_algorithms does. Does it cover everything else we are doing here?

torch.use_deterministic_algorithms(True)
# env var for deterministic CuBLAS
# https://github.com/pytorch/pytorch/blob/18525e185e211b3eab44c67a688e5df8396f6f97/torch/__init__.py#L1300
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how you found this arcane variable, but nice job finding it!

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved, but expressly on the condition of removing the:
torch.use_deterministic_algo(True)
as it crashes out during compile with fp8.

The cublas setting though is imo worth landing asap and I verified no issues with compile/fp8.

AWS is doing runs now to re-do the loss curves and I've pinged them this change, but easier if it lands in main.
Thanks for finding this obscure cublas setting to resolve the grad norm disparity!

@@ -48,6 +48,10 @@ def set_determinism(seed: Optional[int]) -> None:
torch.backends.cudnn.benchmark = False
# set Python seed
os.environ["PYTHONHASHSEED"] = str(seed)
torch.use_deterministic_algorithms(True)
# env var for deterministic CuBLAS
# https://github.com/pytorch/pytorch/blob/18525e185e211b3eab44c67a688e5df8396f6f97/torch/__init__.py#L1300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -48,6 +48,10 @@ def set_determinism(seed: Optional[int]) -> None:
torch.backends.cudnn.benchmark = False
# set Python seed
os.environ["PYTHONHASHSEED"] = str(seed)
torch.use_deterministic_algorithms(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave a TODO here reminding us to use it after the issue gets fixed.

also tbh I don't know what exactly torch.use_deterministic_algorithms does. Does it cover everything else we are doing here?

@lessw2020
Copy link
Contributor

re: also tbh I don't know what exactly torch.use_deterministic_algorithms does. Does it cover everything else we are doing here?

deterministic_algo only covers a subset of operations. My understanding is we will still need all the other settings.
Summary is it impacts convolutions generically (1D, 2D, etc) and these specific operations:

torch.nn.ReplicationPad2d
torch.bmm()
torch.Tensor.getitem() when attempting to differentiate a CPU tensor and the index is a list of tensors
torch.Tensor.index_put() with accumulate=False
torch.Tensor.index_put() with accumulate=True when called on a CPU tensor
torch.Tensor.put_()]with accumulate=True when called on a CPU tensor
torch.Tensor.scatter_add_() when called on a CUDA tensor
torch.gather() when called on a CUDA tensor that requires grad
torch.index_add() when called on CUDA tensor
torch.index_select() when attempting to differentiate a CUDA tensor
torch.repeat_interleave() when attempting to differentiate a CUDA tensor
torch.Tensor.index_copy() when called on a CPU or CUDA tensor
torch.Tensor.scatter()when src type is Tensor and called on CUDA tensor
torch.Tensor.scatter_reduce() when reduce='sum' or reduce='mean' and called on CUDA tensor

It will error on some other ops that don't have a deterministic impl, but I think the main takeaway is this is additive rather than replacing the other settings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

reproducable numerics for loss, weights and gradients for single node (8 GPUs)
4 participants