-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Qwen2-7B model to torchtune. #1143
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1143
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d0671ee with merge base 2dc11d9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @fyabc! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for this great contribution. Could you provide a bit more information in the PR Description around some of the decisions you made when implementing Qwen2? I'm interested in particular why the TransformerDecoder wasn't sufficient for implementing Qwen. I am also interesting in understanding better how you tested and verified that this implementation is correct. Looking forward to getting this into our library and thank you again. |
Hi, I have updated some implementation details in the description, please check it again. |
Hi @fyabc, thanks so much for this PR. A couple initial questions:
|
Thank you for your suggestions!
|
Thanks @fyabc for your responses! For (1) let's go with Thanks for clarifying on (2). @joecummings has also been looking at our RoPE implementation so can weigh in with any thoughts he has. (3) sounds great! Please let me know if you need any pointers from us on adding a new tokenizer; otherwise just ping on here when you are ready for a more detailed review of everything. |
# Conflicts: # torchtune/utils/_checkpointing/_checkpointer.py
@ebsmothers Hi, I have updated this PR based on your suggestions:
|
@@ -0,0 +1,124 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there generally interest in the 1.5B model too? If so, is it worth adding?
Perhaps as a more general point, I noticed in the Qwen2 technical report that weight tying is only used for the 0.5B and 1.5B models, right? Would users want to configure weight tying in the larger models where weight tying is enabled by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will add other size of models (0.5B / 1.5B / 72B).
In addition, whether weight tying is enabled must be consistent between model initialization (in _model_builders.py
here) and weight loading (in _convert_weights.py
, reading configuration from config.json
file), so I think we should fix the weight tying configuration in default model providers to avoid users from bugs.
Based on the experience of the open source community, most users have no need to modify weight tying when tuning Qwen models. For expert users, they can directly use torchtune.models.qwen2.qwen2()
to provide correct hyperparameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the insight!
This brings me to a broader point and feel free to ignore me if this sounds crazy. If we don't think there would be interest from the community in the 0.5B/1.5B models, which do use weight tying, would we be able to simplify this PR a bunch by just using the original TransformerDecoder
for the 7B and 72B models, which don't need weight tying?
If you think it's worth offering this flexibility I think the added decoder is fine, and it'll be nice to share it with Gemma. (also cc @joecummings @ebsmothers here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we've made it pretty far in this PR so I think best course of action is to land this with the weight tying; however, for future reference @SalmanMohammadi is absolutely correct that we prefer to start by landing only what is needed, which - in this case - would just be regular Qwen2 without weight tying.
This code will make it easy to add the small models, which is a huge benefit for our memory constrained users, so I'm excited for the quick follow up with those configs :)
@@ -0,0 +1,48 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sure @ebsmothers will have more thoughts on whether this is necessary here - but I'll share an excellent comment from him about the testing process in torchtune #840 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this is coming together very nicely and I appreciate your hard work to add a great new model to torchtune! Most of the comments I had are nits, but I did have one last thing to say on testing.
The individual unit tests are great for proving correctness of individual models, but a E2E run of this recipe on alpaca and comparing the loss curves between this and HuggingFace is a fantastic sanity check to make sure the torchtune users are getting the expected experience. It's also useful to output both loss curves to a Weights & Biases run so we can visually see that they're similar.
Thank you for you suggestions! I will run all recipes and report the results and comparisons between HuggingFace. |
I've also added a few comments (but feel free to take these with a pinch of salt since I don't bring the experience that the maintainers do!). This PR is great and you've generally fit everything into the codebase well. I've raised some general points about whether it's worth generalising the fast BPE tokenizer that you've implemented, and whether there's interest in the smaller weight-tied models, or a need for flexibility around weight tying in the larger models; I'd love to get some input on these points. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I'm late to the party - thank you for all the excellent work here! This was a huge effort to add an important model to the library, and you've integrated it into the codebase well.
One general comment I had on tied embeddings - do you know if users will ever want to use Qwen2 without tied embeddings? I noticed on HF this defaults to False. Asking because assuming it's always tied will simplify a lot of the builder logic and typing. For Gemma for example, we took a stance that it will always be tied, but I don't know enough about Qwen to make that call.
On generalizing the fast BPE tokenizer here - I would vote to keep it in Qwen until another model needs to use the same GPT-2 BPE tokenizer (cc @SalmanMohammadi)
Hi @fyabc! Thanks for your patience as we reviewed this PR - we're still really excited to get the Qwen2 model into torchtune. Would you mind taking another quick look at the comments by @SalmanMohammadi and @RdoubleA ? Let me know how else I can help. |
Sorry, I am busy on other work last week and forgot to reply. I will update my code on these comments in few days. |
No worries at all! Please do let us know if there's anyway we can help support you further - we appreciate all the work you've already done here |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1143 +/- ##
==========================================
+ Coverage 67.81% 72.00% +4.18%
==========================================
Files 219 230 +11
Lines 9908 10479 +571
==========================================
+ Hits 6719 7545 +826
+ Misses 3189 2934 -255 ☔ View full report in Codecov by Sentry. |
Thanks so much for the updates @fyabc. This is looking great. As per @joecummings's comment above, would it be possible to verify training works as expected?
I think once you've had a chance to compare a training run, we should be good to land this (though Joe and Rafi may have some more input). Please let us know if there's more we can do to help - we really appreciate all the work you've put in here. |
Thank you for your suggestion! I am running finetune recipes and compare them with the official huggingface finetune script. I will report experiment results to wandb and add the share link to PR description. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! I left a few small comments but none of them are blocking. I think we can address anything else in a follow-up. Thank you so much for contributing Qwen2 to our library!
@@ -246,3 +247,152 @@ def forward( | |||
# shape: [b, s, out_dim] - out_dim is usually the vocab size | |||
output = self.output(h).float() | |||
return output | |||
|
|||
|
|||
class TiedEmbeddingTransformerDecoder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that this is in transformer.py as a core module we need to add a unit test for it. I don't think we need to block this PR on it, so filed #1241 as a follow-up.
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the download commands also need to add the --ignore-patterns
flag to properly download safetensors files (this is slightly annoying but we need to do things this way for the time being)
# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct | |
# tune download Qwen/Qwen2-7B-Instruct --output-dir /tmp/Qwen2-7B-Instruct --ignore-patterns "" |
attn_dropout: float = 0.0, | ||
norm_eps: float = 1e-5, | ||
rope_base: float = 1_000_000.0, | ||
tie_word_embeddings: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is this something we anticipate people to experiment with? I think it's fine to support but (a) it can have implications on checkpointing and/or FSDP wrapping logic, and (b) we are relaxing the contract of the builder function a bit to do it. Not a huge concern, but wanna make sure that it's valuable to add this kind of configurability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: nevermind on this point. I see that you mention it being relevant for 0.5B and 1.5B size models. So I think the approach makes sense
Context
This PR add support to Qwen2-7B, including model components and generation / full-parameter finetune / LoRA finetune recipes.
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test
Implementation Details & Notes
tie_word_embeddings = True
), so we add a newtorchtune.models.qwen2.transformer.Qwen2TransformerDecoder
module to support this feature (setoutput=None
).torchtune.models.qwen2._tokenizer.Qwen2Tokenizer
construct a "fast" tokenizer backed by HuggingFace's tokenizers library (follow official transformers implementation).torchtune.data._chat_formats.ChatMLFormat
will add an extra<|IM_END|>
tag at the end of the last assistant message in generation (see this discussion), so we override thetokenize_message()
method to fix this problem (will not add<|IM_END|>
and EOS token ifmessages[-1].role == 'assistant' and not messages[-1].content
)torchtune.modules.peft.lora.LoRALinear
does not support bias yet, QLoRA finetune recipes a not added into this PR (will support in future).