Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8Linear saves new parameters in ckpt and I cannot load the saved ckpt #651

Open
goldhuang opened this issue Oct 24, 2024 · 5 comments
Open
Labels
bug Something isn't working

Comments

@goldhuang
Copy link

[rank0]: Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
[rank0]: Traceback (most recent call last): (RANK 2)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 248, in all_gather
[rank0]:     result = map_fun()
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 228, in read_data
[rank0]:     all_reads = storage_reader.read_data(final_local_plan, planner)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 655, in read_data
[rank0]:     torch.load(
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/serialization.py", line 1359, in load
[rank0]:     raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
[rank0]: _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m. 
[rank0]: 	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
[rank0]: 	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
[rank0]: 	WeightsUnpickler error: Unsupported global: GLOBAL torchao.float8.fsdp_utils.WeightWithDelayedFloat8CastTensor was not an allowed global by default. Please use `torch.serialization.add_safe_globals([WeightWithDelayedFloat8CastTensor])` to allowlist this global if you trust this class/function.

I'm using

        planner = DefaultLoadPlanner(allow_partial_load=not strict)
        torch.distributed.checkpoint.load(state_dict=state, checkpoint_id=path, planner=planner)

to load the distributed ckpt.

@tianyu-l tianyu-l added the bug Something isn't working label Oct 25, 2024
@weifengpy
Copy link
Contributor

Hi @goldhuang , could I have the training config for delayed scaling? I have not promoted delayed scaling but want to take this chance to make it right

@goldhuang
Copy link
Author

@weifengpy

        convert_to_float8_training(model, 
                                    config=Float8LinearConfig(
                                        enable_fsdp_float8_all_gather=True,
                                        cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
                                        cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
                                        cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
                                        enable_pre_and_post_forward=False,
                                        enable_amax_init=False,
                                        pad_inner_dim=True,
                                    ),
                                    module_filter_fn=module_filter_fn,
                                    )

I tried to use DYNAMIC at the very beginning, but found an issue with HSDP. I created an issue in torchao repo pytorch/ao#1086. Then I changed to use DELAYED.

This issue is only happening with enable_fsdp_float8_all_gather=True according to the code. Currently I'm using enable_fsdp_float8_all_gather=False which makes my life much easier from several aspects.

BTW, I find enable_fsdp_float8_all_gather=True is not obviously faster on 8 nodes (64 GPUs). When can I see the benefit by setting enable_fsdp_float8_all_gather=True?

@weifengpy
Copy link
Contributor

BTW, I find enable_fsdp_float8_all_gather=True is not obviously faster on 8 nodes (64 GPUs). When can I see the benefit by setting enable_fsdp_float8_all_gather=True?

for 128 GPUs, enable_fsdp_float8_all_gather=False shows x1.42 speed up. enable_fsdp_float8_all_gather=True shows x1.50 speed up. So it's true most gains are coming from float8 comput itself. https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359

thanks for sharing your recipe. I am giving it a try

@goldhuang
Copy link
Author

goldhuang commented Oct 25, 2024

@weifengpy Did you guys try with more than 128 GPUs? Like 1024 GPUs?

@weifengpy
Copy link
Contributor

@weifengpy Did you guys try with more than 128 GPUs? Like 1024 GPUs?

for 1D FSDP, 128 GPUs are my largest test. I have not tested on 1D FSDP on 1024 GPUs. At that scale probably need HSDP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants