Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KD distributed recipe #1631

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

lindawangg
Copy link
Contributor

@lindawangg lindawangg commented Sep 20, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

To enable distributed training for knowledge distillation.

Changelog

What are the changes made in this PR?

  • Builds on top of Add single device KD recipe #1539
  • KD distributed recipe (knowledge_distillation_distributed.py) is similar to lora_finetune_distributed.py.
  • KD config: knowledge_distillation_distributed.yaml

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
tune run --nodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed

(left) single device (right) distributed, can also increase batch size
imageimage
Similar eval results
image

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1631

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 20, 2024
@lindawangg lindawangg marked this pull request as ready for review September 20, 2024 01:55
@felipemello1
Copy link
Contributor

Hey @lindawangg , thanks for the recipe!! We have been a bit busy, but we will get to this PR.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few minor comments, otherwise looks good!

# To launch on a single device, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed
#
# This config works only for distilling on a single device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update this

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the peak memory you're seeing is but with distributed you may be able to get away without this (especially for such small models) and get faster training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to False. Isn't needed for qwen2 and training time also went from 1h to 20 mins.

Comment on lines 56 to 62
@pytest.mark.parametrize(
"reshard_after_forward",
[
True,
False,
],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal but you can probably get away without testing both of these cases. We already test it elsewhere and I don't expect it to change in KD vs in other recipes (lmk if you disagree though)

Comment on lines 75 to 78
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly we should probably just bite the bullet and upload some small Qwen2-formatted checkpoints rather than overriding everything as Llama in these tests. (Btw you don't have to worry about this, I am just writing it down so we can hold ourselves accountable later 😃 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since llama3_2 is released. I changed to llama3_2 distributed config, which uses the same LLAMA3 model type

Comment on lines 44 to 102
"""
Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized
for single GPU training. Training on CPU is not supported.

Features:
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.

- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
loss in model quality (will depend on the model, training data and other settings). For
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
precision are currently not supported.g

- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
controlled using the ``gradient_accumulation_steps`` flag.

Total Batch Size = batch_size * gradient accumulation steps.

For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32.

Gradient accumulation is especially useful when you are memory constrained. In this case,
accumulating gradients might give you better training speed than enabling activation
checkpointing.

- Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes
library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with
8-bit AdamW and Paged AdamW.

- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
training. Currently we checkpoint both the adapter weights (trainable params only) and the
complete merged weights (adapter weights added back to the base model). For more details
please take a look at our LoRA tutorial
(https://pytorch.org/torchtune/main/tutorials/lora_finetune.html).

Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are
only saved at the end of a given epoch and used in case of resuming training. Resuming
training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
currently not supported.

For more details on the checkpointer, please take a look at
our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html).

- Logging. Terminal, Disk, WandB and TensorBoard are all supported.

For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.

Args:
cfg (DictConfig): OmegaConf object parsed from yaml file

Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole docstring is from the single-device recipe. May wanna make sure it lines up with the features that are in here (e.g. FSDP, and I don't think we really advertise low-precision optimizers in our distributed recipes (though they should probably work))

is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = (True,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a tuple? Also might be useful to run with QLoRA and/or DoRA just as a sanity check that nothing breaks if you haven't already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that was a mistake. I don't remember why. Changed to is_dora = True and tested works with dora.

@joecummings joecummings mentioned this pull request Oct 15, 2024
33 tasks
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but no huge concerns from my side. Looks like CI is red for unrelated reasons, tagging @joecummings who is looking into it

# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's also worthwhile to add a config with 70B model size? (Doesn't necessarily have to be in this PR, but it'd be useful to have at least one config that strictly requires distributed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up to that, I wonder if we should include model sizes in the config names? I know it makes it a bit longer (and doesn't line up with what you did for the single-device configs), but otherwise we cannot really distinguish between configs for distilling 70B -> 1B vs 8B -> 1B. Similar to the other comment here, this is fine to save for a follow-up though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this config on the 70B model and verified it works, but I think there has to be more tuning. We can add the 70B model in a separate PR and figure out how to change the naming. There wasn't many changes to add the 70B model, just the model target and checkpoint since tokenizer has to be the same right now.


@pytest.mark.integration_test
@gpu_test(gpu_count=2)
def test_training_state_on_resume(self, tmpdir, monkeypatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we lose test_loss along the way here? (Just wanna make sure it was deliberate)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think I read the previous comment wrong. I thought you meant since test_loss is already tested, we didn't need to test it again. But now i realized you meant reshard_after_forward. Let me add test_loss back in

Comment on lines +465 to +470
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, did you try keeping the student model unsharded? I'm wondering what the tradeoff is here for perf vs memory.. if the model is small enough to not change the HW profile we're runnable on by fully replicating across all devices but we get speedups by saving on comms, might be worthwhile.

Copy link
Contributor Author

@lindawangg lindawangg Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do it for 1B, but i got oom when trying to load 3B student and 70B teacher models. We could set it as an option to shard the student model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you see nice speedups on 1B, or were they pretty minimal? If the latter let's just leave it as is, otherwise we can consider exposing the option as you mentioned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The speedup is pretty minimal, especially when using 8 gpus. The number of devices influences the speed more.
1B w/o fsdp 5 steps on 8 gpus: 1:09
1B w/o fsdp 5 steps on 4 gpus: 3:57
1B w/ fsdp 5 steps on 8 gpus: 1:14
1B w/ fsdp 5 steps on 4 gpus: 4:59


class_loss, kd_loss = self._loss_step(batch)
loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss
loss = loss / self._gradient_accumulation_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to actually do anything for this PR, but just FYI we are likely to be changing how we normalize loss when gradient accumulation is enabled (see #1875)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants