Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2-7B model to torchtune. #1143

Merged
merged 12 commits into from
Jul 30, 2024
Merged

Add Qwen2-7B model to torchtune. #1143

merged 12 commits into from
Jul 30, 2024

Conversation

fyabc
Copy link
Contributor

@fyabc fyabc commented Jul 5, 2024

Context

This PR add support to Qwen2-7B, including model components and generation / full-parameter finetune / LoRA finetune recipes.

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

  • Add support to Qwen2
  • Fix code format

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Implementation Details & Notes

  • transformer decoder: In Qwen2-0.5B and Qwen2-1.5B, the output projection is replaced with token embedding weights (tie_word_embeddings = True), so we add a new torchtune.models.qwen2.transformer.Qwen2TransformerDecoder module to support this feature (set output=None).
  • tokenizer: torchtune.models.qwen2._tokenizer.Qwen2Tokenizer construct a "fast" tokenizer backed by HuggingFace's tokenizers library (follow official transformers implementation).
    • The official implemtation torchtune.data._chat_formats.ChatMLFormat will add an extra <|IM_END|> tag at the end of the last assistant message in generation (see this discussion), so we override the tokenize_message() method to fix this problem (will not add <|IM_END|> and EOS token if messages[-1].role == 'assistant' and not messages[-1].content)
  • verify the correctness: We compare the output logits and generated responses between torchtune and official transformer implementation under the same prompts and infer hyperparameters.
  • QLoRA: Since quantized torchtune.modules.peft.lora.LoRALinear does not support bias yet, QLoRA finetune recipes a not added into this PR (will support in future).
  • More Qwen2 models (0.5B, 1.5B, etc.) will be added in future.

Copy link

pytorch-bot bot commented Jul 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1143

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d0671ee with merge base 2dc11d9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @fyabc!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@pbontrager
Copy link
Contributor

Thank you for this great contribution. Could you provide a bit more information in the PR Description around some of the decisions you made when implementing Qwen2? I'm interested in particular why the TransformerDecoder wasn't sufficient for implementing Qwen. I am also interesting in understanding better how you tested and verified that this implementation is correct. Looking forward to getting this into our library and thank you again.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 8, 2024
@fyabc
Copy link
Contributor Author

fyabc commented Jul 8, 2024

Thank you for this great contribution. Could you provide a bit more information in the PR Description around some of the decisions you made when implementing Qwen2? I'm interested in particular why the TransformerDecoder wasn't sufficient for implementing Qwen. I am also interesting in understanding better how you tested and verified that this implementation is correct. Looking forward to getting this into our library and thank you again.

Hi, I have updated some implementation details in the description, please check it again.

@ebsmothers
Copy link
Contributor

Hi @fyabc, thanks so much for this PR. A couple initial questions:

  1. The transformer decoder class looks pretty similar to the one we have for Gemma, which also ties the embedding weights. I see there is also an optional post-embedding norm in your implementation, but other than that are these two the same? Basically I am wondering if we should provide a single general TiedEmbeddingTransformerDecoder (or something like that) that can be used by both Gemma and Qwen2 models.
  2. What are the specific differences between Qwen2 RoPE and Llama2/Llama3 RoPE?
  3. What's the underlying tokenizer used by Qwen2? Is it just GPT-2 BPE tokenization? We currently do not have a dependency in torchtune on the tokenizers library, so adding this may require a bit more discussion.

@fyabc
Copy link
Contributor Author

fyabc commented Jul 9, 2024

Hi @fyabc, thanks so much for this PR. A couple initial questions:

  1. The transformer decoder class looks pretty similar to the one we have for Gemma, which also ties the embedding weights. I see there is also an optional post-embedding norm in your implementation, but other than that are these two the same? Basically I am wondering if we should provide a single general TiedEmbeddingTransformerDecoder (or something like that) that can be used by both Gemma and Qwen2 models.
  2. What are the specific differences between Qwen2 RoPE and Llama2/Llama3 RoPE?
  3. What's the underlying tokenizer used by Qwen2? Is it just GPT-2 BPE tokenization? We currently do not have a dependency in torchtune on the tokenizers library, so adding this may require a bit more discussion.

Thank you for your suggestions!

  1. Yes, Qwen2 decoder is the same as Gemma decoder, I think a general TiedEmbeddingTransformerDecoder class is a good solution. (or change TransformerDecoder directly?)
  2. Like Phi3, Qwen2 RoPE is not numerically equivalent to the Llama2/Llama3 RoPE in bf16 training (Qwen2 compute rotary embedding in input dtype, Llama2/3 will convert to fp32).
  3. Yes, Qwen2 tokenizer is based on GPT-2 BPE tokenization, and I am implementing another version to remove the dependency on tokenizers (derived from Qwen2Tokenizer).

@ebsmothers
Copy link
Contributor

  1. Yes, Qwen2 decoder is the same as Gemma decoder, I think a general TiedEmbeddingTransformerDecoder class is a good solution. (or change TransformerDecoder directly?)
  2. Like Phi3, Qwen2 RoPE is not numerically equivalent to the Llama2/Llama3 RoPE in bf16 training (Qwen2 compute rotary embedding in input dtype, Llama2/3 will convert to fp32).
  3. Yes, Qwen2 tokenizer is based on GPT-2 BPE tokenization, and I am implementing another version to remove the dependency on tokenizers (derived from Qwen2Tokenizer).

Thanks @fyabc for your responses! For (1) let's go with TiedEmbeddingTransformerDecoder, no need to change TransformerDecoder (since that change would impact a lot more models). If you want to also point Gemma to this new class that'd be great, but if not we can also do it as a follow-up.

Thanks for clarifying on (2). @joecummings has also been looking at our RoPE implementation so can weigh in with any thoughts he has.

(3) sounds great! Please let me know if you need any pointers from us on adding a new tokenizer; otherwise just ping on here when you are ready for a more detailed review of everything.

@fyabc
Copy link
Contributor Author

fyabc commented Jul 16, 2024

@ebsmothers Hi, I have updated this PR based on your suggestions:

  1. Add a new torchtune.modules.transformer.TiedEmbeddingTransformerDecoder class to handle models with tied embedding weight. (I have not refactored Gemma implementation)
  2. Rewrite Qwen2Tokenizer, remove the dependency on tokenizers library.

@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there generally interest in the 1.5B model too? If so, is it worth adding?

Perhaps as a more general point, I noticed in the Qwen2 technical report that weight tying is only used for the 0.5B and 1.5B models, right? Would users want to configure weight tying in the larger models where weight tying is enabled by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will add other size of models (0.5B / 1.5B / 72B).
In addition, whether weight tying is enabled must be consistent between model initialization (in _model_builders.py here) and weight loading (in _convert_weights.py, reading configuration from config.json file), so I think we should fix the weight tying configuration in default model providers to avoid users from bugs.
Based on the experience of the open source community, most users have no need to modify weight tying when tuning Qwen models. For expert users, they can directly use torchtune.models.qwen2.qwen2() to provide correct hyperparameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insight!

This brings me to a broader point and feel free to ignore me if this sounds crazy. If we don't think there would be interest from the community in the 0.5B/1.5B models, which do use weight tying, would we be able to simplify this PR a bunch by just using the original TransformerDecoder for the 7B and 72B models, which don't need weight tying?

If you think it's worth offering this flexibility I think the added decoder is fine, and it'll be nice to share it with Gemma. (also cc @joecummings @ebsmothers here).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we've made it pretty far in this PR so I think best course of action is to land this with the weight tying; however, for future reference @SalmanMohammadi is absolutely correct that we prefer to start by landing only what is needed, which - in this case - would just be regular Qwen2 without weight tying.

This code will make it easy to add the small models, which is a huge benefit for our memory constrained users, so I'm excited for the quick follow up with those configs :)

@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure @ebsmothers will have more thoughts on whether this is necessary here - but I'll share an excellent comment from him about the testing process in torchtune #840 (comment).

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this is coming together very nicely and I appreciate your hard work to add a great new model to torchtune! Most of the comments I had are nits, but I did have one last thing to say on testing.

The individual unit tests are great for proving correctness of individual models, but a E2E run of this recipe on alpaca and comparing the loss curves between this and HuggingFace is a fantastic sanity check to make sure the torchtune users are getting the expected experience. It's also useful to output both loss curves to a Weights & Biases run so we can visually see that they're similar.

@fyabc
Copy link
Contributor Author

fyabc commented Jul 17, 2024

Overall, this is coming together very nicely and I appreciate your hard work to add a great new model to torchtune! Most of the comments I had are nits, but I did have one last thing to say on testing.

The individual unit tests are great for proving correctness of individual models, but a E2E run of this recipe on alpaca and comparing the loss curves between this and HuggingFace is a fantastic sanity check to make sure the torchtune users are getting the expected experience. It's also useful to output both loss curves to a Weights & Biases run so we can visually see that they're similar.

Thank you for you suggestions! I will run all recipes and report the results and comparisons between HuggingFace.

@SalmanMohammadi
Copy link
Collaborator

I've also added a few comments (but feel free to take these with a pinch of salt since I don't bring the experience that the maintainers do!).

This PR is great and you've generally fit everything into the codebase well. I've raised some general points about whether it's worth generalising the fast BPE tokenizer that you've implemented, and whether there's interest in the smaller weight-tied models, or a need for flexibility around weight tying in the larger models; I'd love to get some input on these points.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm late to the party - thank you for all the excellent work here! This was a huge effort to add an important model to the library, and you've integrated it into the codebase well.

One general comment I had on tied embeddings - do you know if users will ever want to use Qwen2 without tied embeddings? I noticed on HF this defaults to False. Asking because assuming it's always tied will simplify a lot of the builder logic and typing. For Gemma for example, we took a stance that it will always be tied, but I don't know enough about Qwen to make that call.

On generalizing the fast BPE tokenizer here - I would vote to keep it in Qwen until another model needs to use the same GPT-2 BPE tokenizer (cc @SalmanMohammadi)

torchtune/models/qwen2/_tokenizer.py Show resolved Hide resolved
torchtune/models/qwen2/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/qwen2/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/qwen2/_tokenizer.py Show resolved Hide resolved
torchtune/models/qwen2/_tokenizer.py Show resolved Hide resolved
torchtune/models/qwen2/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/qwen2/_tokenizer.py Show resolved Hide resolved
@joecummings
Copy link
Contributor

Hi @fyabc! Thanks for your patience as we reviewed this PR - we're still really excited to get the Qwen2 model into torchtune.

Would you mind taking another quick look at the comments by @SalmanMohammadi and @RdoubleA ? Let me know how else I can help.

@fyabc
Copy link
Contributor Author

fyabc commented Jul 23, 2024

Hi @fyabc! Thanks for your patience as we reviewed this PR - we're still really excited to get the Qwen2 model into torchtune.

Would you mind taking another quick look at the comments by @SalmanMohammadi and @RdoubleA ? Let me know how else I can help.

Sorry, I am busy on other work last week and forgot to reply. I will update my code on these comments in few days.

@joecummings
Copy link
Contributor

Sorry, I am busy on other work last week and forgot to reply. I will update my code on these comments in few days.

No worries at all! Please do let us know if there's anyway we can help support you further - we appreciate all the work you've already done here

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 85.98326% with 67 lines in your changes missing coverage. Please review.

Project coverage is 72.00%. Comparing base (7eb89e2) to head (d0671ee).
Report is 9 commits behind head on main.

Files Patch % Lines
torchtune/modules/transformer.py 17.64% 28 Missing ⚠️
torchtune/models/qwen2/_convert_weights.py 25.00% 21 Missing ⚠️
torchtune/models/qwen2/_tokenizer.py 94.50% 10 Missing ⚠️
torchtune/models/qwen2/_component_builders.py 93.22% 4 Missing ⚠️
torchtune/models/qwen2/_model_builders.py 85.71% 2 Missing ⚠️
torchtune/utils/_checkpointing/_checkpointer.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1143      +/-   ##
==========================================
+ Coverage   67.81%   72.00%   +4.18%     
==========================================
  Files         219      230      +11     
  Lines        9908    10479     +571     
==========================================
+ Hits         6719     7545     +826     
+ Misses       3189     2934     -255     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi
Copy link
Collaborator

Thanks so much for the updates @fyabc. This is looking great. As per @joecummings's comment above, would it be possible to verify training works as expected?

The individual unit tests are great for proving correctness of individual models, but a E2E run of this recipe on alpaca and comparing the loss curves between this and HuggingFace is a fantastic sanity check to make sure the torchtune users are getting the expected experience. It's also useful to output both loss curves to a Weights & Biases run so we can visually see that they're similar.

I think once you've had a chance to compare a training run, we should be good to land this (though Joe and Rafi may have some more input). Please let us know if there's more we can do to help - we really appreciate all the work you've put in here.

@fyabc
Copy link
Contributor Author

fyabc commented Jul 29, 2024

Thanks so much for the updates @fyabc. This is looking great. As per @joecummings's comment above, would it be possible to verify training works as expected?

The individual unit tests are great for proving correctness of individual models, but a E2E run of this recipe on alpaca and comparing the loss curves between this and HuggingFace is a fantastic sanity check to make sure the torchtune users are getting the expected experience. It's also useful to output both loss curves to a Weights & Biases run so we can visually see that they're similar.

I think once you've had a chance to compare a training run, we should be good to land this (though Joe and Rafi may have some more input). Please let us know if there's more we can do to help - we really appreciate all the work you've put in here.

Thank you for your suggestion! I am running finetune recipes and compare them with the official huggingface finetune script. I will report experiment results to wandb and add the share link to PR description.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! I left a few small comments but none of them are blocking. I think we can address anything else in a follow-up. Thank you so much for contributing Qwen2 to our library!

@@ -246,3 +247,152 @@ def forward(
# shape: [b, s, out_dim] - out_dim is usually the vocab size
output = self.output(h).float()
return output


class TiedEmbeddingTransformerDecoder(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this is in transformer.py as a core module we need to add a unit test for it. I don't think we need to block this PR on it, so filed #1241 as a follow-up.

#
# This config assumes that you've run the following command before launching
# this run:
# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the download commands also need to add the --ignore-patterns flag to properly download safetensors files (this is slightly annoying but we need to do things this way for the time being)

Suggested change
# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct
# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct --ignore-patterns ""

attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
rope_base: float = 1_000_000.0,
tie_word_embeddings: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is this something we anticipate people to experiment with? I think it's fine to support but (a) it can have implications on checkpointing and/or FSDP wrapping logic, and (b) we are relaxing the contract of the builder function a bit to do it. Not a huge concern, but wanna make sure that it's valuable to add this kind of configurability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: nevermind on this point. I see that you mention it being relevant for 0.5B and 1.5B size models. So I think the approach makes sense

@ebsmothers ebsmothers merged commit ca1d7a1 into pytorch:main Jul 30, 2024
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants