Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLHF with PPO #1005

Merged
merged 44 commits into from
Aug 5, 2024
Merged

RLHF with PPO #1005

merged 44 commits into from
Aug 5, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented May 19, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

#812

Background reading:

The N Implementation Details of RLHF with PPO, Huang et al.
The N+ Implementation Details of RLHF with PPO: A case study on TL;DR Summarization
The original RLHF paper - Fine-Tuning Language Models from Human Preferences, Ziegler et al.
Anthropic's RLHF paper - Training a Helpful and Harmless Assistant with Learning from Human Feedback
Training language models to follow instructions with human feedback, Ouyang et al.
Shameless plug, but I would have genuinely found this post helpful when I started out with PPO, even for skimming through some of the references - The theory of Proximal Policy Optimization implementations

Changelog:

  • Implemented LoRA PPO recipe
  • Refactored TransformerDecoder
    • Changes
      • The following changes were added as a submodule in torchtune.models.mistral:
        • Changed TransformerDecoder to TransformerDecoderWithHiddenLayer
        • Added TransformerLM which wraps an output projection around TransformerDecoderWithHiddenLayer
          • Component and model builders now return TransformerLM
        • Added TransformerLMWithValueHead, with two linear projections: one for the LM head and one for the value head.
      • Support for checkpointing with models wrapping around TransformerDecoderWithHiddenLayer added
      • Updated checkpointing to correctly convert HF weights to refactored models
      • Added support for checkpointing value heads.
    • Tests:
      • TODO
  • Added mistral value head models
    • Changes:
      • Added model and component builders
    • Tests
      • The implementation is identical to MistralClassifier.
  • Added PPOLoss
    • Tests:
      • test_ppo_loss tests for correct behaviour based on expected relative value and policy loss for different inputs.
  • Added utils.ppo_utils for various ppo utils, and tests for all files including:
    • _generation.py
      • Added custom_generate_next_token functions for generating with value head models, and for generating with masks and input positions.
      • Added get_causal_masks for creating masks of shape [bsz, seq_len, seq_len] which correctly mask leading padding tokens, suitable for use with scaled_dot_product_attention.
      • Added a custom generate function which generates sequences using above functionality.
    • collate.py
      • Added support for collating input sequences by left-padding to a specified maximum sequence length.
    • rewards.py
      • Support for calculation of rewards, advantage estimation, adaptive and fixed KL controllers, and reward normalisation.

TODO:

  • Complete or remove all TODO (@SalmanMohammadi)
  • Remove temporary changes for MPS support.
  • Upload randomly initialised models and write recipe tests for:
    • Verifying recipe checkpointing works correctly w.r.t. saving and loading base model and value head weights.
    • Verifying expected loss values.
  • Run full model training and verify loss curves
  • Add support for reward models and base models using different tokenizers.
  • Add docs to api reference.

Adding this to open up discussion and get some feedback (@kartikayk) while I train models and verify correctness. Maybe a good place to start would be the TransformerDecoder refactor?

closes #812

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented May 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1005

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 4e6be43 with merge base 5019074 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 19, 2024
@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review June 8, 2024 15:00
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is shaping up nicely - I very much like that we de-scoped this to focus on full finetune first with lora and qlora as follow-ups. A couple high level things to note:

  1. I'm concerned about the configs becoming too bloated and would like to discuss how to minimize storing lots of logic there.

  2. What are the largest size model you can fit in 80G A100? I see you include configs for both 7B and 1B?

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Jul 17, 2024

Thanks for another review.

I'm concerned about the configs becoming too bloated and would like to discuss how to minimize storing lots of logic there.

I feel this. One thing that stuck out to me when writing this - we currently need 4 checkpointers, two of which are solely used to point to the original weights for the policy and reward models, respectively. They're necessary because you need the reference to the original weights when resuming training, and the choice to me at the time was managing this state in the config vs the checkpoints.

The model definitions are also taking up a lot of space, but that's largely because I didn't see another obvious way to configure a 1B Llama2. The model definition in the Mistral config is annoying because that specific reward model uses a different vocab size. Please let me know if I can make this cleaner!

There's also ~30 lines for hyperparameters in the config. Hopefully this won't be overwhelming to the user once we include a cookbook. I could remove 5 or so of these from the config and set as defaults in the recipe.

EDIT: I could also liberally use cfg.get to set the recipe up with default hyperparameter valuex which could work for most use-cases and expose them in the recipe docs instead.

What are the largest size model you can fit in 80G A100? I see you include configs for both 7B and 1B?

The training run in my above comment trained Mistral 7B on an 80GB A100.

@codecov-commenter
Copy link

codecov-commenter commented Jul 18, 2024

Codecov Report

Attention: Patch coverage is 52.35546% with 445 lines in your changes missing coverage. Please review.

Project coverage is 67.96%. Comparing base (43c7332) to head (ba365a8).
Report is 4 commits behind head on main.

Files Patch % Lines
recipes/ppo_full_finetune_single_device.py 0.00% 330 Missing ⚠️
...ts/recipes/test_ppo_full_tunetune_single_device.py 16.19% 88 Missing ⚠️
torchtune/modules/rlhf/collate.py 45.00% 11 Missing ⚠️
torchtune/modules/rlhf/rewards.py 85.45% 8 Missing ⚠️
tests/recipes/utils.py 33.33% 4 Missing ⚠️
torchtune/modules/rlhf/sequence_processing.py 84.61% 2 Missing ⚠️
recipes/lora_dpo_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_dpo_single_device.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1005      +/-   ##
==========================================
- Coverage   69.32%   67.96%   -1.37%     
==========================================
  Files         233      246      +13     
  Lines       10593    11434     +841     
==========================================
+ Hits         7344     7771     +427     
- Misses       3249     3663     +414     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi
Copy link
Collaborator Author

bump bump @joecummings @ebsmothers.

My dearest reviewers,

Sorry to ping you when you're busy. In my defense, @kartikayk did tell me to. What can I do to help move this along? I'm more than happy to help reduce the review overhead if I can.

@bhack bhack mentioned this pull request Jul 30, 2024
@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Aug 2, 2024

Outstanding discussions/tasks:

@joecummings @ebsmothers

  • @ebsmothers can you please upload my reward model to S3 so my recipe tests pass? See the PPO channel on Discord. Let me know if you need a refresh on this.
  • @joecummings pointed out that the config is a bit unweildy. I think I could address this by removing a bunch of parameters from the config and setting sensible defaults, and then correctly documenting them in a recipe doc (which might look like [RFC][DOCS] Recipe [DOCS] ([DOC]umentation) #1230). Can we leave this as a follow up, though? 🥺
  • Are you happy with my replication results here? RLHF with PPO #1005 (comment)
  • Can we also leave generalising the generation utils to a follow up? Joe has started discussion in Fix generation for bsz > 1 #1250.
  • My use of rng generator checkpointing (RLHF with PPO #1005 (comment)) is unprecedented in the codebase. Are you happy with this? I'm just adding another key to the checkpoint.
  • Do you care about generalizing the RLHF collation utils to the torchtune.utils.collate? (Joe's comment here and here They currently aren't being used outside of the PPO recipe itself, and the DPO collation isn't being used outside the DPO recipe.

@@ -29,6 +29,7 @@
"llama2_tune": "/tmp/test-artifacts/small-ckpt-tune-03082024.pt",
"llama2_meta": "/tmp/test-artifacts/small-ckpt-meta-03082024.pt",
"llama2_hf": "/tmp/test-artifacts/small-ckpt-hf-03082024.pt",
"llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-12072024.pt", # TODO (SalmanMohammadi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being an American chauvinist but I changed the filename to small-ckpt-hf-reward-07122024.pt (really just want to make it consistent with the format of the other ones). Also I think you will need to update cache_artifacts.sh correspondingly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh how the Empire has fallen from grace.

@ebsmothers
Copy link
Contributor

General comment on the checklist you left earlier: all the points look good to me, let's just file tasks for some of the more important todos that don't have them already.

Also, leaving some miscellaneous remarks here in response to several of your previous comments:

I ran the experiment on an A100 - the default config in the repo includes the memory optimization parameters needed to make this work. I used optimizer_in_bwd and PagedAdamW. Training was pretty slow at the start, but I was seeing >10x speedups once torch.compile kicked in.

Looking at the figures seems this is necessary even for A100? Since you are still pretty close to 80GB allocated memory. I'm also curious whether the overall training speed is decent as these configs can slow things down quite a bit.

Can we also leave generalising the generation utils to a follow up? Joe has started discussion in #1250.

Just want to confirm: will we actually be able to run this recipe without batched generation support?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK a bunch more comments but after that there are no major concerns. Home stretch here -- thanks again for your immense patience on this one

torchtune/utils/pooling.py Outdated Show resolved Hide resolved
torchtune/modules/rlhf/collate.py Show resolved Hide resolved
torchtune/utils/collate.py Show resolved Hide resolved
torchtune/modules/loss/ppo.py Outdated Show resolved Hide resolved
torchtune/modules/loss/ppo.py Outdated Show resolved Hide resolved
recipes/ppo_full_finetune_single_device.py Outdated Show resolved Hide resolved
Comment on lines +795 to +797
(seq_lens > 0) & (seq_lens < self._max_generated_tokens - 1),
seq_lens + 1,
seq_lens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not sure I fully follow what the purpose of this is

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...
...
...
Dare I say..... excalidraw?

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Aug 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all seriousness, I can send you an equally confusing diagram from my notes on Discord. This took me a while to wrap my head around, and longer to explain coherently (disclaimer, this could just all be wrong, since my only reference is a single line from a Learning to Summarize implementation), so, thanks for the nerd snipe.

The TL;DR - the value function is estimating the return for the whole sequence at each step, which is the reward model score for the (query, truncated response), plus the KL per-token penalty. We want to use this for the advantage estimation, and the advantage for the last action taken (the last valid non-padding token generated by the model), is:

image

So, we need the value estimate (return) for the sequence up to now, plus one step ahead. For the last token, this means we need extend the padding mask out by one for the values - instead of masking everything after the last non-padding token, we mask everything one value after the last non-padding token.

These three lines do this, but just add some logic to say if we're already at the end of the sequence then we don't need to extend the mask.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I think this makes sense (though I reserve the right to be confused again later on)

recipes/ppo_full_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/ppo_full_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/ppo_full_finetune_single_device.py Outdated Show resolved Hide resolved
@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Aug 5, 2024

O' what glorious reviews. Thank you. I've addressed your comments.

Looking at the figures seems this is necessary even for A100? Since you are still pretty close to 80GB allocated memory. I'm also curious whether the overall training speed is decent as these configs can slow things down quite a bit.

For 7B, since we're fitting 4x7B models in it took a little wrangling to fit it all. The run I posted took around ~3 hours. I haven't found any tests to benchmark against here on comparable hardware, to estimate appropriate speed/memory usage. DeepSpeed's RLHF states:

Theoretically, the largest model you can train for this step is similar to the step-1 SFT finetuning if you enable

zero stage 3 (if you use multiple GPUs)
gradient checkpoint
LoRA
reference model offloading.

However, in practice, this is not always the case, and we are still investigating the reasons behind it. For now, we suggest that users use "Total-GPU-Memory-in-GB / 6" as the upper parameter bound in billions for the sum of the actor model and critical model, for safety. Nevertheless, users are welcome to try the real limit.

Not 100% clear on whether the upper parameter bound calculation requires the specific config they listed, but going by that, their method offers ~13GB for the sum of the actor and critic models, max, on an 80GB A100. TRL trained Pythia 6.9B on their PPOV2 trainer with 8xH100.

Just want to confirm: will we actually be able to run this recipe without batched generation support?

The generation utils I include in this PR do provide batched generation support.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your immense patience on this one. I left a couple of other follow-up comments, but none of them are blocking us from landing this. 🚀

@SalmanMohammadi SalmanMohammadi merged commit c593c10 into pytorch:main Aug 5, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Proximal Policy Optimisation
7 participants