Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to Qwen2-0.5B and Qwen2-1.5B. #1247

Merged
merged 20 commits into from
Aug 6, 2024

Conversation

fyabc
Copy link
Contributor

@fyabc fyabc commented Jul 30, 2024

Context

This PR:

  • Add support to Qwen2-1.5B and Qwen2-0.5B, with full-parameter & lora finetune recipes.
  • Change default finetune hyperparameters for Qwen2 recipes.
  • Fix some bugs in tie_word_embeddings model loading & saving.
  • Refactor TiedEmbeddingTransformerDecoder to align with the default TransformerDecoder.

Loss curves of Qwen2-0.5B and 1.5B full-parameter finetune on alpaca clean dataset have been uploaded to wandb (loss comparison curves are shown below)

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Loss Curves

All loss curves are smoothed with smooth_ratio=0.6.

loss-compare-0.5b-smooth0.6
wandb link: https://api.wandb.ai/links/fyabc-123/jxbvzezh

loss-compare-1.5b-smooth0.6
wandb link: https://api.wandb.ai/links/fyabc-123/w2lqv1v6

Copy link

pytorch-bot bot commented Jul 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1247

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 856555b with merge base 9fd5d01 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 30, 2024
@@ -0,0 +1,75 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Qwen2 0.5B model
#
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel a bit like we need to rethink tune ls. Here's the output just now:

RECIPE                                   CONFIG                                  
full_finetune_single_device              llama2/7B_full_low_memory               
                                         code_llama2/7B_full_low_memory          
                                         llama3/8B_full_single_device            
                                         llama3_1/8B_full_single_device          
                                         mistral/7B_full_low_memory              
                                         phi3/mini_full_low_memory               
                                         qwen2/7B_full_low_memory                
full_finetune_distributed                llama2/7B_full                          
                                         llama2/13B_full                         
                                         llama3/8B_full                          
                                         llama3_1/8B_full                        
                                         llama3/70B_full                         
                                         llama3_1/70B_full                       
                                         mistral/7B_full                         
                                         gemma/2B_full                           
                                         gemma/7B_full                           
                                         phi3/mini_full                          
                                         qwen2/7B_full                           
lora_finetune_single_device              llama2/7B_lora_single_device            
                                         llama2/7B_qlora_single_device           
                                         code_llama2/7B_lora_single_device       
                                         code_llama2/7B_qlora_single_device      
                                         llama3/8B_lora_single_device            
                                         llama3_1/8B_lora_single_device          
                                         llama3/8B_qlora_single_device           
                                         llama3_1/8B_qlora_single_device         
                                         llama2/13B_qlora_single_device          
                                         mistral/7B_lora_single_device           
                                         mistral/7B_qlora_single_device          
                                         gemma/2B_lora_single_device             
                                         gemma/2B_qlora_single_device            
                                         gemma/7B_lora_single_device             
                                         gemma/7B_qlora_single_device            
                                         phi3/mini_lora_single_device            
                                         phi3/mini_qlora_single_device           
                                         qwen2/7B_lora_single_device             
lora_dpo_single_device                   llama2/7B_lora_dpo_single_device        
lora_dpo_distributed                     llama2/7B_lora_dpo                      
lora_finetune_distributed                llama2/7B_lora                          
                                         llama2/13B_lora                         
                                         llama2/70B_lora                         
                                         llama3/70B_lora                         
                                         llama3_1/70B_lora                       
                                         llama3/8B_lora                          
                                         llama3_1/8B_lora                        
                                         mistral/7B_lora                         
                                         gemma/2B_lora                           
                                         gemma/7B_lora                           
                                         phi3/mini_lora                          
                                         qwen2/7B_lora                           
lora_finetune_fsdp2                      llama2/7B_lora                          
                                         llama2/13B_lora                         
                                         llama2/70B_lora                         
                                         llama2/7B_qlora                         
                                         llama2/70B_qlora                        
generate                                 generation                              
eleuther_eval                            eleuther_evaluation                     
quantize                                 quantization                            
qat_distributed                          llama2/7B_qat_full                      
                                         llama3/8B_qat_full    

Imo this isn't scaling well as we add more and more cool models and support configs for new techniques (imagine how many we'll have for multimodal!). It's getting unweildy.

Not sure if this is already on your radar @joecummings @ebsmothers but I have an idea or two (one radical, one not-so-radical) to address- happy to put an RFC up if there's consensus?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a good point. I know @joecummings had some ideas on this so will defer to him, but I think a quick RFC on how to scale tune ls better would definitely be helpful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's chat - this is definitely on my radar. Glad you caught it, too.

@@ -68,7 +68,7 @@ def qwen2_hf_to_tune(

for key, value in state_dict.items():
if (
tie_word_embeddings and QWEN2_TIED_KEY not in key
tie_word_embeddings and QWEN2_TIED_KEY in key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this the first time, but is there a reason the checkpointing logic for tied embeddings in Qwen2 should be different from Gemma? Ref #1168, #1169

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, happy to have this filed as a refactor cleanup to start grouping some of these together in a "TIED_MODEL" key or something.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I'm missing something, but do we need a refactor here? I thought we ended up not needing any special checkpointing logic for Gemma, so it's just a matter of applying the same changes as in https://github.com/pytorch/torchtune/pull/1168/files - removing qwen2/_convert_weights.py, and remove any special logic in the checkpointers (here, for example), so it goes through the default model save/load path.

Happy to do this in a follow up, anyway, so this doesn't get blocked?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! A few small comments here but this looks pretty good to me. Also the wandb link in the PR summary is broken for me, would definitely be interested to see the comparison with HF loss curves if that's possible

@@ -0,0 +1,75 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Qwen2 0.5B model
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a good point. I know @joecummings had some ideas on this so will defer to him, but I think a quick RFC on how to scale tune ls better would definitely be helpful

recipes/configs/qwen2/0.5B_full_low_memory.yaml Outdated Show resolved Hide resolved
torchtune/models/qwen2/_model_builders.py Outdated Show resolved Hide resolved
@fyabc
Copy link
Contributor Author

fyabc commented Aug 2, 2024

Thanks for the PR! A few small comments here but this looks pretty good to me. Also the wandb link in the PR summary is broken for me, would definitely be interested to see the comparison with HF loss curves if that's possible

Hi, I have updated loss curves between torchtune and HF in description.

@ebsmothers
Copy link
Contributor

Thanks for the PR! A few small comments here but this looks pretty good to me. Also the wandb link in the PR summary is broken for me, would definitely be interested to see the comparison with HF loss curves if that's possible

Hi, I have updated loss curves between torchtune and HF in description.

Thanks! Also bumping this comment @fyabc. Mainly I don't think we should need stuff like PagedAdamW or other memory optimizations that may slow down training for a 0.5B (or even 1.5B) model -- I imagine if we scrap these we can still support the same set of hardware but train faster. Let me know if this makes sense to you

@fyabc
Copy link
Contributor Author

fyabc commented Aug 3, 2024

Thanks for the PR! A few small comments here but this looks pretty good to me. Also the wandb link in the PR summary is broken for me, would definitely be interested to see the comparison with HF loss curves if that's possible

Hi, I have updated loss curves between torchtune and HF in description.

Thanks! Also bumping this comment @fyabc. Mainly I don't think we should need stuff like PagedAdamW or other memory optimizations that may slow down training for a 0.5B (or even 1.5B) model -- I imagine if we scrap these we can still support the same set of hardware but train faster. Let me know if this makes sense to you

Thank you for your suggestions! I will update the related 0.5B and 1.5B recipes.

@joecummings
Copy link
Contributor

Thank you for your suggestions! I will update the related 0.5B and 1.5B recipes.

Hi @fyabc! Anything we can do to help finish this up? We'd love to feature this on our README and post about it on our Discord channel, as well.

@fyabc
Copy link
Contributor Author

fyabc commented Aug 6, 2024

Thank you for your suggestions! I will update the related 0.5B and 1.5B recipes.

Hi @fyabc! Anything we can do to help finish this up? We'd love to feature this on our README and post about it on our Discord channel, as well.

Hi, I have updated this PR to resolve review comments.

All *_low_memory recipes have been renamed to *_single_device, and model builder typing annotations have changed to more exact types.

I think this CR is ready to merge, and feel free for more suggestions.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work as always :)

@SalmanMohammadi SalmanMohammadi merged commit 5c7246e into pytorch:main Aug 6, 2024
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants